# Trying to make an IDM architecture 

### Step 1: Load the image dataset 

In [89]:
import torchvision 
import os
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader
import torch.nn.functional as F

Loading the images from the data. 
First getting the filenames

In [90]:
relative_path = 'data'
data_dir = os.path.abspath(relative_path)
filenames = [name for name in os.listdir(data_dir) if os.path.splitext(name)[-1] == '.png']



Then using the filenames to load images.


In [91]:
dataset_size = len(filenames)
dataset = torch.zeros(dataset_size, 3, 128, 128)
for i in range(dataset_size-4):
 dataset[i] = torchvision.io.read_image(os.path.join(data_dir, filenames[i]))
 

Normalizing the pixel value by dividing it by 255. Now it is between 0 and 1

In [92]:
dataset.size()
dataset = dataset / 255.0

### Step 2 Pass the data through 3D Convolution

This is not a final model. I am writing this to make proper data shape. 

In [93]:
class Temporal3DConv(nn.Module):
    def __init__(self):
        super(Temporal3DConv, self).__init__()

        # 3 is input channel because of RGB images. 
        # 128 is the output channel or learnable filters
        # Kernel size 5 is temporal kernel width 
        # (1*1) is spatial kernel width
        # 2 Depth padding for initial and end frames
        self.conv3d = nn.Conv3d(3, 128, kernel_size=(5, 1, 1), padding=(2,0,0))
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv3d(x)
        x = self.relu(x)
        return x

In [94]:
temporal3DConv = Temporal3DConv()

This is to make 128 size chunks of frames

In [95]:
# Create a TensorDataset
data = TensorDataset(dataset)

# Create a DataLoader
dataloader = DataLoader(data, batch_size=128, shuffle=False)


For Conv3D the input format is (batch_size, num_channels, num_frames, height, width)
So I am using unsqueeze to increase the outer dimension to make batch_size = 1 . 

Then using the permute to make the dimension in proper shape.

In [96]:
# Iterate over the dataloader in batches
output = any
for framechunk in dataloader:
    # Access the batched tensor data
    # Pass the input through the model
    print(">>>",framechunk[0].unsqueeze(0).size())
    output = temporal3DConv(framechunk[0].unsqueeze(0).permute(0, 2, 1, 3, 4) )

    print(output.shape)  # Shape of the output tensor

>>> torch.Size([1, 128, 3, 128, 128])
torch.Size([1, 128, 128, 128, 128])
>>> torch.Size([1, 128, 3, 128, 128])
torch.Size([1, 128, 128, 128, 128])


### Step 3 Pass the 3D Convolution layer outcome through ResNet

This is the middle Res Net layer. So the ResNetBlock represents the Deep residual learning for image
recognition paper architecture. 

ResNetBlocksWithPooling represents the Resnet stack mentioned in the VPT paper. We will use three stacks consecuvely then flatten it before passing it to attention layer. 

In [97]:
class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)

    def forward(self, x):
        residual = self.shortcut(x)

        x = self.conv1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x += residual
        x = self.relu(x)

        return x

