In [None]:
import os 
import time 
import numpy as np
import matplotlib.pyplot as plt

from typing import Optional, Tuple, Callable, Optional, Type, Union
from functools import partial

import torch
import torch.nn as nn
import torchvision
import torchvision.models as models

from torch.utils.data import DataLoader
import torch.nn.functional as F
import torchvision.transforms.v2 as transforms
from torchvision import datasets

import torch.optim as optim
import csv
from tqdm import tqdm 
import time as time

torch.manual_seed(42)

In [None]:
transform = transforms.Compose([
    transforms.Resize((32, 32)),  # Ensure all images are 32x32
    transforms.ToTensor(),
])

# Download and create the CIFAR-10 train and test datasets
train_ds = torchvision.datasets.CIFAR10(
    root='./', train=True, download=True, transform=transform
)

test_ds = torchvision.datasets.CIFAR10(
    root='./', train=False, download=True, transform=transform
)

print(f'Train Samples: {len(train_ds)} || Test Samples: {len(test_ds)} || Classes: {len(train_ds.classes)}')

In [None]:
class DepthwiseSeparableConv(nn.Module):
    def __init__(self, in_chans: int, out_chans: int, kernel_size: int, stride: int, padding: int):
        super(DepthwiseSeparableConv, self).__init__()
        self.depthwise = nn.Conv2d(in_chans, in_chans, kernel_size=kernel_size, stride=stride,
                                   padding=padding, groups=in_chans, bias=False)
        self.pointwise = nn.Conv2d(in_chans, out_chans, kernel_size=1, bias=False)

    def forward(self, x):
        x = self.depthwise(x)
        x = self.pointwise(x)
        return x

class PreNormAttention(nn.Module):
    def __init__(
        self,
        dim: int,
        num_heads: int = 8,
        kernel_size: int = 3
    ):
        super().__init__()
        assert dim % num_heads == 0, 'dim should be divisible by num_heads'
        self.dim = dim
        self.num_heads = num_heads
        self.head_dim = dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        self.q = nn.Linear(dim, dim)
        self.k = nn.Linear(dim, dim)
        self.v = nn.Linear(dim, dim)
        self.proj = nn.Linear(dim, dim)
        
        # Depthwise convolution
        self.depthwise_conv = DepthwiseSeparableConv(dim, dim, kernel_size=kernel_size, stride=1, padding=kernel_size // 2)
                
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, P, C = x.shape
        H = self.num_heads
        
        q = self.q(x).view(B, P, H, -1).transpose(1, 2)
        k = self.k(x).view(B, P, H, -1).transpose(1, 2)
        v = self.v(x).view(B, P, H, -1).transpose(1, 2)
        
        q = q * self.scale
        attn = q @ k.transpose(-2, -1)
        attn = attn.softmax(dim=-1)
        
        x_reshaped = x.transpose(1, 2).view(B, C, int(P**0.5), int(P**0.5))
        conv_out = self.depthwise_conv(x_reshaped)
        
        self.conv_feature_maps = conv_out.detach()  # Store conv feature maps for visualization
        
        conv_out = conv_out.view(B, C, P).transpose(1, 2)
        
        x = attn @ v
        x = x.transpose(1, 2).reshape(B, P, C)
        x = x + conv_out
        
        x = F.layer_norm(x, [C])
        x = self.proj(x)
        return x

class PatchEmbed(nn.Module):
    def __init__(self,
                 img_size: int,
                 patch_size: int,
                 in_chans: int,
                 embed_dim: int,
    ):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.grid_size = (self.img_size // self.patch_size, ) * 2
        self.num_patches = self.grid_size[0] * self.grid_size[1]
        self.proj = nn.Conv2d(in_chans,
                              embed_dim,
                              kernel_size=patch_size,
                              stride=patch_size,
                              padding=0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.proj(x)
        x = x.flatten(2).transpose(1,2)
        return x

class PositionalEmbedding(nn.Module):
    def __init__(self, num_patches: int, embed_dim: int):
        super().__init__()
        self.position_embeddings = nn.Parameter(torch.zeros(1, num_patches, embed_dim))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return x + self.position_embeddings

class Attention(nn.Module):
    def __init__(self, img_size: int, patch_size: int, in_channels: int, embed_dim: int, num_heads: int):
        super().__init__()
        
        self.patch_embed = PatchEmbed(img_size, patch_size, in_channels, embed_dim)
        self.positional_embedding = PositionalEmbedding(self.patch_embed.num_patches, embed_dim)
        self.combined_attn = PreNormAttention(embed_dim, num_heads)

    def forward(self, x: torch.Tensor) -> torch.Tensor:

        x = self.patch_embed(x)  # Shape: (B, num_patches, embed_dim)
        x = self.positional_embedding(x)
        x = self.combined_attn(x)  # Shape: (B, num_patches, embed_dim)

        return x
    
class PreNormAttentionModel(nn.Module):
    def __init__(self, img_size: int, patch_size: int, in_channels: int, embed_dim: int, num_heads: int, classes: int = 10):
        super().__init__()     
        self.attn = Attention(img_size, patch_size, in_channels, embed_dim, num_heads)
        
        num_patches = (img_size // patch_size) ** 2
        self.fc = nn.Linear(num_patches * embed_dim, classes)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.attn(x)
        B, P, C = x.shape
        
        x = x.view(B, -1)
        x = self.fc(x)
        return x