In [None]:
import sys
import os
sys.path.append('..')

from Utils.Accuracy_measures import topk_accuracy
from Utils.TinyImageNet_loader import get_tinyimagenet_dataloaders
from Utils.Num_parameter import count_parameters
import torchvision.transforms as transforms

import time
import torch


import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import models, datasets, transforms
import torch.nn.functional as F
from torch.utils.data import DataLoader
from tqdm import tqdm

# Set image size and batch size
image_size = 224
batch_size = 64

# Define transforms for training, validation, and testing
tiny_transform_train = transforms.Compose([
    transforms.RandomHorizontalFlip(),
    transforms.Resize((image_size, image_size)),
    transforms.RandomCrop(image_size, padding=5),
    transforms.RandomRotation(10),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225))
])

tiny_transform_val = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225))
])

tiny_transform_test = transforms.Compose([
    transforms.Resize((image_size, image_size)),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406),
                         (0.229, 0.224, 0.225))
])

# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f'Using device: {device}')

train_loader, val_loader, test_loader= get_tinyimagenet_dataloaders(
                                        data_dir = '../datasets',
                                        transform_train=tiny_transform_train,
                                        transform_val=tiny_transform_val,
                                        transform_test=tiny_transform_test,
                                        batch_size=batch_size,
                                        image_size=image_size)

# Define the custom patch embedding class with adjustments
class DualPatchEmbedding(nn.Module):
    def __init__(self, img_size=224, patch_size1=16, patch_size2=8, embed_dim=768):
        super(DualPatchEmbedding, self).__init__()
        self.img_size = img_size
        self.patch_size1 = patch_size1  # 16
        self.patch_size2 = patch_size2  # 8
        self.embed_dim = embed_dim

        # Number of patches before rearrangement
        self.num_patches1 = (img_size // patch_size1) ** 2  # 14*14=196
        self.num_patches2 = (img_size // patch_size2) ** 2  # 28*28=784

        # Generate the custom index mapping once
        self.idx = self.generate_custom_order()
        self.num_patches = len(self.idx)  # Now num_patches = 392 after rearrangement

        # Embedding layers for both patch sizes
        self.proj1 = nn.Conv2d(3, embed_dim, kernel_size=patch_size1, stride=patch_size1)
        self.proj2 = nn.Conv2d(3, embed_dim, kernel_size=patch_size2, stride=patch_size2)

        # Positional embeddings
        self.pos_embed = nn.Parameter(torch.zeros(1, self.num_patches + 1, embed_dim))  # +1 for cls token

        # Class token
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))

    def forward(self, x):
        # x shape: [batch_size, 3, img_size, img_size]

        # Patch embeddings for 16x16 patches
        x1 = self.proj1(x)  # Shape: [batch_size, embed_dim, H1, W1]
        x1 = x1.flatten(2).transpose(1, 2)  # Shape: [batch_size, num_patches1, embed_dim]

        # Repeat each patch twice to get 392 patches
        x1 = x1.repeat_interleave(2, dim=1)  # Now num_patches1 * 2 = 392

        # Patch embeddings for 8x8 patches
        x2 = self.proj2(x)  # Shape: [batch_size, embed_dim, H2, W2]
        x2 = x2.flatten(2).transpose(1, 2)  # Shape: [batch_size, 784, embed_dim]

        # Rearrange x2 according to the specified pattern
        x2 = self.rearrange_patches(x2)  # Now x2 shape: [batch_size, 392, embed_dim]

        # Use the correct positional embeddings
        x = x2 + self.pos_embed[:, 1:, :]  # Exclude cls token position
        # x shape: [batch_size, 392, embed_dim]

        return x, x1, x2

    def rearrange_patches(self, x):
        x = x[:, self.idx, :]  # Rearranged patches
        return x

    def generate_custom_order(self):
        idx = []
        # Adjusted pattern based on your description
        for row_block in range(0, 28, 4):
            for col_block in range(0, 28, 2):
                base = row_block * 28 + col_block
                idx.extend([
                    base,
                    base + 1,
                    base + 28,
                    base + 29
                ])
        return idx

