# Vision Encoder model

> CNN for image feature extraction.

In [None]:
#| default_exp models.vision.__init__

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export 
import numpy as np
import torch
import torch.distributions as td
import torch.nn as nn
from fastcore.utils import *
from fastcore.all import *
from torch import nn
from torch.nn import functional as F

In [None]:
#| export
ENCODER_LAYERS_CONFIG = {
    # L1
    "a": [(2, 32, 5, 1, 0), (32, 32, 4, 2, 0), (32, 32, 3, 1, 0), (32, 16, 1, 1, 0)],
    "b": [(2, 16, 5, 1, 0), (16, 32, 4, 2, 0), (32, 32, 3, 1, 0), (32, 16, 1, 1, 0)],
    "c": [(2, 16, 5, 1, 0), (16, 16, 4, 2, 0), (16, 16, 3, 1, 0)],
    "f": [(2, 16, 5, 1, 0), (16, 16, 5, 2, 0), (16, 16, 5, 1, 2)],
    "g": [(2, 32, 5, 1, 0), (32, 32, 5, 2, 0), (32, 32, 5, 1, 2), (32, 16, 1, 1, 0)],
    "h": [(2, 16, 5, 1, 0), (16, 16, 5, 2, 0), (16, 16, 3, 1, 0)],
    "i": [(2, 16, 5, 1, 0), (16, 16, 5, 2, 0), (16, 16, 3, 1, 1)],
    "i_fc": [
        (2, 16, 5, 1, 0),
        (16, 16, 5, 2, 0),
        (16, 16, 3, 1, 1),
        ("fc", 13456, 512),
    ],
    "i_b": [(6, 16, 5, 1, 0), (16, 16, 5, 2, 0), (16, 16, 3, 1, 0)],
    "d4rl_a": [
        (6, 16, 5, 1, 0),
        (16, 32, 5, 2, 0),
        (32, 32, 3, 1, 0),
        (32, 32, 3, 1, 1),
        (32, 16, 1, 1, 0),
    ],
    "d4rl_b": [
        (6, 16, 5, 1, 0),
        (16, 32, 5, 2, 0),
        (32, 32, 3, 1, 0),
        (32, 32, 3, 1, 1),
        (32, 32, 3, 1, 1),
        (32, 16, 1, 1, 0),
    ],
    "d4rl_c": [
        (6, 16, 5, 1, 0),
        (16, 32, 5, 2, 0),
        (32, 32, 3, 1, 0),
        (32, 32, 3, 1, 1),
        (32, 32, 3, 1, 1),
    ],
    "j": [(2, 32, 5, 1, 0), (32, 32, 5, 2, 0), (32, 32, 3, 1, 1), (32, 16, 1, 1, 0)],
    "k": [(2, 16, 5, 1, 0), (16, 32, 5, 2, 0), (32, 32, 3, 1, 1), (32, 16, 1, 1, 0)],
    # L2
    "d": [(16, 16, 3, 1, 0), (16, 16, 3, 1, 0)],
    "e": [
        ("pad", (0, 1, 0, 1)),
        (16, 16, 3, 1, 0),
        ("avg_pool", 2, 2, 0),
        (16, 16, 3, 1, 0),
    ],
    "l2a": [(16, 16, 5, 1, 2), (16, 16, 5, 2, 2), (16, 16, 3, 1, 1)],  # (8, 16, 15, 15)
    "l2b": [(16, 16, 3, 1, 1), (16, 16, 3, 2, 1), (16, 16, 3, 1, 1)],  # (8, 16, 15, 15)
    "l2c": [(16, 32, 5, 1, 2), (32, 32, 5, 2, 2), (32, 32, 3, 1, 1)],  # (8, 32, 15, 15)
    "l2d": [(16, 32, 3, 1, 1), (32, 32, 3, 2, 1), (32, 32, 3, 1, 1)],  # (8, 32, 15, 15)
    "l2e": [(16, 16, 3, 2, 1), (16, 16, 3, 1, 1)],
}

