In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.transforms as transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import numpy as np
from sklearn.preprocessing import LabelEncoder
from tqdm import tqdm
from sklearn.model_selection import train_test_split
import matplotlib.pyplot as plt
import torch.nn.functional as F

# Paths to the Tiny ImageNet dataset
TRAIN_DIR = 'tiny-imagenet-200/train'
VAL_DIR = 'tiny-imagenet-200/val'

# Hyperparameters
BATCH_SIZE = 64
NUM_EPOCHS = 20
LEARNING_RATE = 0.001
NUM_CLASSES = 200
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

class TinyImageNetDataset(Dataset):
    def __init__(self, data, labels, transform=None):
        self.data = data
        self.labels = labels
        self.transform = transform
        self.le = LabelEncoder()
        self.le.fit(labels)
        self.encoded_labels = self.le.transform(labels)
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        img_path = self.data[idx]
        image = Image.open(img_path).convert('RGB')
        label = self.encoded_labels[idx]
        
        if self.transform:
            image = self.transform(image)
        
        return image, label

def load_training_data():
    data = []
    labels = []
    for wnid in os.listdir(TRAIN_DIR):
        wnid_dir = os.path.join(TRAIN_DIR, wnid, 'images')
        if os.path.isdir(wnid_dir):
            for img_file in os.listdir(wnid_dir):
                img_path = os.path.join(wnid_dir, img_file)
                data.append(img_path)
                labels.append(wnid)
    return data, labels

def load_validation_data():
    val_data = []
    val_labels = []
    annotations_file = os.path.join(VAL_DIR, 'val_annotations.txt')
    val_annotations = {}
    with open(annotations_file, 'r') as f:
        for line in f:
            tokens = line.strip().split('\t')
            filename, wnid = tokens[0], tokens[1]
            val_annotations[filename] = wnid
    images_dir = os.path.join(VAL_DIR, 'images')
    for img_file in os.listdir(images_dir):
        img_path = os.path.join(images_dir, img_file)
        if img_file in val_annotations:
            val_data.append(img_path)
            val_labels.append(val_annotations[img_file])
    return val_data, val_labels

transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.480, 0.448, 0.398],
                         std=[0.277, 0.269, 0.282])
])

# Load training data
train_data, train_labels = load_training_data()
train_dataset = TinyImageNetDataset(train_data, train_labels, transform=transform)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=4)

# Load validation data
val_data, val_labels = load_validation_data()

# Split validation data into new validation and test sets
val_data, test_data, val_labels, test_labels = train_test_split(
    val_data, val_labels, test_size=0.5, random_state=42, stratify=val_labels)

# Create validation dataset and dataloader
val_dataset = TinyImageNetDataset(val_data, val_labels, transform=transform)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

# Create test dataset and dataloader
test_dataset = TinyImageNetDataset(test_data, test_labels, transform=transform)
test_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=4)

# Custom Multi-Head Attention with learnable alpha and beta
class CustomMultiheadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super(CustomMultiheadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout

        assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by num_heads."

        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5

        # Projection layers
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

        # Output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim)

        # Learnable parameters for combining attentions
        self.gamma = nn.Parameter(torch.tensor([0.5, 0.5]))  # [alpha, beta]

    def forward(self, query, key, value, num_patches1):
        B, N, C = query.shape  # Batch size, sequence length, embedding dimension

        # Linear projections
        Q = self.q_proj(query)
        K = self.k_proj(key)
        V = self.v_proj(value)

        # Reshape for multi-head attention
        Q = Q.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)  # (B, num_heads, N, head_dim)
        K = K.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
        V = V.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)

        # Compute attention scores
        attn_scores = torch.matmul(Q, K.transpose(-2, -1)) * self.scale  # (B, num_heads, N, N)

        # Apply softmax to get attention weights
        attn_weights = F.softmax(attn_scores, dim=-1)  # (B, num_heads, N, N)
        attn_weights = F.dropout(attn_weights, p=self.dropout, training=self.training)

        # Apply alpha and beta to attention weights before multiplying with V
        gamma = torch.softmax(self.gamma, dim=0)  # Ensure alpha and beta sum to 1
        alpha, beta = gamma[0], gamma[1]

        # Adjusted for CLS token
        num_tokens = attn_weights.size(-1)
        cls_attn_weights = attn_weights[:, :, :, 0:1]  # CLS token attention weights
        patch_attn_weights = attn_weights[:, :, :, 1:]  # Exclude CLS token

        # Split attention weights and V according to the number of patches
        attn_weights1 = patch_attn_weights[:, :, :, :num_patches1] * alpha
        attn_weights2 = patch_attn_weights[:, :, :, num_patches1:] * beta

        V1 = V[:, :, 1:num_patches1+1, :]
        V2 = V[:, :, num_patches1+1:, :]

        # Compute attention outputs
        attn_output1 = torch.matmul(attn_weights1, V1)
        attn_output2 = torch.matmul(attn_weights2, V2)

        # Concatenate CLS token attention outputs
        attn_output_cls = torch.matmul(cls_attn_weights, V[:, :, :1, :])  # CLS token
        attn_output = torch.cat([attn_output_cls, attn_output1, attn_output2], dim=2)  # (B, num_heads, N, head_dim)
        attn_output = attn_output.transpose(1, 2).reshape(B, N, self.embed_dim)

        # Final linear projection
        output = self.out_proj(attn_output)

        return output