# Define the modified multi-head attention class
class ModifiedMultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, alpha=0.5, beta=0.5):
        super(ModifiedMultiheadAttention, self).__init__()
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.alpha = alpha
        self.beta = beta

        self.head_dim = embed_dim // num_heads
        assert (
            self.head_dim * num_heads == embed_dim
        ), "embed_dim must be divisible by num_heads"

        # Define linear layers for Q, K, V
        self.qkv = nn.Linear(embed_dim, embed_dim * 3, bias=False)

        # Output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x, x1, x2):
        # x: [batch_size, num_patches+1, embed_dim]
        # x1, x2: [batch_size, num_patches+1, embed_dim]

        batch_size, num_patches, embed_dim = x.size()

        # Compute Q, K, V for both inputs
        qkv = self.qkv(x)
        qkv1 = self.qkv(x1)
        qkv2 = self.qkv(x2)

        q, _, _ = qkv.chunk(3, dim=-1)
        _, k1, _ = qkv1.chunk(3, dim=-1)
        _, k2, v2 = qkv2.chunk(3, dim=-1)  # V comes from x2 (8x8 patches)

        # Reshape for multi-head attention
        q = q.reshape(batch_size, num_patches, self.num_heads, self.head_dim).transpose(1, 2)
        k1 = k1.reshape(batch_size, num_patches, self.num_heads, self.head_dim).transpose(1, 2)
        k2 = k2.reshape(batch_size, num_patches, self.num_heads, self.head_dim).transpose(1, 2)
        v2 = v2.reshape(batch_size, num_patches, self.num_heads, self.head_dim).transpose(1, 2)

        # Compute attention scores
        attn_scores1 = torch.matmul(q, k1.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attn_scores2 = torch.matmul(q, k2.transpose(-2, -1)) / (self.head_dim ** 0.5)

        # Combine attention scores
        attn_scores = self.alpha * attn_scores1 + self.beta * attn_scores2

        # Apply softmax
        attn_probs = F.softmax(attn_scores, dim=-1)

        # Multiply by V (from x2)
        attn_output = torch.matmul(attn_probs, v2)

        # Reshape and project output
        attn_output = attn_output.transpose(1, 2).reshape(batch_size, num_patches, embed_dim)
        attn_output = self.out_proj(attn_output)

        return attn_output

# Define the modified transformer encoder layer
class ModifiedTransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_dim, dropout_rate=0.1):
        super(ModifiedTransformerEncoderLayer, self).__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = ModifiedMultiheadAttention(embed_dim, num_heads)
        self.dropout1 = nn.Dropout(dropout_rate)

        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = nn.Sequential(
            nn.Linear(embed_dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout_rate),
            nn.Linear(mlp_dim, embed_dim),
            nn.Dropout(dropout_rate),
        )

    def forward(self, x, x1, x2):
        # x: Input embeddings
        # x1, x2: Additional inputs for attention

        # Attention block
        attn_output = self.attn(self.norm1(x), self.norm1(x1), self.norm1(x2))
        x = x + self.dropout1(attn_output)

        # MLP block
        x = x + self.mlp(self.norm2(x))
        return x

