In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader, Dataset
from PIL import Image
from einops import rearrange
from tqdm import tqdm

# Paths to the dataset
train_dir = 'tiny-imagenet-200/train'
val_dir = 'tiny-imagenet-200/val'

# Define transformations
transform = transforms.Compose([
    transforms.Resize((64, 64)),
    transforms.ToTensor(),
])

# Custom Dataset for Tiny ImageNet Validation Set
class TinyImageNetValDataset(Dataset):
    def __init__(self, val_dir, transform=None):
        self.transform = transform
        self.val_dir = val_dir
        self.img_dir = os.path.join(val_dir, 'images')
        self.annotations = os.path.join(val_dir, 'val_annotations.txt')
        self.data = []

        # Read the annotations file and create a list of (image_path, label) tuples
        with open(self.annotations, 'r') as f:
            for line in f:
                tokens = line.strip().split('\t')
                img_name = tokens[0]
                label = tokens[1]
                img_path = os.path.join(self.img_dir, img_name)
                self.data.append((img_path, label))

        # Create a mapping from label names to indices
        self.label_to_idx = {label: idx for idx, label in enumerate(sorted(set([label for _, label in self.data])))}
        self.idx_to_label = {idx: label for label, idx in self.label_to_idx.items()}

        # Update labels to indices
        self.data = [(img_path, self.label_to_idx[label]) for img_path, label in self.data]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

# Create a mapping from label names to indices for the training set
def get_train_label_mapping(train_dir):
    classes = sorted(os.listdir(train_dir))
    label_to_idx = {label: idx for idx, label in enumerate(classes)}
    idx_to_label = {idx: label for label, idx in label_to_idx.items()}
    return label_to_idx, idx_to_label