# Custom Transformer Encoder Layer using CustomMultiheadAttention
class CustomTransformerEncoderLayer(nn.Module):
    def __init__(self, embed_dim, num_heads, dim_feedforward=2048, dropout=0.1):
        super(CustomTransformerEncoderLayer, self).__init__()
        self.self_attn = CustomMultiheadAttention(embed_dim, num_heads, dropout=dropout)
        self.linear1 = nn.Linear(embed_dim, dim_feedforward)
        self.dropout = nn.Dropout(dropout)
        self.linear2 = nn.Linear(dim_feedforward, embed_dim)

        self.norm1 = nn.LayerNorm(embed_dim)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.dropout2 = nn.Dropout(dropout)

        self.activation = F.relu  # You can choose a different activation function if desired

    def forward(self, src, num_patches1):
        src2 = self.self_attn(src, src, src, num_patches1)
        src = src + self.dropout1(src2)
        src = self.norm1(src)

        src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
        src = src + self.dropout2(src2)
        src = self.norm2(src)

        return src

# ViT model
class ViT(nn.Module):
    def __init__(self, img_size=64, patch_size1=16, patch_size2=8, num_classes=NUM_CLASSES, dim=512, depth=6, heads=8, mlp_dim=1024):
        super(ViT, self).__init__()

        # For patch size 16x16
        num_patches1 = (img_size // patch_size1) ** 2  # 16 patches
        patch_dim1 = 3 * patch_size1 * patch_size1
        self.patch_size1 = patch_size1

        # For patch size 8x8
        num_patches2 = (img_size // patch_size2) ** 2  # 64 patches
        patch_dim2 = 3 * patch_size2 * patch_size2
        self.patch_size2 = patch_size2

        # Total number of patches after processing
        self.num_patches1 = num_patches1 * 4  # After repeating each 16x16 patch 4 times
        self.num_patches2 = num_patches2
        total_patches = self.num_patches1 + self.num_patches2

        # Linear projection layers
        self.to_patch_embedding1 = nn.Linear(patch_dim1, dim)
        self.to_patch_embedding2 = nn.Linear(patch_dim2, dim)

        # Positional embeddings
        self.pos_embedding = nn.Parameter(torch.randn(1, total_patches + 1, dim))  # +1 for CLS token

        # CLS token
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))

        # Transformer Encoder
        encoder_layer = CustomTransformerEncoderLayer(embed_dim=dim, num_heads=heads, dim_feedforward=mlp_dim)
        self.transformer = nn.ModuleList([encoder_layer for _ in range(depth)])

        # Classification head
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, x):
        B, C, H, W = x.shape

        # First Pass: Patches of size 16x16
        patches1 = x.unfold(2, self.patch_size1, self.patch_size1).unfold(3, self.patch_size1, self.patch_size1)
        patches1 = patches1.contiguous().view(B, C, -1, self.patch_size1, self.patch_size1)
        patches1 = patches1.permute(0, 2, 1, 3, 4)
        patches1 = patches1.reshape(B, -1, 3 * self.patch_size1 * self.patch_size1)

        # Repeat each patch 4 times
        patches1 = patches1.unsqueeze(2).repeat(1, 1, 4, 1).reshape(B, -1, 3 * self.patch_size1 * self.patch_size1)

        # Embedding
        tokens1 = self.to_patch_embedding1(patches1)

        # Second Pass: Patches of size 8x8 with custom order
        patches2 = x.unfold(2, self.patch_size2, self.patch_size2).unfold(3, self.patch_size2, self.patch_size2)
        patches2 = patches2.contiguous().view(B, C, -1, self.patch_size2, self.patch_size2)
        patches2 = patches2.permute(0, 2, 1, 3, 4)
        patches2 = patches2.reshape(B, -1, 3 * self.patch_size2 * self.patch_size2)

        # Reordering patches
        order = [0,1,8,9,2,3,10,11,4,5,12,13,6,7,14,15,
                 16,17,24,25,18,19,26,27,20,21,28,29,22,23,30,31,
                 32,33,40,41,34,35,42,43,36,37,44,45,38,39,46,47,
                 48,49,56,57,50,51,58,59,52,53,60,61,54,55,62,63]
        order = torch.tensor(order).to(x.device)
        patches2 = patches2[:, order, :]

        # Embedding
        tokens2 = self.to_patch_embedding2(patches2)

        # Combine tokens
        tokens = torch.cat((tokens1, tokens2), dim=1)  # (B, total_patches, dim)

        # Add CLS token
        cls_tokens = self.cls_token.expand(B, -1, -1)  # (B, 1, dim)
        tokens = torch.cat((cls_tokens, tokens), dim=1)  # (B, total_patches + 1, dim)

        # Add positional embeddings
        tokens += self.pos_embedding[:, :tokens.size(1), :]

        # Transformer encoding
        for encoder_layer in self.transformer:
            tokens = encoder_layer(tokens, self.num_patches1)

        # Classification using CLS token
        cls_token_final = tokens[:, 0, :]  # (B, dim)
        out = self.mlp_head(cls_token_final)

        return out

