# Trying to make an IDM architecture 

### Step 1: Load the image dataset 

In [89]:
import torchvision 
import os
import torch
import torch.nn as nn
from torch.utils.data import TensorDataset, DataLoader

Loading the images from the data. 
First getting the filenames

In [90]:
relative_path = 'data'
data_dir = os.path.abspath(relative_path)
filenames = [name for name in os.listdir(data_dir) if os.path.splitext(name)[-1] == '.png']



Then using the filenames to load images.


In [91]:
dataset_size = len(filenames)
dataset = torch.zeros(dataset_size, 3, 128, 128)
for i in range(dataset_size-4):
 dataset[i] = torchvision.io.read_image(os.path.join(data_dir, filenames[i]))
 

Normalizing the pixel value by dividing it by 255. Now it is between 0 and 1

In [92]:
dataset.size()
dataset = dataset / 255.0

### Step 2 Pass the data through 3D Convolution

This is not a final model. I am writing this to make proper data shape. 

In [93]:
class Temporal3DConv(nn.Module):
    def __init__(self):
        super(Temporal3DConv, self).__init__()

        # 3 is input channel because of RGB images. 
        # 128 is the output channel or learnable filters
        # Kernel size 5 is temporal kernel width 
        # (1*1) is spatial kernel width
        # 2 Depth padding for initial and end frames
        self.conv3d = nn.Conv3d(3, 128, kernel_size=(5, 1, 1), padding=(2,0,0))
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.conv3d(x)
        x = self.relu(x)
        return x

In [94]:
temporal3DConv = Temporal3DConv()

This is to make 128 size chunks of frames

In [95]:
# Create a TensorDataset
data = TensorDataset(dataset)

# Create a DataLoader
dataloader = DataLoader(data, batch_size=128, shuffle=False)


For Conv3D the input format is (batch_size, num_channels, num_frames, height, width)
So I am using unsqueeze to increase the outer dimension to make batch_size = 1 . 

Then using the permute to make the dimension in proper shape.

In [96]:
# Iterate over the dataloader in batches
output = any
for framechunk in dataloader:
    # Access the batched tensor data
    # Pass the input through the model
    print(">>>",framechunk[0].unsqueeze(0).size())
    output = temporal3DConv(framechunk[0].unsqueeze(0).permute(0, 2, 1, 3, 4) )

    print(output.shape)  # Shape of the output tensor

>>> torch.Size([1, 128, 3, 128, 128])
torch.Size([1, 128, 128, 128, 128])
>>> torch.Size([1, 128, 3, 128, 128])
torch.Size([1, 128, 128, 128, 128])


### Step 3 Pass the 3D Convolution layer outcome through ResNet

This is the middle Res Net layer. So the ResNetBlock represents the Deep residual learning for image
recognition paper architecture. 

ResNetBlocksWithPooling represents the Resnet stack mentioned in the VPT paper. We will use three stacks consecuvely then flatten it before passing it to attention layer. 

In [97]:
class ResNetBlock(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResNetBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.relu = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)

        self.shortcut = nn.Sequential()
        if in_channels != out_channels:
            self.shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1)

    def forward(self, x):
        residual = self.shortcut(x)

        x = self.conv1(x)
        x = self.relu(x)

        x = self.conv2(x)
        x += residual
        x = self.relu(x)

        return x

class ResNetBlocksWithPooling(nn.Module):
    def __init__(self, in_channels, out_channels):
        super(ResNetBlocksWithPooling, self).__init__()
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.pool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
        self.resnet_block1 = ResNetBlock(out_channels, out_channels)
        self.resnet_block2 = ResNetBlock(out_channels, out_channels)

    def forward(self, x):
        x = self.conv(x)
        x = self.pool(x)
        x = self.resnet_block1(x)
        x = self.resnet_block2(x)
        return x

This is just a sample code to check if the design checks out

In [98]:
print(output.size())
xx = output.permute(0, 2, 1, 3, 4).contiguous().view(1 * 128, 128, 128, 128)
print(xx.size())
layer = ResNetBlocksWithPooling(128, 64)
layer2 = ResNetBlocksWithPooling(64, 64)
layer3 = ResNetBlocksWithPooling(64, 64)
flattenLayer = nn.Flatten()
f = layer.forward(xx)
f2 = layer2.forward(f)
f3 = layer3.forward(f2)
f4 = flattenLayer(f3)
print(f.size())
print(f2.size())
print(f3.size())
print(f4.size())

torch.Size([1, 128, 128, 128, 128])
torch.Size([128, 128, 128, 128])
torch.Size([128, 64, 64, 64])
torch.Size([128, 64, 32, 32])
torch.Size([128, 64, 16, 16])
torch.Size([128, 16384])