class ResNetBlocksWithPooling(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResNetBlocksWithPooling, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.resnet_block1 = ResNetBlock(out_channels, out_channels)
        self.resnet_block2 = ResNetBlock(out_channels, out_channels)

    def forward(self, x):
        x = self.conv(x)
        x = self.pool(x)
        x = self.resnet_block1(x)
        x = self.resnet_block2(x)
        return x

This is just a sample code to check if the design checks out

In [98]:
print(output.size())
# To match the expected input shape of the ResNet model, we need to reshape the output tensor. 
# First, we use permute to rearrange the dimensions of the tensor, swapping the second and third dimensions. 
# Then, we use contiguous to ensure the tensor's memory is laid out contiguously. 
# Finally, we use view to reshape the tensor into a 4D tensor with dimensions (batch_size * num_frames, num_channels, height, width).
xx = output.permute(0, 2, 1, 3, 4).contiguous().view(1 * 128, 128, 128, 128)
print(xx.size())
layer = ResNetBlocksWithPooling(128, 64)
layer2 = ResNetBlocksWithPooling(64, 64)
layer3 = ResNetBlocksWithPooling(64, 64)
flattenLayer = nn.Flatten()
f = layer.forward(xx)
f2 = layer2.forward(f)
f3 = layer3.forward(f2)
f4 = flattenLayer(f3)
print(f.size())
print(f2.size())
print(f3.size())
print(f4.size())

torch.Size([1, 128, 128, 128, 128])
torch.Size([128, 128, 128, 128])
torch.Size([128, 64, 64, 64])
torch.Size([128, 64, 32, 32])
torch.Size([128, 64, 16, 16])
torch.Size([128, 16384])


### Step 4 Pass ResNet outcome through Multiheaded Residual Transformer

In [109]:

class FrameWiseDense(nn.Module):
    def __init__(self, in_features, out_features):
        super(FrameWiseDense, self).__init__()
        self.linear = nn.Linear(in_features, out_features)
        self.relu = nn.ReLU()

    def forward(self, x):
        out = self.linear(x)
        out = self.relu(out)
        return out

class ResidualTransformerBlock(nn.Module):
    def __init__(self, embedding_dim, num_heads, dropout=0.1):
        super(ResidualTransformerBlock, self).__init__()
        self.attention = nn.MultiheadAttention(embedding_dim, num_heads, dropout=dropout)
        self.dropout1 = nn.Dropout(dropout)
        self.norm1 = nn.LayerNorm(embedding_dim)
        self.dense1 = FrameWiseDense(embedding_dim, 16384)
        self.dropout2 = nn.Dropout(dropout)
        self.norm2 = nn.LayerNorm(embedding_dim)
        self.dense2 = FrameWiseDense(16384, embedding_dim)

    def forward(self, x):
        residual = x
        out, _ = self.attention(x, x, x)
        out = self.dropout1(out)
        out = self.norm1(out + residual)
        residual = out
        out = self.dense1(out)
        out = self.dropout2(out)
        out = self.dense2(out)
        out = self.norm2(out + residual)
        return out

class ActionPredictionModel(nn.Module):
    def __init__(self, num_actions):
        super(ActionPredictionModel, self).__init__()
        # the initial value is 16384 because it is the flattened output dimension for ResNet
        self.dense1 = FrameWiseDense(16384, 256)
        self.dense2 = FrameWiseDense(256, 4096)
        self.residual_transformer_blocks = nn.Sequential(
            ResidualTransformerBlock(embedding_dim=4096, num_heads=32),
            ResidualTransformerBlock(embedding_dim=4096, num_heads=32),
            ResidualTransformerBlock(embedding_dim=4096, num_heads=32),
            ResidualTransformerBlock(embedding_dim=4096, num_heads=32)
        )
        self.dense3 = FrameWiseDense(4096, 16384)
        self.dense4 = FrameWiseDense(16384, 4096)
        self.action_head = nn.Linear(4096, num_actions)

    def forward(self, x):
        out = self.dense1(x)
        out = self.dense2(out)
        print(out.size())
        # out = out.permute(1, 0, 2)
        out = self.residual_transformer_blocks(out)
        # out = out.permute(1, 0, 2)
        out = self.dense3(out)
        out = self.dense4(out)
        print(out.size())
        # out = out.mean(dim=1)
        out = self.action_head(out)
        out = F.softmax(out, dim=1)
        return out




Dummy code to check model compatibility

In [110]:
# Creating an instance of the model
model = ActionPredictionModel(num_actions=4)

output = model(f4)
print(output)  # Output shape: (128, 4)

torch.Size([128, 4096])
torch.Size([128, 4096])
tensor([[0.2615, 0.2518, 0.2016, 0.2851],
        [0.2600, 0.2561, 0.1945, 0.2894],
        [0.2632, 0.2460, 0.1998, 0.2911],
        [0.2584, 0.2536, 0.1964, 0.2916],
        [0.2614, 0.2594, 0.1989, 0.2803],
        [0.2647, 0.2504, 0.1989, 0.2860],
        [0.2624, 0.2554, 0.1919, 0.2903],
        [0.2511, 0.2613, 0.1989, 0.2887],
        [0.2591, 0.2571, 0.1968, 0.2869],
        [0.2695, 0.2571, 0.1995, 0.2739],
        [0.2523, 0.2622, 0.1965, 0.2890],
        [0.2689, 0.2510, 0.1986, 0.2815],
        [0.2542, 0.2593, 0.2032, 0.2833],
        [0.2609, 0.2476, 0.2034, 0.2881],
        [0.2605, 0.2552, 0.1984, 0.2860],
        [0.2540, 0.2559, 0.2041, 0.2859],
        [0.2612, 0.2537, 0.2023, 0.2829],
        [0.2674, 0.2513, 0.2000, 0.2814],
        [0.2570, 0.2571, 0.2025, 0.2834],
        [0.2679, 0.2579, 0.1928, 0.2814],
        [0.2603, 0.2495, 0.2016, 0.2887],
        [0.2669, 0.2479, 0.1948, 0.2904],
        [0.2562, 0.2514, 0.2