# Communication model

> CNN for image feature extraction.

In [None]:
#| default_exp models.comm

In [None]:
#| hide
from nbdev.showdoc import *

In [None]:
#| export 
import numpy as np
import torch
import torch.distributions as td
import torch.nn as nn
from fastcore.utils import *

from torch import nn
from torch.nn import functional as F

### MSG Encoder

In [None]:
#| export
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

class MSGEnc(nn.Module):
    def __init__(self, num_primitives=5, latent_dim=32):
        self.latent_dim = latent_dim
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(num_primitives, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv2d(16, 16, kernel_size=3, stride=2),
            nn.ReLU(),
            nn.Flatten(),
            nn.Linear(16 * 3 * 3, 64),
            nn.ReLU(),
            nn.Linear(64, latent_dim)
        )
        
    def forward(self, x):
        if x.dim() == 4:
            x = x.unsqueeze(1)  # Add time dimension if missing
        B, T, C, H, W = x.shape
        x = rearrange(x, 'b t c h w -> (b t) c h w')
        x = self.net(x) # [B*T, latent_dim]
        x = rearrange(x, '(b t) d -> b t d', b= B)
        return x


In [None]:
#| hide
model = MSGEnc(num_primitives=5, latent_dim=32)
inp = torch.randn(16, 8, 5, 7, 7)
out = model(inp)
out.shape

torch.Size([16, 8, 32])

### Communication Module

In [None]:
#| export
import torch
import torch.nn as nn
from einops import rearrange
class CommModule(nn.Module):
    def __init__(self, input_channel= 32):
        super().__init__()
        
        # input shape: (batch, input_channel, 15, 15)
        self.network = nn.Sequential(
            # First layer: Refine latent features without changing spatial size
            nn.Conv2d(input_channel, 32, kernel_size=3, stride=1, padding=1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            
            # Second layer: Downsample from 15x15 to 7x7
            # Formula: floor((15 + 2*0 - 3) / 2) + 1 = 7
            nn.Conv2d(32, 64, kernel_size=3, stride=2, padding=0),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            
            # Final layer: Map to the 5 element channels
            # We use a 1x1 convolution to reach the target channel count
            nn.Conv2d(64, 5, kernel_size=1)
        )

    def forward(self, x):
        if x.dim() == 5:
            # print("Reshaping input from 5D to 4D for processing.")
            # Reshape from (batch, time, channels, height, width) to (batch * time, channels, height, width)
            b, t, c, h, w = x.shape
            x = x.view(b * t, c, h, w)
        elif x.dim() == 4:
            # print("Input shape is already 4D, proceeding without reshaping.")
            b, c, h, w = x.shape
        # Output shape: (batch, 5, 7, 7)
        x =  self.network(x)
        x = rearrange(x, '(b t) c h w -> b t c h w', b= b, t= t) if 't' in locals() else x
        return x



In [None]:
#| hide
model = CommModule(input_channel=32)
latent_input = torch.randn(64, 20, 32, 15, 15)  # Simulated latent representation
output = model(latent_input)

print(f"Input shape: {latent_input.shape}")
print(f"Output shape: {output.shape}") # Should be [32, 5, 7, 7]

Input shape: torch.Size([64, 20, 32, 15, 15])
Output shape: torch.Size([64, 20, 5, 7, 7])


In [None]:
#| hide
from torch.nn.functional import softmax
softmax(output, dim= 1).shape

torch.Size([64, 20, 5, 7, 7])

In [None]:
#| hide
class_indices = torch.argmax(output, dim=2)
class_indices.max(), class_indices.min()

(tensor(4), tensor(0))

In [None]:
#| hide
class_indices

tensor([[[[1, 0, 0,  ..., 1, 3, 2],
          [1, 1, 0,  ..., 3, 1, 2],
          [1, 2, 1,  ..., 1, 2, 2],
          ...,
          [2, 4, 0,  ..., 1, 3, 1],
          [3, 0, 2,  ..., 1, 1, 3],
          [1, 1, 2,  ..., 0, 1, 1]],

         [[0, 3, 1,  ..., 2, 1, 2],
          [1, 1, 1,  ..., 3, 4, 1],
          [2, 0, 3,  ..., 1, 1, 2],
          ...,
          [1, 2, 3,  ..., 0, 2, 0],
          [3, 0, 0,  ..., 2, 1, 0],
          [3, 2, 0,  ..., 1, 1, 0]],

         [[3, 1, 2,  ..., 2, 3, 3],
          [1, 0, 3,  ..., 3, 1, 0],
          [2, 3, 1,  ..., 2, 3, 3],
          ...,
          [1, 1, 2,  ..., 3, 4, 1],
          [1, 0, 1,  ..., 3, 3, 3],
          [2, 4, 2,  ..., 1, 1, 3]],

         ...,

         [[1, 1, 1,  ..., 3, 3, 2],
          [1, 3, 0,  ..., 3, 0, 3],
          [3, 3, 1,  ..., 3, 0, 3],
          ...,
          [3, 1, 2,  ..., 4, 1, 1],
          [3, 3, 1,  ..., 2, 1, 1],
          [1, 1, 1,  ..., 3, 3, 1]],

         [[2, 1, 1,  ..., 3, 3, 3],
          [1, 3, 

In [None]:
#| hide
import torch.nn.functional as F
one_hot_grid = F.one_hot(class_indices, num_classes=5)
one_hot_grid

tensor([[[[[0, 1, 0, 0, 0],
           [1, 0, 0, 0, 0],
           [1, 0, 0, 0, 0],
           ...,
           [0, 1, 0, 0, 0],
           [0, 0, 0, 1, 0],
           [0, 0, 1, 0, 0]],

          [[0, 1, 0, 0, 0],
           [0, 1, 0, 0, 0],
           [1, 0, 0, 0, 0],
           ...,
           [0, 0, 0, 1, 0],
           [0, 1, 0, 0, 0],
           [0, 0, 1, 0, 0]],

          [[0, 1, 0, 0, 0],
           [0, 0, 1, 0, 0],
           [0, 1, 0, 0, 0],
           ...,
           [0, 1, 0, 0, 0],
           [0, 0, 1, 0, 0],
           [0, 0, 1, 0, 0]],

          ...,

          [[0, 0, 1, 0, 0],
           [0, 0, 0, 0, 1],
           [1, 0, 0, 0, 0],
           ...,
           [0, 1, 0, 0, 0],
           [0, 0, 0, 1, 0],
           [0, 1, 0, 0, 0]],

          [[0, 0, 0, 1, 0],
           [1, 0, 0, 0, 0],
           [0, 0, 1, 0, 0],
           ...,
           [0, 1, 0, 0, 0],
           [0, 1, 0, 0, 0],
           [0, 0, 0, 1, 0]],

          [[0, 1, 0, 0, 0],
           [0, 1, 0, 0, 0]

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()