# Models utilities module

> This module handles all aspects of the world model, including state representation, environment dynamics, and prediction.

In [None]:
#| default_exp models.misc

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *
import torch

In [None]:
#| export
from typing import List, Union

import torch
from torch import nn
from torch.nn import functional as F
from mawm.models.utils import *


def build_projector(arch: str, embedding: int):
    if arch == "id":
        return nn.Identity(), embedding
    else:
        f = [embedding] + list(map(int, arch.split("-")))
        return build_mlp(f), f[-1]


def build_norm1d(norm: str, dim: int):
    if norm == "batch_norm":
        return torch.nn.BatchNorm1d(dim)
    elif norm == "layer_norm":
        return torch.nn.LayerNorm(dim)
    else:
        raise ValueError(f"Unknown norm {norm}")


def build_activation(activation: str):
    if activation == "relu":
        return nn.ReLU(True)
    elif activation == "mish":
        return nn.Mish(True)
    else:
        raise ValueError(f"Unknown activation {activation}")


class PartialAffineLayerNorm(nn.Module):
    def __init__(
        self,
        first_dim: int,
        second_dim: int,
        first_affine: bool = True,
        second_affine: bool = True,
    ):
        super().__init__()
        self.first_dim = first_dim
        self.second_dim = second_dim

        if first_affine:
            self.first_ln = nn.LayerNorm(first_dim, elementwise_affine=True)
        else:
            self.first_ln = nn.LayerNorm(first_dim, elementwise_affine=False)

        if second_affine:
            self.second_ln = nn.LayerNorm(second_dim, elementwise_affine=True)
        else:
            self.second_ln = nn.LayerNorm(second_dim, elementwise_affine=False)

    def forward(self, x):
        first = self.first_ln(x[..., : self.first_dim])
        second = self.second_ln(x[..., self.first_dim :])
        out = torch.cat([first, second], dim=-1)
        return out


def build_mlp(
    layers_dims: Union[List[int], str],
    input_dim: int = None,
    output_shape: int = None,
    norm="batch_norm",
    activation="relu",
    pre_actnorm=False,
    post_norm=False,
):
    if isinstance(layers_dims, str):
        layers_dims = (
            list(map(int, layers_dims.split("-"))) if layers_dims != "" else []
        )

    if input_dim is not None:
        layers_dims = [input_dim] + layers_dims

    if output_shape is not None:
        layers_dims = layers_dims + [output_shape]

    layers = []

    if pre_actnorm:
        if norm is not None:
            layers.append(build_norm1d(norm, layers_dims[0]))
        if activation is not None:
            layers.append(build_activation(activation))

    for i in range(len(layers_dims) - 2):
        layers.append(nn.Linear(layers_dims[i], layers_dims[i + 1]))
        if norm is not None:
            layers.append(build_norm1d(norm, layers_dims[i + 1]))
        if activation is not None:
            layers.append(build_activation(activation))

    layers.append(nn.Linear(layers_dims[-2], layers_dims[-1]))

    if post_norm:
        layers.append(build_norm1d(norm, layers_dims[-1]))

    return nn.Sequential(*layers)


class MLP(torch.nn.Module):
    def __init__(
        self,
        arch: str,
        input_dim: int = None,
        output_shape: int = None,
        norm=None,
        activation="relu",
    ):
        super().__init__()

        self.mlp = build_mlp(
            layers_dims=arch,
            input_dim=input_dim,
            output_shape=output_shape,
            norm=norm,
            activation=activation,
        )

    def forward(self, x):
        return self.mlp(x)


PROBER_CONV_LAYERS_CONFIG = {
    "a": [
        (-1, 16, 3, 1, 1),
        ("max_pool", 2, 2, 0),
        (16, 8, 1, 1, 0),
        ("max_pool", 2, 2, 0),
        ("fc", -1, 2),
    ],
    "b": [
        (-1, 32, 3, 1, 1),
        ("max_pool", 2, 2, 0),
        (32, 32, 3, 1, 1),
        ("max_pool", 2, 2, 0),
        (32, 32, 3, 1, 1),
        ("fc", -1, 2),
    ],
    "c": [
        (-1, 32, 3, 1, 1),
        ("max_pool", 2, 2, 0),
        (32, 32, 3, 1, 1),
        ("max_pool", 2, 2, 0),
        (32, 32, 3, 1, 1),
        (32, 32, 3, 1, 1),
        ("fc", -1, 2),
    ],
}