def accuracy(output, target, topk=(1, 3, 5)):
    """Computes the accuracy over the k top predictions"""
    maxk = max(topk)
    batch_size = target.size(0)

    _, pred = output.topk(maxk, 1, True, True)
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))

    res = []
    for k in topk:
        correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))
    return res

def train(model, device, train_loader, criterion, optimizer, epoch):
    model.train()
    running_loss = 0.0
    top1_acc = 0.0
    top3_acc = 0.0
    top5_acc = 0.0
    total = 0

    loop = tqdm(train_loader)
    for batch_idx, (inputs, targets) in enumerate(loop):
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        acc1, acc3, acc5 = accuracy(outputs, targets, topk=(1, 3, 5))
        top1_acc += acc1.item()
        top3_acc += acc3.item()
        top5_acc += acc5.item()
        total += 1

        # Update tqdm loop
        loop.set_description(f"Epoch [{epoch}/{NUM_EPOCHS}]")
        loop.set_postfix(loss=(running_loss/total), top1_acc=(top1_acc/total), top3_acc=(top3_acc/total), top5_acc=(top5_acc/total))

    return running_loss/total, top1_acc/total, top3_acc/total, top5_acc/total

def validate(model, device, val_loader, criterion):
    model.eval()
    running_loss = 0.0
    top1_acc = 0.0
    top3_acc = 0.0
    top5_acc = 0.0
    total = 0

    with torch.no_grad():
        loop = tqdm(val_loader)
        for batch_idx, (inputs, targets) in enumerate(loop):
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets)

            running_loss += loss.item()
            acc1, acc3, acc5 = accuracy(outputs, targets, topk=(1, 3, 5))
            top1_acc += acc1.item()
            top3_acc += acc3.item()
            top5_acc += acc5.item()
            total += 1

            # Update tqdm loop
            loop.set_description(f"Validation")
            loop.set_postfix(loss=(running_loss/total), top1_acc=(top1_acc/total), top3_acc=(top3_acc/total), top5_acc=(top5_acc/total))

    return running_loss/total, top1_acc/total, top3_acc/total, top5_acc/total

