In [4]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, random_split
from PIL import Image
# -- no augmentations --
train_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 1. Download the full training dataset
full_train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=None)

# 2. Split the full training dataset into a training and a validation set
train_size = int(0.9 * len(full_train_dataset)) # 90% for training
val_size = len(full_train_dataset) - train_size # 10% for validation
train_subset, val_subset = random_split(full_train_dataset, [train_size, val_size])

print(f"Total training images: {len(full_train_dataset)}")
print(f"Training subset size: {len(train_subset)}")
print(f"Validation subset size: {len(val_subset)}")


# 3. Download the test dataset
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)
print(f"Test dataset size: {len(test_dataset)}")

100.0%


Total training images: 50000
Training subset size: 45000
Validation subset size: 5000
Test dataset size: 10000


In [10]:
from torchvision.utils import save_image
from torchvision.transforms.functional import to_tensor
# Define the inverse of your normalization transform
# If Normalize was ((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), the inverse is the same.
# T(x) = (x - mean) / std  =>  x = T(x) * std + mean
un_normalize = transforms.Normalize(
   mean=[-0.5 / 0.5, -0.5 / 0.5, -0.5 / 0.5],
   std=[1 / 0.5, 1 / 0.5, 1 / 0.5]
)


# Get the first image and its label
image_tensor, label = full_train_dataset[0]
print(f"Original tensor shape: {image_tensor.size}")
print(f"Class label: {full_train_dataset.classes[label]}")

# Un-normalize the tensor before saving
image_to_save = un_normalize(to_tensor(image_tensor))

# Save the image
save_image(image_to_save, 'cifar_image1.png')
print("Saved one image as 'cifar_image1.png'")

Original tensor shape: (32, 32)
Class label: frog
Saved one image as 'cifar_image1.png'


In [None]:
test_path = 'data/cifar_image1.png'
test_img = Image.open(test_path)
test_img

(32, 32)

In [None]:

import torch.nn as nn
from dataclasses import dataclass
from einops import rearrange, repeat
@dataclass
class DataParameter:
    img_size: int = 32
    in_channel: int = 3

@dataclass
class Hyperparameters:
    patch_size: int = 4
    D: int = 256# this is the hidden_dimension of Xmer
    batch_size: int = 1024 # this is the batch size
    num_attn_head: int = 4
    transformer_layers: int = 6


class PatchEmbedding(nn.Module):
    """
    Splits an image into patches and embeds them.
    Drop out is missing
    """
    def __init__(self, image_data = DataParameter(), hparams = Hyperparameters()):
        super(PatchEmbedding, self).__init__()
        self.hparams = hparams
        self.patch_size = hparams.patch_size
        self.n_patches = (image_data.img_size // hparams.patch_size) ** 2
        self.cls_token = nn.Parameter(torch.rand(1,1,hparams.D))
        self.patch_embed = nn.Conv2d(
            in_channels=image_data.in_channel,
            out_channels=hparams.D,
            kernel_size=hparams.patch_size,
            stride=hparams.patch_size
        )

    def forward(self, x):
        """
        Forward pass.
        Args:
            x (torch.Tensor): Input image tensor with shape [B, C, H, W].
        Returns:
            torch.Tensor: Embedded patches with shape [B, N, D].
        """
        # x shape: [B, C, H, W] -> [B, D, H/P, W/P]
        # Example: [B, 3, 32, 32] -> [B, 256, 8, 8]
        b = x.shape[0]
        x = self.patch_embed(x)
        x = rearrange(x, 'b d h w -> b (h w) d')
        cls_token = repeat(self.cls_token, '1 1 d -> b 1 d', b=b)
        x = torch.cat((cls_token, x), dim=1) # added the cls token which is a learnable param
        return x


In [None]:
from torch.nn import Softmax
class Attention(nn.Module):
    def __init__(self, hparams = Hyperparameters()):
        super(Attention, self).__init__()
        assert hparams.D % hparams.num_attn_head == 0, "Embedding dim (D) must be divisible by num_heads"
        self.hparams = hparams
        self.num_attn_head = hparams.num_attn_heads
        self.attn_head_size = hparams.D // hparams.num_attn_head
        self.all_head_size = self.num_attn_head * self.attn_head_size

        self.query = nn.Linear(hparams.D, self.all_head_size)
        self.key = nn.Linear(hparams.D, self.all_head_size)
        self.value = nn.Linear(hparams.D, self.all_head_size)
        self.output = nn.Linear(hparams.D, hparams.D)

        self.attn_dropout = nn.Dropout(0.0)
        self.proj_dropout = nn.Dropout(0.0)
        self.softmax = Softmax(dim=-1)
    def forward(self, x, mask=None):
        # x = [B, N, D] N= num_pathces + 1 (coming from cls token )
        q = self.query(x)
        k = self.query(x)
        v = self.query(x)
        # reshape for multi head processing [B, N, D] -> [B, num_heads, N, head_dim]
        q = rearrange(q, 'b n (h d) -> b h n d', h=self.num_attn_head)
        k = rearrange(q, 'b n (h d) -> b h n d', h=self.num_attn_head)
        v = rearrange(q, 'b n (h d) -> b h n d', h=self.num_attn_head)
        # Reshape by transposing and then do dot product for attention
        # attention has shape [B, h, N, N]
        attn_score = torch.matmul(q, k.transpose(-1,-2)) / (self.attn_head_size**0.5)
        if mask is not None:
            attn_scores = attn_scores.masked_fill(mask == 0, -1e9)
        attn_probs = self.softmax(attn_scores)
        attn_probs = self.attn_dropout(attn_probs)
        # sum with V
        context = torch.matmul(attn_probs,v) #[B,h,n,d]
        # combine all heads
        context = rearrange(context, 'b h n d -> b n (h d)')
        output = self.output(context)
        output = self.proj_dropout(output)
        return output


In [None]:
from torch.nn import GELU
class MLP(nn.Module):
    def __init__(self, dropout, hparams=Hyperparameters()):
        super(MLP, self).__init__()
        self.D = hparams.D
        self.hidden_dim = 4 * self.D
        self.dropout = dropout
        self.net = nn.Sequential(
            nn.Linear(in_features=self.D, out_features=self.hidden_dim),
            nn.GELU(),
            nn.Dropout(self.dropout),
            nn.Linear(in_features=self.hidden_dim, out_features=self.D),
            nn.Dropout(self.dropout)
        )
    def forward(self, x):
        # [B, 65, D] coming from attention
        return self.net(x)

In [None]:
class EncoderBlock(nn.Module):
    def __init__(self, hparams=Hyperparameters()):
        super(EncoderBlock, self).__init__()
        self.D = hparams.D
        self.norm1 = nn.LayerNorm(self.D)
        self.attn = Attention(hparams=hparams)
        self.norm2 = nn.LayerNorm(self.D)
        self.ffn = MLP(dropout=0.1, hparams=hparams)

    def forward(self, x):
        residual = x
        x = self.norm1(x)
        x = self.attn(x) + residual

        residual = x
        x = self.norm2(x)
        x = self.ffn(x) + residual
        return x
class Transformer(nn.Module):
    def __init__(self, hparams=Hyperparameters()):
        super(Transformer, self).__init__()
        self.depth = hparams.transformer_layers
        self.layers = nn.ModuleList([
            EncoderBlock(hparams=hparams) for _ in range(self.depth)
        ])
    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x