# Define the modified ViT model with adjustments
class ModifiedViT(nn.Module):
    def __init__(
        self,
        img_size=224,
        patch_size1=16,
        patch_size2=8,
        num_classes=200,
        embed_dim=768,
        depth=12,
        num_heads=12,
        mlp_dim=3072,
        dropout_rate=0.1,
    ):
        super(ModifiedViT, self).__init__()

        self.patch_embed = DualPatchEmbedding(
            img_size=img_size,
            patch_size1=patch_size1,
            patch_size2=patch_size2,
            embed_dim=embed_dim,
        )

        # The positional embeddings and class token are part of the patch embedding
        self.pos_embed = self.patch_embed.pos_embed  # Reuse the positional embeddings
        self.cls_token = self.patch_embed.cls_token  # Reuse the class token
        self.pos_drop = nn.Dropout(p=dropout_rate)

        self.layers = nn.ModuleList(
            [
                ModifiedTransformerEncoderLayer(
                    embed_dim, num_heads, mlp_dim, dropout_rate
                )
                for _ in range(depth)
            ]
        )

        self.norm = nn.LayerNorm(embed_dim)
        self.head = nn.Linear(embed_dim, num_classes)

    def forward(self, x):
        # x shape: [batch_size, 3, img_size, img_size]
        x, x1, x2 = self.patch_embed(x)  # Get embeddings and additional inputs

        batch_size = x.size(0)

        # Concatenate class token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)  # Shape: [batch_size, 1, embed_dim]
        x = torch.cat((cls_tokens, x), dim=1)   # x shape: [batch_size, 393, embed_dim]
        x1 = torch.cat((cls_tokens, x1), dim=1)  # x1 shape: [batch_size, 393, embed_dim]
        x2 = torch.cat((cls_tokens, x2), dim=1)  # x2 shape: [batch_size, 393, embed_dim]

        # Add positional embeddings
        x = x + self.pos_embed  # x and self.pos_embed should have matching shapes
        x = self.pos_drop(x)

        # Transformer encoder
        for layer in self.layers:
            x = layer(x, x1, x2)

        x = self.norm(x)

        # Classifier head
        cls_output = x[:, 0]  # Extract the class token
        logits = self.head(cls_output)

        return logits

# Initialize the modified model
model = ModifiedViT(
    img_size=224,
    patch_size1=16,
    patch_size2=8,
    num_classes=200,
    embed_dim=768,
    depth=12,
    num_heads=12,
    mlp_dim=3072,
    dropout_rate=0.1,
)

# Move the model to the appropriate device
model = model.to(device)

# Define loss function and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-4, weight_decay=0.01)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