# Custom Dataset for Tiny ImageNet Training Set (to ensure consistent label mapping)
class TinyImageNetTrainDataset(Dataset):
    def __init__(self, train_dir, label_to_idx, transform=None):
        self.transform = transform
        self.train_dir = train_dir
        self.data = []
        self.label_to_idx = label_to_idx

        # Iterate over each class folder
        for label in os.listdir(train_dir):
            class_dir = os.path.join(train_dir, label, 'images')
            img_files = os.listdir(class_dir)
            for img_name in img_files:
                img_path = os.path.join(class_dir, img_name)
                self.data.append((img_path, self.label_to_idx[label]))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_path, label = self.data[idx]
        image = Image.open(img_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image, label

# Get label mappings
train_label_to_idx, train_idx_to_label = get_train_label_mapping(train_dir)

# Load datasets
train_dataset = TinyImageNetTrainDataset(train_dir, train_label_to_idx, transform=transform)
val_dataset = TinyImageNetValDataset(val_dir, transform=transform)

# Ensure that the validation set uses the same label mapping as the training set
val_dataset.label_to_idx = train_label_to_idx
val_dataset.idx_to_label = train_idx_to_label
val_dataset.data = [(img_path, train_label_to_idx[val_dataset.idx_to_label[label]]) for img_path, label in val_dataset.data]

# Data loaders
batch_size = 64
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

class TransformerEncoderLayer(nn.Module):
    def __init__(self, dim, heads, mlp_dim, dropout, seq_length):
        super(TransformerEncoderLayer, self).__init__()
        self.seq_length = seq_length
        self.norm1 = nn.LayerNorm(dim)
        self.attn = nn.MultiheadAttention(embed_dim=dim, num_heads=heads, dropout=dropout)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = nn.Sequential(
            nn.Linear(dim, mlp_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(mlp_dim, dim),
            nn.Dropout(dropout),
        )

    def forward(self, x):
        # Self-attention
        x_norm = self.norm1(x)
        x_attn, _ = self.attn(x_norm.transpose(0, 1), x_norm.transpose(0, 1), x_norm.transpose(0, 1))
        x = x + x_attn.transpose(0, 1)

        # Feed-forward
        x = x + self.mlp(self.norm2(x))
        return x

class AlternatingViT(nn.Module):
    def __init__(self, num_classes=200, dim=256, depth=6, heads=8, mlp_dim=512, dropout=0.1):
        super(AlternatingViT, self).__init__()
        self.dim = dim
        self.depth = depth
        self.patch_sizes = [16, 8] * (depth // 2)

        # Positional embeddings for different patch sizes
        self.pos_embedding_16 = nn.Parameter(torch.randn(1, 16, dim))
        self.pos_embedding_64 = nn.Parameter(torch.randn(1, 64, dim))

        # Patch embeddings
        self.patch_embed_16 = nn.Conv2d(3, dim, kernel_size=16, stride=16)
        self.patch_embed_8 = nn.Conv2d(3, dim, kernel_size=8, stride=8)

        # Transformer encoder layers
        self.transformer_layers = nn.ModuleList()
        for i in range(depth):
            patch_size = self.patch_sizes[i]
            seq_length = (64 // patch_size) ** 2
            layer = TransformerEncoderLayer(dim, heads, mlp_dim, dropout, seq_length)
            self.transformer_layers.append(layer)

        # Classification head
        self.to_cls_token = nn.Identity()
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )

    def forward(self, x):
        batch_size = x.size(0)

        # Initial patch embeddings and positional encodings
        x_16 = self.patch_embed_16(x)  # [B, D, 4, 4]
        x_16 = x_16.flatten(2).transpose(1, 2)  # [B, 16, D]
        x_16 = x_16 + self.pos_embedding_16

        x_8 = self.patch_embed_8(x)  # [B, D, 8, 8]
        x_8 = x_8.flatten(2).transpose(1, 2)  # [B, 64, D]
        x_8 = x_8 + self.pos_embedding_64

        # Start with sequence of length 16
        seq = x_16

        for i, layer in enumerate(self.transformer_layers):
            seq_length = seq.size(1)

            # Apply transformer layer
            seq = layer(seq)

            # Rearrangement between layers if necessary
            if i < self.depth - 1:
                next_seq_length = self.transformer_layers[i + 1].seq_length
                if seq_length != next_seq_length:
                    seq = self.rearrange_tokens(seq, seq_length, next_seq_length)

        # Classification head
        seq = self.to_cls_token(seq.mean(dim=1))  # Global average pooling
        out = self.mlp_head(seq)
        return out

    def rearrange_tokens(self, seq, seq_length, next_seq_length):
        # Reshape tokens to match next sequence length
        if seq_length == 16 and next_seq_length == 64:
            # Expand sequence from 16 to 64
            seq = rearrange(seq, 'b (h w) d -> b d h w', h=4, w=4)
            seq = nn.functional.interpolate(seq, scale_factor=2, mode='nearest')
            seq = rearrange(seq, 'b d h w -> b (h w) d')
        elif seq_length == 64 and next_seq_length == 16:
            # Reduce sequence from 64 to 16
            seq = rearrange(seq, 'b (h w) d -> b d h w', h=8, w=8)
            seq = nn.functional.avg_pool2d(seq, kernel_size=2, stride=2)
            seq = rearrange(seq, 'b d h w -> b (h w) d')
        return seq

def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    running_loss = 0.0
    for images, labels in tqdm(train_loader, desc=f"Training Epoch {epoch}"):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
    print(f"Epoch [{epoch}], Loss: {running_loss / len(train_loader):.4f}")

def evaluate(model, device, data_loader, mode='Validation'):
    model.eval()
    correct_top1 = 0
    correct_top3 = 0
    correct_top5 = 0
    total = 0
    with torch.no_grad():
        for images, labels in tqdm(data_loader, desc=f"Evaluating {mode}"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, pred = outputs.topk(5, 1, largest=True, sorted=True)

            total += labels.size(0)
            correct = pred.eq(labels.view(-1, 1).expand_as(pred))

            correct_top1 += correct[:, :1].sum().item()
            correct_top3 += correct[:, :3].sum().item()
            correct_top5 += correct[:, :5].sum().item()

    top1_acc = 100 * correct_top1 / total
    top3_acc = 100 * correct_top3 / total
    top5_acc = 100 * correct_top5 / total
    print(f"{mode} Top-1 Accuracy: {top1_acc:.2f}%")
    print(f"{mode} Top-3 Accuracy: {top3_acc:.2f}%")
    print(f"{mode} Top-5 Accuracy: {top5_acc:.2f}%")

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AlternatingViT(num_classes=200).to(device)

print(f"Number of parameters: {count_parameters(model)}")

# Define optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# Training loop
num_epochs = 10
for epoch in range(1, num_epochs + 1):
    train(model, device, train_loader, optimizer, criterion, epoch)
    evaluate(model, device, train_loader, mode='Training')    # Training accuracy
    evaluate(model, device, val_loader, mode='Validation')    # Validation accuracy


Number of parameters: 3481288


Training Epoch 1:   0%|          | 0/1563 [00:00<?, ?it/s]

: 