# Vision encoder

> a ConvNet module for percpetion.

In [None]:
#| default_exp models.encoder

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export
from fastcore import *
from fastcore.utils import *

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F

class VisionEncoder(nn.Module):
    def __init__(self, in_channels=3, latent_dim=128):
        super().__init__()

        self.latent_dim = latent_dim
        hidden_dims = [32, 64, 128, 256] 

        # -----------------------
        #        Encoder
        # -----------------------
        # Path: 42x42 -> 21x21 -> 11x11 -> 6x6 -> 3x3
        modules = []
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, h_dim, kernel_size=3,
                              stride=2, padding=1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU()
                )
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)

        # Final shape is 256 * 3 * 3 = 2304
        self.flattened_dim = hidden_dims[-1] * 3 * 3 

        self.proj  = nn.Linear(self.flattened_dim, latent_dim)
        # self.fc_var = nn.Linear(self.flattened_dim, latent_dim)

    def forward(self, x):
        N, V = x.shape[:2]
        x = self.encoder(x.flatten(0, 1))
        x = torch.flatten(x, start_dim=1)
        proj = self.proj(x)
        return proj.reshape(N, V, -1).transpose(0, 1)


In [None]:
#| hide
from MAWM.data.loaders import RolloutObservationDataset

# | hide
from torchvision import transforms
import numpy as np

ASIZE, LSIZE, RSIZE, RED_SIZE, SIZE =\
    3, 32, 256, 32, 40


import torch
from torchvision.transforms import v2
train_tf = v2.Compose(
    [
        # Keep the size at 42 to match your model architecture
        v2.ToPILImage(),
        v2.RandomResizedCrop(42, scale=(0.8, 1.0)), 
        v2.RandomApply([v2.ColorJitter(0.4, 0.4, 0.4, 0.1)], p=0.8),
        v2.RandomGrayscale(p=0.2),
        # Reduced kernel size for smaller image resolution
        # v2.RandomApply([v2.GaussianBlur(kernel_size=3, sigma=(0.1, 2.0))], p=0.1),
        v2.RandomHorizontalFlip(),
        v2.ToImage(),
        v2.ToDtype(torch.float32, scale=True),
        # Normalizes to [-1, 1] to match Tanh output
        v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ]
)

dataset = RolloutObservationDataset(
    agent='agent_0',
    root='../marl_grid_data/',
    transform=train_tf,
    buffer_size=10,
    train=True
)
dataset.load_next_buffer()

Loading file buffer ...: 100%|██████████| 10/10 


In [None]:
#| hide
class HFDataset(torch.utils.data.Dataset):
    def __init__(self, V=1):
        self.V = V
        self.ds = dataset
        self.aug = train_tf
        self.test = v2.Compose(
            [
                v2.ToPILImage(),
                v2.Resize(42),
                v2.CenterCrop(42),
                v2.ToImage(),
                v2.ToDtype(torch.float32, scale=True),
                v2.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
            ]
        )

    def __getitem__(self, i):
        item = self.ds[i][0]
        img = item# item["image"].convert("RGB")
        transform = self.aug if self.V > 1 else self.test
        return torch.stack([transform(img) for _ in range(self.V)])

    def __len__(self):
        return len(self.ds)

In [None]:
#| hide
ds = HFDataset(V = 2)
from torch.utils.data import DataLoader

loader = DataLoader(ds, batch_size = 16, shuffle=True, drop_last=True)

In [None]:
#| hide
x = torch.randn(16, 2, 3, 42, 42)
x.flatten(0, 1).shape

torch.Size([32, 3, 42, 42])

In [None]:
#| hide
model = VisionEncoder()
# x = torch.randn(4, 2, 3, 42, 42)
for i, data in enumerate(loader):
    y = model(data)
    y.shape
    break

In [None]:
y.shape

torch.Size([2, 16, 128])