In [None]:
#| hide
ENCODER_LAYERS_CONFIG['d4rl_a']

[(6, 16, 5, 1, 0),
 (16, 32, 5, 2, 0),
 (32, 32, 3, 1, 0),
 (32, 32, 3, 1, 1),
 (32, 16, 1, 1, 0)]

In [None]:
#| export
class PassThrough(nn.Module):
    def forward(self, x):
        return x


In [None]:
#| export
from mawm.models.vision.enums import BackboneOutput
class MLPNet(nn.Module):
    def __init__(self, output_dim: int = 64):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, output_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(output_dim, output_dim)

    def forward(self, x):
        out = x.flatten(1)
        out = self.fc1(out)
        out = self.relu(out)
        out = self.fc2(out)
        out = BackboneOutput(encodings=out)
        return out

In [None]:
#| export
from mawm.models.vision.enums import BackboneOutput
from mawm.models.vision.base import SequenceBackbone
import torch
import torch.nn as nn

class MeNet5(SequenceBackbone):
    def __init__(
        self,
        input_dim: int,
        output_dim: int = 64,
        input_channels: int = 1,
        width_factor: int = 1,
        conv_out_dim: int = 9 * 32,
        backbone_norm: str = "batch_norm",
        backbone_pool: str = "backbone_pool",
        backbone_final_fc: bool = True,
    ):
        super().__init__()
        self.width_factor = width_factor
        self.conv_out_dim = conv_out_dim
        self.backbone_final_fc = backbone_final_fc
        self.layer1 = nn.Sequential(
            nn.Conv2d(
                input_channels, 16 * width_factor, kernel_size=5, stride=2, padding=2
            ),
            nn.ReLU(),
            (
                nn.BatchNorm2d(16 * width_factor)
                if backbone_norm == "batch_norm"
                else nn.GroupNorm(4, 16 * width_factor)
            ),
            nn.Conv2d(
                16 * width_factor, 32 * width_factor, kernel_size=5, stride=2, padding=2
            ),
            nn.ReLU(),
            (
                nn.BatchNorm2d(32 * width_factor)
                if backbone_norm == "batch_norm"
                else nn.GroupNorm(4, 32 * width_factor)
            ),
            nn.Conv2d(
                32 * width_factor, 32 * width_factor, kernel_size=3, stride=1, padding=1
            ),
            nn.ReLU(),
            (
                nn.BatchNorm2d(32 * width_factor)
                if backbone_norm == "batch_norm"
                else nn.GroupNorm(4, 32 * width_factor)
            ),
        )
        if backbone_pool == "avg_pool":
            self.pool = nn.AvgPool2d(2, stride=2)
        else:
            self.pool = nn.Conv2d(
                in_channels=32 * width_factor, out_channels=32, kernel_size=1
            )
        sample_input = torch.randn(input_dim).unsqueeze(0)
        sample_output = self.pool(self.layer1(sample_input)).reshape(1, -1)
        if backbone_final_fc:
            self.fc = nn.Linear(sample_output.shape[-1], output_dim)

    def forward(self, x):
        out = self.layer1(x)  # [bs,64,16,16]
        out = self.pool(out)  # [bs, 32, 16, 16]
        if self.backbone_final_fc:
            out = out.reshape(out.size(0), -1)
            out = self.fc(out)
        out = BackboneOutput(encodings=out)
        return out


In [None]:
#| export
class ResizeConv2d(nn.Module):
    def __init__(
        self,
        in_channels,
        out_channels,
        kernel_size,
        scale_factor,
        mode="nearest",
        groups=1,
        bias=False,
        padding=1,
    ):
        super().__init__()
        self.scale_factor = scale_factor
        self.mode = mode
        self.conv = nn.Conv2d(
            in_channels,
            out_channels,
            kernel_size,
            stride=1,
            padding=padding,
            groups=groups,
            bias=bias,
        )

    def forward(self, x):
        x = F.interpolate(x, scale_factor=self.scale_factor, mode=self.mode)
        x = self.conv(x)
        x = BackboneOutput(encodings=x)
        return x