class Prober(torch.nn.Module):
    def __init__(
        self,
        embedding: int,
        arch: str,
        output_shape: int,
        input_dim=None,
        arch_subclass: str = "a",
    ):
        super().__init__()
        self.output_shape = output_shape
        self.arch = arch

        if arch == "conv":
            self.prober = build_conv(
                PROBER_CONV_LAYERS_CONFIG[arch_subclass], input_dim=input_dim
            )
        else:
            arch_list = list(map(int, arch.split("-"))) if arch != "" else []
            f = [embedding] + arch_list + [self.output_shape]
            layers = []
            for i in range(len(f) - 2):
                layers.append(torch.nn.Linear(f[i], f[i + 1]))
                # layers.append(torch.nn.BatchNorm1d(f[i + 1]))
                layers.append(torch.nn.ReLU(True))
            layers.append(torch.nn.Linear(f[-2], f[-1]))
            self.prober = torch.nn.Sequential(*layers)

    def forward(self, e):
        if self.arch == "conv":
            output = self.prober(e)
        else:
            e = flatten_conv_output(e)
            output = self.prober(e)

        # output = output.view(*output.shape[:-1], *self.output_shape)

        return output


class Projector(torch.nn.Module):
    def __init__(self, arch: str, embedding: int, random: bool = False):
        super().__init__()

        self.arch = arch
        self.embedding = embedding
        self.random = random

        self.model, self.output_dim = build_projector(arch, embedding)

        if self.random:
            for param in self.parameters():
                param.requires_grad = False

    def maybe_reinit(self):
        if self.random and self.arch != "id":
            for param in self.parameters():
                torch.nn.init.xavier_uniform_(param)
                print("initialized")

    def forward(self, x: torch.Tensor):
        return self.model(x)

{'CellEmpty': 0, 'CellObstacle': 1, 'CellItem': 2, 'CellGoal': 3, 'CellAgent': 4, 'GoalAt': 5, 'ItemAt': 6, 'Near': 7, 'SeeGoal': 8, 'CanMove': 9, 'OtherAgentAt': 10, 'OtherAgentNear': 11, 'OtherAgentDirection': 12}


In [None]:
#| export
import torch
from torch import nn
from einops import rearrange
class JepaProjector(nn.Module):
    def __init__(self, z_input_dim=4050, c_input_dim=32):
        super().__init__()
        
        self.z_projector = nn.Sequential(
                nn.Linear(z_input_dim, 2048),
                nn.BatchNorm1d(2048),
                nn.ReLU(),
                nn.Linear(2048, 128) # 128 is the 'Shared Latent Space'
            )

        self.msg_projector = nn.Sequential(
            nn.Linear(c_input_dim, 2048),
            nn.BatchNorm1d(2048),
            nn.ReLU(),
            nn.Linear(2048, 128) # 128 is the 'Shared Latent Space'
            )
        
    def forward(self, z_sender, C):
        B, T, D = C.shape
        z_sender = rearrange(z_sender, 'b t c h w -> (t b) (c h w)')
        proj_z = self.z_projector(z_sender) # [(T*B, dim=128]
        proj_z = rearrange(proj_z, '(t b) d -> t b d', b= B)

        C = rearrange(rearrange(C, 'b t d -> t b d'), 't b d -> (t b) d')
        print(C.shape)
        proj_c = self.msg_projector(C) # [(T*B, dim=128]
        proj_c = rearrange(proj_c, '(t b) d -> t b d', b= B)
        
        return  proj_z[:-1], proj_c[:-1]


In [None]:
#| hide
from mawm.models.jepa import JEPA
from omegaconf import OmegaConf
cfg = OmegaConf.load("../cfgs/MPCJepa/mpc.yaml")

model = JEPA(cfg.model,input_dim=(3, 42, 42), action_dim= 5)
model.backbone.repr_dim


(16, 15, 15)

In [None]:

#| hide
from mawm.models.vision import SemanticEncoder
msg_enc = SemanticEncoder(latent_dim = 32)
msg_enc.latent_dim

32