## VAE Encoder

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class VAE2(nn.Module):
    def __init__(self, in_channels=3, latent_dim=128):
        super().__init__()

        self.latent_dim = latent_dim
        hidden_dims = [32, 64, 128, 256] 

        # -----------------------
        #        Encoder
        # -----------------------
        # Path: 42x42 -> 21x21 -> 11x11 -> 6x6 -> 3x3
        modules = []
        for h_dim in hidden_dims:
            modules.append(
                nn.Sequential(
                    nn.Conv2d(in_channels, h_dim, kernel_size=3,
                              stride=2, padding=1),
                    nn.BatchNorm2d(h_dim),
                    nn.LeakyReLU()
                )
            )
            in_channels = h_dim

        self.encoder = nn.Sequential(*modules)

        # Final shape is 256 * 3 * 3 = 2304
        self.flattened_dim = hidden_dims[-1] * 3 * 3 

        self.fc_mu  = nn.Linear(self.flattened_dim, latent_dim)
        self.fc_var = nn.Linear(self.flattened_dim, latent_dim)

        # -----------------------
        #        Decoder
        # -----------------------
        self.decoder_input = nn.Linear(latent_dim, self.flattened_dim)

        # Reverse hidden dims for decoder: [256, 128, 64, 32]
        hidden_dims = hidden_dims[::-1]

        # We need to reach 42x42. 
        # Layer 1: 3x3   -> 6x6  (out_pad=1)
        # Layer 2: 6x6   -> 11x11 (out_pad=0)
        # Layer 3: 11x11 -> 21x21 (out_pad=0)
        # Layer 4: 21x21 -> 42x42 (out_pad=1)
        
        # Manual output paddings to match the encoder's rounding behavior
        out_paddings = [1, 0, 0] 

        modules = []
        for i in range(len(hidden_dims) - 1):
            modules.append(
                nn.Sequential(
                    nn.ConvTranspose2d(hidden_dims[i],
                                       hidden_dims[i + 1],
                                       kernel_size=3,
                                       stride=2,
                                       padding=1,
                                       output_padding=out_paddings[i]),
                    nn.BatchNorm2d(hidden_dims[i + 1]),
                    nn.LeakyReLU()
                )
            )

        self.decoder = nn.Sequential(*modules)

        # Final layer: 21x21 -> 42x42
        self.final_layer = nn.Sequential(
            nn.ConvTranspose2d(hidden_dims[-1],
                               hidden_dims[-1],
                               kernel_size=3,
                               stride=2,
                               padding=1,
                               output_padding=1),
            nn.BatchNorm2d(hidden_dims[-1]),
            nn.LeakyReLU(),
            nn.Conv2d(hidden_dims[-1], 3, kernel_size=3, padding=1),
            nn.Tanh()
        )

    def encode(self, x):
        h = self.encoder(x)
        h = torch.flatten(h, start_dim=1)
        mu = self.fc_mu(h)
        log_var = self.fc_var(h)
        return mu, log_var

    def reparameterize(self, mu, log_var):
        std = torch.exp(0.5 * log_var)
        eps = torch.randn_like(std)
        return mu + eps * std

    def decode(self, z):
        h = self.decoder_input(z)
        # Reshape to (batch, 256, 3, 3)
        h = h.view(-1, 256, 3, 3)
        h = self.decoder(h)
        h = self.final_layer(h)
        return h

    def forward(self, x):
        mu, log_var = self.encode(x)
        z = self.reparameterize(mu, log_var)
        recon = self.decode(z)
        return recon, x, mu, log_var

In [None]:
#| hide
model = VAE2(3, 128)
x = torch.randn(16, 3, 42, 42)
u, lo = model.encode(x)
x_hat = model.decode(u)
x_hat.shape, u.shape


(torch.Size([16, 3, 42, 42]), torch.Size([16, 128]))

## Resnet 18 Encoder

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F

class ResidualBlock(nn.Module):
    def __init__(self, in_len, out_len, stride=1):
        super(ResidualBlock, self).__init__()
        
        self.conv1 = nn.Conv2d(in_len, out_len, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_len)
        self.conv2 = nn.Conv2d(out_len, out_len, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_len)
        
        layers = []
        if stride != 1 or in_len != out_len:
            layers.append(nn.Conv2d(in_len, out_len, kernel_size=1, stride=stride, bias=False))
            layers.append(nn.BatchNorm2d(out_len))
            self.shortcut = nn.Sequential(*layers)

        else:
            self.shortcut = nn.Sequential()
        

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out)) # Conv then BN
        out += self.shortcut(x)
        out = F.relu(out)
        return out

In [None]:
#| export
class ResNet18(nn.Module):
    def __init__(self):
        super(ResNet18, self).__init__()
        
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
                
        self.layer1 = nn.Sequential(
            ResidualBlock(64, 64, stride=1),
            ResidualBlock(64, 64, stride=1)
        )
        self.layer2 = nn.Sequential(
            ResidualBlock(64, 128, stride=2),
            ResidualBlock(128, 128, stride=1)
        )
        self.layer3 = nn.Sequential(
            ResidualBlock(128, 256, stride=2),
            ResidualBlock(256, 256, stride=1)
        )
        self.layer4 = nn.Sequential(
            ResidualBlock(256, 512, stride=2),
            ResidualBlock(512, 512, stride=1)
        )
        
        # Use Adaptive Pooling to handle the 6x6 -> 1x1 transition automatically
        self.adaptive_pool = nn.AdaptiveAvgPool2d((3, 3)) 
        self.linear = nn.Linear(512 * 3 * 3, 10) # 4608 inputs

    def get_feature_space(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        
        # Changed: use adaptive_pool instead of F.avg_pool2d(out, 4)
        out = self.adaptive_pool(out) 
        out = out.view(out.size(0), -1)
        return out

    def forward(self, x):
        N, V = x.shape[:2]
        out = self.get_feature_space(x.flatten(0, 1))
        proj = self.linear(out)
        return proj.reshape(N, V, -1).transpose(0, 1)

In [None]:
#| hide
x = torch.randn(16, 2, 3, 42, 42)
model = ResNet18()
y = model(x)
y.shape

torch.Size([2, 16, 10])

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()