# Define accuracy calculation function
def calculate_accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for specified values of k"""
    maxk = max(topk)
    batch_size = target.size(0)

    # Get the indices of the top k predictions
    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()

    # Compare predictions with targets
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        # Calculate the number of correct predictions in top k
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append((correct_k / batch_size).item() * 100)
    return res  # Returns a list of accuracies

# Training and validation loops
def train_model(model, criterion, optimizer, scheduler, num_epochs, train_loader, val_loader):
    best_acc = 0.0
    for epoch in range(num_epochs):
        print(f'\nEpoch {epoch+1}/{num_epochs}')
        print('-' * 30)

        # Training phase
        model.train()
        running_loss = 0.0
        top1_acc_train = 0.0
        top3_acc_train = 0.0
        top5_acc_train = 0.0

        for inputs, labels in tqdm(train_loader, desc='Training'):
            inputs = inputs.to(device)
            labels = labels.to(device)

            optimizer.zero_grad()

            # Forward pass
            outputs = model(inputs)

            loss = criterion(outputs, labels)

            # Backward pass and optimization
            loss.backward()
            optimizer.step()

            # Statistics
            running_loss += loss.item() * inputs.size(0)
            acc1, acc3, acc5 = calculate_accuracy(outputs, labels, topk=(1, 3, 5))
            top1_acc_train += acc1 * inputs.size(0)
            top3_acc_train += acc3 * inputs.size(0)
            top5_acc_train += acc5 * inputs.size(0)

        epoch_loss = running_loss / len(train_loader.dataset)
        epoch_acc1 = top1_acc_train / len(train_loader.dataset)
        epoch_acc3 = top3_acc_train / len(train_loader.dataset)
        epoch_acc5 = top5_acc_train / len(train_loader.dataset)

        print(f'Train Loss: {epoch_loss:.4f} | Top-1 Acc: {epoch_acc1:.2f}% | Top-3 Acc: {epoch_acc3:.2f}% | Top-5 Acc: {epoch_acc5:.2f}%')

        # Validation phase
        model.eval()
        running_loss = 0.0
        top1_acc_val = 0.0
        top3_acc_val = 0.0
        top5_acc_val = 0.0

        with torch.no_grad():
            for inputs, labels in tqdm(val_loader, desc='Validation'):
                inputs = inputs.to(device)
                labels = labels.to(device)

                # Forward pass
                outputs = model(inputs)
                loss = criterion(outputs, labels)

                # Statistics
                running_loss += loss.item() * inputs.size(0)
                acc1, acc3, acc5 = calculate_accuracy(outputs, labels, topk=(1, 3, 5))
                top1_acc_val += acc1 * inputs.size(0)
                top3_acc_val += acc3 * inputs.size(0)
                top5_acc_val += acc5 * inputs.size(0)

        epoch_loss = running_loss / len(val_loader.dataset)
        epoch_acc1 = top1_acc_val / len(val_loader.dataset)
        epoch_acc3 = top3_acc_val / len(val_loader.dataset)
        epoch_acc5 = top5_acc_val / len(val_loader.dataset)

        print(f'Val Loss: {epoch_loss:.4f} | Top-1 Acc: {epoch_acc1:.2f}% | Top-3 Acc: {epoch_acc3:.2f}% | Top-5 Acc: {epoch_acc5:.2f}%')

        # Adjust learning rate
        scheduler.step()

        # Save the model if it has the best accuracy so far
        if epoch_acc1 > best_acc:
            best_acc = epoch_acc1
            torch.save(model.state_dict(), 'best_modified_vit_model.pth')

    print(f'\nBest Validation Top-1 Accuracy: {best_acc:.2f}%')

# Calculate number of parameters
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

# Run the training process
num_epochs = 20  # Adjust the number of epochs as needed
model_parameters = count_parameters(model)
print(f'\nTotal number of trainable parameters: {model_parameters}')

# Test with a dummy input to ensure the model works
dummy_input = torch.randn(2, 3, 224, 224).to(device)
output = model(dummy_input)
print(f'\nDummy output shape: {output.shape}')

# Start training
train_model(model, criterion, optimizer, scheduler, num_epochs, train_loader, val_loader)

# Evaluate on the test set
def evaluate_model(model, test_loader):
    model.eval()
    top1_acc_test = 0.0
    top3_acc_test = 0.0
    top5_acc_test = 0.0

    with torch.no_grad():
        for inputs, labels in tqdm(test_loader, desc='Testing'):
            inputs = inputs.to(device)
            labels = labels.to(device)

            outputs = model(inputs)

            acc1, acc3, acc5 = calculate_accuracy(outputs, labels, topk=(1, 3, 5))
            top1_acc_test += acc1 * inputs.size(0)
            top3_acc_test += acc3 * inputs.size(0)
            top5_acc_test += acc5 * inputs.size(0)

    epoch_acc1 = top1_acc_test / len(test_loader.dataset)
    epoch_acc3 = top3_acc_test / len(test_loader.dataset)
    epoch_acc5 = top5_acc_test / len(test_loader.dataset)

    print(f'\nTest Top-1 Accuracy: {epoch_acc1:.2f}%')
    print(f'Test Top-3 Accuracy: {epoch_acc3:.2f}%')
    print(f'Test Top-5 Accuracy: {epoch_acc5:.2f}%')

# Load the best model and evaluate
model.load_state_dict(torch.load('best_modified_vit_model.pth'))
model = model.to(device)
evaluate_model(model, test_loader)


Using device: cuda

Total number of trainable parameters: 86223560

Dummy output shape: torch.Size([2, 200])

Epoch 1/20
------------------------------


Training:   0%|          | 0/1563 [00:00<?, ?it/s]