In [None]:
#| hide
T = 8
B = 16
d_c = 32
C = torch.randn(B, T, d_c)
z = torch.randn(B, T, 16, 15, 15)
from functools import reduce
prod = reduce(lambda x, y: x * y, model.backbone.repr_dim)
proj = JepaProjector(z_input_dim=prod, c_input_dim=msg_enc.latent_dim)
out_z, out_c = proj(z, C)
out_z.shape, out_c.shape

torch.Size([128, 32])


(torch.Size([7, 16, 128]), torch.Size([7, 16, 128]))

In [None]:
#| hide
from omegaconf import OmegaConf


In [None]:
#| hide
cfg = OmegaConf.load("../cfgs/MPCJepa/mpc.yaml")

cfg.model.backbone

{'arch': 'MeNet6', 'backbone_subclass': 'd4rl_a', 'backbone_mlp': None, 'backbone_width_factor': 2, 'input_dim': 4, 'channels': 3, 'position_dim': 0, 'position_encoder_arch': 'id'}

In [None]:
#| hide
T = 8
BS = 16
inp = torch.randn(BS, T, 3, 42, 42)
pos = torch.randn(BS, T, 2)

In [None]:
#| hide
z = model.backbone(inp, position=None)
z.shape

torch.Size([16, 8, 16, 15, 15])

A message predictor is a predictor for the message from the latent of the observation. So the input shape is :
$$\mathbf{z} \in \mathbb{R}^{B \times C \times H \times W}$$
where C is the number of channels of the latent representation of the observation.

Now the problem is that, the observation encoder outputs two different shapes depending on whether or not the position is passed. It it's passed, the output shape is:
$$\mathbf{z} \in \mathbb{R}^{B \times 18 \times 15 \times 15}$$
where H and W are the height and width of the latent representation.
Otherwise, the output shape is:

$$\mathbf{z} \in \mathbb{R}^{B \times 16 \times 15 \times 15}$$
where H and W are the height and width of the latent representation.

In [None]:
#| export
import torch
import torch.nn as nn
class MsgPred(nn.Module):
    def __init__(self, h_dim=32, in_channels=16):
        super().__init__()
        self.net = nn.Sequential(
            # Input: (B, 16, 15, 15)
            nn.Conv2d(in_channels, 32, kernel_size=3, stride=2), # -> (B, 32, 7, 7)
            nn.ReLU(),
            nn.Conv2d(32, 32, kernel_size=3, stride=2), # -> (B, 32, 3, 3)
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(32 * 3 * 3, 64),
            nn.ReLU(),
            nn.Linear(64, h_dim) # Output: 32

        )

    def forward(self, z):
        # z shape: (B, T, 16, 15, 15)
        if z.dim() == 5:
            B, T, C, H, W = z.shape
            z = rearrange(z, 'b t c h w -> (b t) c h w')
            out = self.net(z)
            out = rearrange(out, '(b t) d -> b t d', b= B)
            return out
        return self.net(z)

In [None]:
#| hide
from einops import rearrange
inp = torch.randn(64, 8, 16, 15, 15)
model = MsgPred(h_dim=32, in_channels=16)
out = model(inp)
out.shape

torch.Size([64, 8, 32])

In [None]:
#| export
import torch
import torch.nn as nn
class ObsPred(nn.Module):
    def __init__(self, h_dim=32, out_channels= 18):
        super().__init__()
        self.fc = nn.Sequential(
            nn.Linear(h_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 16 * 7 * 7),
            nn.ReLU()
        )
        self.upsample = nn.Sequential(
            # Input: (B, 16, 7, 7)
            nn.ConvTranspose2d(16, 32, kernel_size=3, stride=2), # -> (B, 32, 15, 15)
            nn.ReLU(),
            nn.Conv2d(32, out_channels, kernel_size=1) # -> (B, 16, 15, 15)
        )

    def forward(self, h):
        # h shape: (B, T, 32)
        B, T, d = h.shape
        h = rearrange(h, 'b t d -> (b t) d')
        z = self.fc(h)
        z = z.view(-1, 16, 7, 7)
        z = self.upsample(z)
        z = rearrange(z, '(b t) c h w -> b t c h w', b= B)
        return z
    

In [None]:
#| hide
from einops import rearrange
model = ObsPred(h_dim=32)
inp = torch.randn(16, 8, 32)
out = model(inp)
out.shape

torch.Size([16, 8, 18, 15, 15])

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()