In [None]:
#| export

class Canonical(nn.Module):
    def __init__(self, output_dim: int = 64):
        super().__init__()
        res = int(np.sqrt(output_dim / 64))
        assert (
            res * res * 64 == output_dim
        ), "canonical backbone resolution error: cant fit desired output_dim"

        self.backbone = nn.Sequential(
            nn.Conv2d(1, 32, 4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(32, 64, 4, stride=2, padding=0),
            nn.ReLU(),
            nn.Conv2d(64, 64, 3, stride=1, padding=0),
            nn.ReLU(),
            nn.AdaptiveAvgPool2d((res, res)),
        )

    def forward(self, x):
        x = self.backbone(x).flatten(1)
        x = BackboneOutput(encodings=x)
        return x

In [None]:
#| export
from mawm.models.misc import  (
    build_mlp,
    Projector,
    MLP,
    build_norm1d,
    PartialAffineLayerNorm,
)

class MLPEncoder(SequenceBackbone):
    def __init__(self, cfg, input_dim):
        super().__init__()
        self.encoder = MLP(
            arch=cfg.backbone_subclass,
            input_dim=input_dim,
            norm=cfg.backbone_norm,
        )

    def forward(self, x):
        x = self.encoder(x)
        x = BackboneOutput(encodings=x)
        return x


In [None]:
#| export
class ObposEncoder1(SequenceBackbone):
    """
    Fused encoder for observation and pos state.
    cat(obs, pos) --> encoder --> encodings
    """

    def __init__(self, config, obs_dim):
        super().__init__()
        self.config = config

        self.encoder = MLP(
            arch=config.backbone_subclass,
            input_dim=obs_dim + config.pos_dim,
            norm=config.backbone_norm,
            activation="mish",
        )

        out_dim = int(config.backbone_subclass.split("-")[-1])

        if config.final_ln:
            self.final_ln = build_norm1d(config.backbone_norm, out_dim)
        else:
            self.final_ln = nn.Identity()

    def forward(self, obs, pos):
        x = torch.cat([obs, pos], dim=-1)
        x = self.encoder(x)
        x = self.final_ln(x)

        return BackboneOutput(encodings=x)

In [None]:
#| export
class ObposEncoder2(SequenceBackbone):
    """
    Distangled encoder for observation and pos state.
    obs --> obs_encoder --> obs_out
    pos --> pos_encoder --> pos_out
    encodings = cat(obs_out, pos_out)
    return: encodings, obs_out, pos_out
    """

    def __init__(self, config, obs_dim):
        super().__init__()
        self.config = config

        obs_subclass, pos_subclass = config.backbone_subclass.split(",")

        if obs_subclass == "id":
            self.obs_encoder = nn.Identity()
            obs_out_dim = obs_dim
        else:
            self.obs_encoder = build_mlp(
                layers_dims=obs_subclass,
                input_dim=obs_dim,
                norm=config.backbone_norm,
                activation="mish",
            )
            obs_out_dim = int(obs_subclass.split("-")[-1])

        if pos_subclass == "id":
            self.pos_encoder = nn.Identity()
            pos_out_dim = config.pos_dim
        else:
            self.pos_encoder = build_mlp(
                layers_dims=pos_subclass,
                input_dim=config.pos_dim,
                norm=config.backbone_norm,
                activation="mish",
            )
            pos_out_dim = int(pos_subclass.split("-")[-1])

        if config.final_ln:
            self.final_ln = PartialAffineLayerNorm(
                first_dim=obs_out_dim,
                second_dim=pos_out_dim,
                first_affine=(obs_subclass != "id"),
                second_affine=(pos_subclass != "id"),
            )
        else:
            self.final_ln = nn.Identity()

    def forward(self, obs, pos):
        obs_out = self.obs_encoder(obs)
        pos_out = self.pos_encoder(pos)

        next_state = torch.cat([obs_out, pos_out], dim=1)
        next_state = self.final_ln(next_state)

        return BackboneOutput(
            encodings=next_state,
            obs_component=obs_out,
            pos_component=pos_out,
        )



In [None]:
#| export
import torch
import torch.nn as nn
from mawm.models.utils import Expander2D, build_conv

class MeNet6(nn.Module):
    def __init__(
        self,
        config,
        input_dim: int,
    ):
        super().__init__()

        self.config = config
        self.input_dim = input_dim
        subclass = config.backbone_subclass
        layers_config = ENCODER_LAYERS_CONFIG[subclass]

        if "l2" in subclass:
            # add prenormalization and relu layers?
            pre_conv = nn.Sequential(nn.GroupNorm(4, layers_config[0][0]), nn.ReLU())
        else:
            pre_conv = nn.Identity()
        conv_layers = build_conv(layers_config, (input_dim[0],))

        self.layers = nn.Sequential(pre_conv, conv_layers)

        if config.position_dim:
            # infer output dim of encoder
            sample_input = torch.randn(input_dim).unsqueeze(0) # [1, C, H, W]
            sample_output = self.layers(sample_input)
            encoder_output_dim = tuple(sample_output.shape[1:])

            if (
                self.config.position_encoder_arch
                and self.config.position_encoder_arch != "id"
            ):
                layer_dims = [
                    int(x) for x in self.config.position_encoder_arch.split("-")
                ]
                layers = []
                for i in range(len(layer_dims) - 1):
                    layers.append(nn.Linear(layer_dims[i], layer_dims[i + 1]))
                    layers.append(nn.ReLU())
                # remove last ReLU
                layers.pop()

                self.position_encoder = nn.Sequential(
                    *layers,
                    Expander2D(w=encoder_output_dim[-2], h=encoder_output_dim[-1]),
                )
            else:
                self.position_encoder = Expander2D(
                    w=encoder_output_dim[-2], h=encoder_output_dim[-1]
                )
        
    @property
    def repr_dim(self):
        with torch.no_grad():
            sample_inp = torch.randn(self.input_dim).unsqueeze(0)  # [1, C, H, W]
            sample_out = self.layers(sample_inp)
            encoder_output_dim = tuple(sample_out.shape[1:])
            if self.config.position_dim:
                return (self.config.position_dim + encoder_output_dim[0], encoder_output_dim[1], encoder_output_dim[2])
            return encoder_output_dim 
    
    
    def forward(self, x, position=None):
        """
        input:
            x: [T, BS, *] or [BS, T, *]
        output:
            x: [T, BS, *] or [T, BS, *]
        """
        time= False
        if x.dim() == 2 or x.dim() == 4:
            encoded_obs = self.layers(x) # # [BS, C, H, W]
        else:
            time= True
            obs = x.flatten(0, 1) # [T*BS, C, H, W]
            encoded_obs = self.layers(obs) # [T*BS, 16, 10, 10]

        if position is not None:
            if position.dim() == 3:
                position = position.flatten(0, 1)  # [T*BS, pos_dim]
            encoded_pos = self.position_encoder(position)
            z = torch.cat([encoded_obs, encoded_pos], dim=1)
        else:
            encoded_pos = None
            z = encoded_obs
            
        if time:
            new_shape = x.shape[:2] + z.shape[1:]
            z = z.reshape(new_shape)
            return z
        return z
    


In [None]:
#| export
import torch
import torch.nn as nn
from mawm.models.vision.enums import BackboneConfig
from mawm.models.vision.resnet import resnet18, resnet18ID
from mawm.models.misc import build_mlp,Projector, MLP, build_norm1d, PartialAffineLayerNorm

def build_backbone(
    config: BackboneConfig,
    input_dim,
):
    backbone, embedding = None, None
    arch = config.arch

    backbone = MeNet6(
            config=config,
            input_dim=input_dim,
        )
    
    if config.backbone_mlp is not None:
        backbone_mlp = Projector(config.backbone_mlp, embedding)
        backbone = nn.Sequential(backbone, backbone_mlp)

    backbone.input_dim = input_dim
    sample_input = torch.randn(input_dim).unsqueeze(0)

    if config.position_dim is not None:
        sample_position_input = torch.randn(config.position_dim).unsqueeze(0)
        sample_output = backbone(sample_input, position=sample_position_input)
    else:
        sample_output = backbone(sample_input)

    output_dim = tuple(sample_output.encodings.shape[1:])
    output_dim = output_dim[0] if len(output_dim) == 1 else output_dim
    backbone.output_dim = output_dim

    if sample_output.pos_component is not None:
        output_obs_dim = tuple(sample_output.obs_component.shape[1:])
        output_obs_dim = (
            output_obs_dim[0] if len(output_obs_dim) == 1 else output_obs_dim
        )
        output_position_dim = tuple(sample_output.pos_component.shape[1:])
        output_position_dim = (
            output_position_dim[0] if len(output_position_dim) == 1 else output_position_dim
        )
    else:
        output_obs_dim = output_dim
        output_position_dim = 0

    backbone.output_obs_dim = output_obs_dim
    backbone.output_position_dim = output_position_dim

    backbone.config = config

    return backbone

In [None]:
#| hide
from omegaconf import OmegaConf


In [None]:
#| hide
cfg = OmegaConf.load("../cfgs/MPCJepa/mpc.yaml")

cfg.model.backbone

{'arch': 'MeNet6', 'backbone_subclass': 'd4rl_a', 'backbone_mlp': None, 'backbone_width_factor': 2, 'input_dim': 4, 'channels': 3, 'position_dim': 2, 'position_encoder_arch': 'id'}

In [None]:
#| hide
model = MeNet6(
    config=cfg.model.backbone,
    input_dim=(3, cfg.model.img_size, cfg.model.img_size),
)

In [None]:
model

MeNet6(
  (layers): Sequential(
    (0): Identity()
    (1): Sequential(
      (0): Conv2d(3, 16, kernel_size=(5, 5), stride=(1, 1))
      (1): GroupNorm(4, 16, eps=1e-05, affine=True)
      (2): ReLU()
      (3): Conv2d(16, 32, kernel_size=(5, 5), stride=(2, 2))
      (4): GroupNorm(8, 32, eps=1e-05, affine=True)
      (5): ReLU()
      (6): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1))
      (7): GroupNorm(8, 32, eps=1e-05, affine=True)
      (8): ReLU()
      (9): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (10): GroupNorm(8, 32, eps=1e-05, affine=True)
      (11): ReLU()
      (12): Conv2d(32, 16, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (position_encoder): Expander2D()
)

In [None]:
#| hide
T = 8
BS = 16
inp = torch.randn(BS, T, 3, 42, 42)
pos = torch.randn(BS, T, 2)

In [None]:
#| hide
out = model(inp, position=pos)
out.shape

torch.Size([16, 8, 18, 15, 15])

In [None]:
#| hide
model.repr_dim

(18, 15, 15)

### MSG Encoder

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class SemanticEncoder(nn.Module):
    def __init__(self, num_primitives=5, latent_dim=32):
        self.latent_dim = latent_dim
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(num_primitives, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(16 * 3 * 3, 64),
            nn.ReLU(),
            nn.Linear(64, latent_dim)
        )
        
    def forward(self, x):
        B, T, C, H, W = x.shape
        x = rearrange(x, 'b t c h w -> (b t) c h w')
        x = self.net(x) # [B*T, latent_dim]
        x = rearrange(x, '(b t) d -> b t d', b= B)
        return x


In [None]:
#| hide
model = SemanticEncoder(num_primitives=5, latent_dim=32)
inp = torch.randn(16, 8, 5, 7, 7)
out = model(inp)
out.shape

torch.Size([16, 8, 32])

In [None]:
# hide
from mawm.models.utils import flatten_conv_output
out_flat = flatten_conv_output(out)
out_flat.shape

torch.Size([16, 8, 32])

In [None]:
#| hide
from mawm.models.utils import Expander2D
inp = out_flat
m = Expander2D(w= 15, h=15)
out = m(inp)
out.shape

torch.Size([16, 8, 32, 15, 15])

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()