def test(model, device, test_loader, criterion):
    model.eval()
    running_loss = 0.0
    top1_acc = 0.0
    top3_acc = 0.0
    top5_acc = 0.0
    total = 0

    with torch.no_grad():
        loop = tqdm(test_loader)
        for batch_idx, (inputs, targets) in enumerate(loop):
            inputs, targets = inputs.to(device), targets.to(device)

            outputs = model(inputs)
            loss = criterion(outputs, targets)

            running_loss += loss.item()
            acc1, acc3, acc5 = accuracy(outputs, targets, topk=(1, 3, 5))
            top1_acc += acc1.item()
            top3_acc += acc3.item()
            top5_acc += acc5.item()
            total += 1

            # Update tqdm loop
            loop.set_description(f"Test")
            loop.set_postfix(loss=(running_loss/total), top1_acc=(top1_acc/total), top3_acc=(top3_acc/total), top5_acc=(top5_acc/total))

    avg_loss = running_loss / total
    avg_top1_acc = top1_acc / total
    avg_top3_acc = top3_acc / total
    avg_top5_acc = top5_acc / total

    print(f"Test Loss: {avg_loss:.4f}")
    print(f"Test Acc@1: {avg_top1_acc:.2f}%")
    print(f"Test Acc@3: {avg_top3_acc:.2f}%")
    print(f"Test Acc@5: {avg_top5_acc:.2f}%")

    return avg_loss, avg_top1_acc, avg_top3_acc, avg_top5_acc

model = ViT().to(DEVICE)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)

train_losses = []
val_losses = []
train_top1_acc = []
val_top1_acc = []
train_top3_acc = []
val_top3_acc = []
train_top5_acc = []
val_top5_acc = []

best_acc = 0.0

for epoch in range(1, NUM_EPOCHS + 1):
    train_loss, train_acc1, train_acc3, train_acc5 = train(model, DEVICE, train_loader, criterion, optimizer, epoch)
    val_loss, val_acc1, val_acc3, val_acc5 = validate(model, DEVICE, val_loader, criterion)

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_top1_acc.append(train_acc1)
    val_top1_acc.append(val_acc1)
    train_top3_acc.append(train_acc3)
    val_top3_acc.append(val_acc3)
    train_top5_acc.append(train_acc5)
    val_top5_acc.append(val_acc5)

    # Save the best model
    if val_acc1 > best_acc:
        best_acc = val_acc1
        torch.save(model.state_dict(), 'best_model.pth')
        print(f"Best model saved with accuracy: {best_acc:.2f}%")

    print(f"Epoch {epoch}: Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}")
    print(f"Train Acc@1: {train_acc1:.2f}%, Val Acc@1: {val_acc1:.2f}%")
    print(f"Train Acc@3: {train_acc3:.2f}%, Val Acc@3: {val_acc3:.2f}%")
    print(f"Train Acc@5: {train_acc5:.2f}%, Val Acc@5: {val_acc5:.2f}%")

def plot_metrics(train_metric, val_metric, metric_name):
    epochs = range(1, NUM_EPOCHS + 1)
    plt.figure()
    plt.plot(epochs, train_metric, 'b', label=f'Training {metric_name}')
    plt.plot(epochs, val_metric, 'r', label=f'Validation {metric_name}')
    plt.title(f'Training and Validation {metric_name}')
    plt.xlabel('Epochs')
    plt.ylabel(metric_name)
    plt.legend()
    plt.savefig(f'{metric_name}.png')
    plt.show()

# Plot Loss
plot_metrics(train_losses, val_losses, 'Loss')

# Plot Top-1 Accuracy
plot_metrics(train_top1_acc, val_top1_acc, 'Top-1 Accuracy')

# Plot Top-3 Accuracy
plot_metrics(train_top3_acc, val_top3_acc, 'Top-3 Accuracy')

# Plot Top-5 Accuracy
plot_metrics(train_top5_acc, val_top5_acc, 'Top-5 Accuracy')

print("Model Parameters:")
for name, param in model.named_parameters():
    if param.requires_grad:
        print(f"{name}: {param.data.shape}")

# Load the best model
model.load_state_dict(torch.load('best_model.pth'))

# Evaluate on test set
test_loss, test_acc1, test_acc3, test_acc5 = test(model, DEVICE, test_loader, criterion)


  0%|          | 0/1563 [00:00<?, ?it/s]