In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from tqdm import tqdm
import matplotlib.pyplot as plt
from collections import deque

# Reproducibility
import random
seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

In [2]:
# Training Configuration
epochs = 10
batch_size = 64
lr = 3e-4
weight_decay = 1e-3
device = "mps"
checkpoint_filepath = ""  # Set to a path if you want to load a checkpoint
save_dir = "checkpoints"
dataset_filepath = "./ImagenetHighResolution"
import os
os.makedirs(save_dir, exist_ok=True)

In [3]:
class PatchEmbedding(nn.Module):
    """
    Module that converts image patches to embeddings for Vision Transformer.
    """
    def __init__(self,
                 image_size: tuple = (64, 72),
                 patch_size: int = 16,
                 in_channels: int = 3,
                 embedding_dim: int = 1024):
        super().__init__()
        self.image_size = image_size
        self.patch_size = patch_size
        self.in_channels = in_channels

        # Calculate number of patches
        self.num_patches = (image_size[0] // patch_size) * (image_size[1] // patch_size)

        # Create projection for converting patches to embeddings
        self.projection = nn.Conv2d(
            in_channels=in_channels,
            out_channels=embedding_dim,
            kernel_size=patch_size,
            stride=patch_size
        )

        # CLS token embedding
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embedding_dim))

        # Positional embedding (Normal distribution initialization of value)
        self.positions = nn.Parameter(torch.zeros(1, self.num_patches + 1, embedding_dim))
        nn.init.trunc_normal_(self.positions, std=0.02)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size = x.shape[0]

        # Convert image to patches and project to embedding dimension
        # x shape: [batch_size, channels, height, width]
        x = self.projection(x)
        # x shape: [batch_size, embedding_dim, height/patch_size, width/patch_size]

        # Flatten patches to sequence
        x = x.flatten(2).transpose(1, 2)
        # x shape: [batch_size, num_patches, embedding_dim]

        # Add CLS token
        cls_tokens = self.cls_token.expand(batch_size, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)

        # Add positional embeddings
        x = x + self.positions

        return x


class VisionAttention(nn.Module):
    def __init__(self,
                 hidden_dim: int,
                 head_dim: int,
                 q_head: int,
                 kv_head: int,
                 lora_rank: int = 16,
                 lora_alpha: int = 32,
                 dropout_ratio: float = 0.1):
        super().__init__()
        self.head_dim = head_dim
        self.q_head = q_head
        self.kv_head = kv_head
        self.qkv = nn.Linear(hidden_dim, (q_head+kv_head*2)*head_dim)
        self.o = nn.Linear(q_head*head_dim, hidden_dim)
        self.scaler = 1/math.sqrt(head_dim)
        self.attn_dropout = nn.Dropout(dropout_ratio)

        # LoRA
        self.lora_scale = lora_alpha / lora_rank
        self.lora_qkv_a = nn.Linear(hidden_dim, lora_rank)
        self.lora_qkv_b = nn.Linear(lora_rank, (q_head+kv_head*2)*head_dim)
        self.lora_o_a = nn.Linear(q_head*head_dim, lora_rank)
        self.lora_o_b = nn.Linear(lora_rank, hidden_dim)

        if q_head != kv_head:
            # If we are using multi query attention
            assert q_head % kv_head == 0
            self.multi_query_attention = True
            self.q_kv_scale = q_head//kv_head
        else:
            self.multi_query_attention = False

    def forward(self, tensor: torch.Tensor, fine_tuning: bool = False) -> torch.Tensor:
        batch_size, seq_len, hid_dim = tensor.shape

        qkv_tensor = self.qkv(tensor)
        if fine_tuning:
            lora_tensor = self.lora_qkv_a(tensor)
            lora_tensor = self.lora_qkv_b(lora_tensor)
            lora_tensor = lora_tensor * self.lora_scale
            qkv_tensor = lora_tensor + qkv_tensor
        query, key, value = qkv_tensor.split([self.head_dim*self.q_head, self.head_dim*self.kv_head, self.head_dim*self.kv_head], dim=-1)

        query = query.contiguous().view(batch_size, seq_len, self.q_head, self.head_dim)
        key = key.contiguous().view(batch_size, seq_len, self.kv_head, self.head_dim)
        value = value.contiguous().view(batch_size, seq_len, self.kv_head, self.head_dim)

        if self.multi_query_attention:
            # Efficient broadcasting instead of repeat_interleave
            key = key.view(batch_size, seq_len, self.kv_head, 1, self.head_dim)
            key = key.expand(-1, -1, -1, self.q_kv_scale, -1)
            key = key.reshape(batch_size, seq_len, self.q_head, self.head_dim)
            
            value = value.view(batch_size, seq_len, self.kv_head, 1, self.head_dim)
            value = value.expand(-1, -1, -1, self.q_kv_scale, -1)
            value = value.reshape(batch_size, seq_len, self.q_head, self.head_dim)

        # Switch to batch_size, head, seq_len, head_dim
        query = query.transpose(1, 2)
        key = key.transpose(1, 2)
        value = value.transpose(1, 2)

        # Classic self attention without attention mask ViT
        attention_raw = torch.matmul(query, key.transpose(2, 3))
        attention_scaled = attention_raw * self.scaler
        attention_score = torch.softmax(attention_scaled, dim=-1)
        attention_score = self.attn_dropout(attention_score)
        value = torch.matmul(attention_score, value)

        # Reshape back to batch_size, seq_len, hid_dim
        value = value.transpose(1, 2).contiguous()
        value = value.view(batch_size, seq_len, hid_dim)

        # Output layer
        output = self.o(value)
        if fine_tuning:
            lora_tensor = self.lora_o_a(value)
            lora_tensor = self.lora_o_b(lora_tensor)
            lora_tensor = lora_tensor * self.lora_scale
            output = lora_tensor + output

        output = self.attn_dropout(output)
        return output


class FeedForward(nn.Module):
    def __init__(self,
                 hidden_size: int,
                 expansion_factor: int = 4,
                 dropout_ratio: float = 0.1,
                 lora_rank: int = 16,
                 lora_alpha: int = 32):
        super().__init__()
        self.gate_and_up = nn.Linear(hidden_size, hidden_size * expansion_factor * 2)
        self.down = nn.Linear(hidden_size * expansion_factor, hidden_size)
        self.dropout = nn.Dropout(p=dropout_ratio)

        # LoRA
        self.lora_scale = lora_alpha / lora_rank
        self.lora_gate_and_up_a = nn.Linear(hidden_size, lora_rank)
        self.lora_gate_and_up_b = nn.Linear(lora_rank, hidden_size * expansion_factor * 2)
        self.lora_down_a = nn.Linear(hidden_size * expansion_factor, lora_rank)
        self.lora_down_b = nn.Linear(lora_rank, hidden_size)

    def forward(self, tensor: torch.Tensor, fine_tuning: bool = False) -> torch.Tensor:
        gate_and_up = self.gate_and_up(tensor)
        if fine_tuning:
            lora_tensor = self.lora_gate_and_up_a(tensor)
            lora_tensor = self.lora_gate_and_up_b(lora_tensor)
            lora_tensor = lora_tensor * self.lora_scale
            gate_and_up = gate_and_up + lora_tensor
        gate, up = gate_and_up.chunk(chunks=2, dim=-1)
        gate = F.gelu(gate, approximate="tanh")
        tensor = gate * up
        tensor = self.dropout(tensor)
        down_tensor = self.down(tensor)
        if fine_tuning:
            lora_tensor = self.lora_down_a(tensor)
            lora_tensor = self.lora_down_b(lora_tensor)
            lora_tensor = lora_tensor * self.lora_scale
            down_tensor = down_tensor + lora_tensor
        return down_tensor


class MOE(nn.Module):
    def __init__(self, hidden_size: int, num_experts: int = 8, expansion_factor: int = 4, dropout_ratio: float = 0.1, lora_rank: int = 16, lora_alpha: int = 32):
        super().__init__()
        self.gate = nn.Linear(hidden_size, num_experts)
        self.num_experts = num_experts
        self.experts = nn.ModuleList([FeedForward(hidden_size, expansion_factor=expansion_factor, dropout_ratio=dropout_ratio, lora_rank=lora_rank, lora_alpha=lora_alpha) for _ in range(num_experts)])

    def forward(self, tensor: torch.Tensor, fine_tuning: bool = False) -> tuple[torch.Tensor, torch.Tensor]:
        # Flatten for better manipulation, this is ok because tokens are independent at this stage
        batch_size, seq_len, hidden_size = tensor.shape
        flat_tensor = tensor.reshape(batch_size * seq_len, hidden_size)

        # Pass through the gating network and select experts
        tensor = self.gate(flat_tensor)
        tensor = F.softmax(tensor, dim=-1)

        # The output of this step is a tensor of shape [batch_size * seq_len, 2] with element i in the second dimension representing ith expert selected for this token
        value_tensor, index_tensor = tensor.topk(k=2, dim=-1)

        # Find the load balancing loss
        counts = torch.bincount(index_tensor[:, 0], minlength=self.num_experts)
        frequencies = counts.float() / (batch_size * seq_len) # This is the hard one-hot frequency
        probability = tensor.mean(0) # This is the soft probability
        load_balancing_loss = (probability * frequencies).mean() * float(self.num_experts ** 2)

        # Normalize top1 and top2 score
        top_expert_score = value_tensor[:, 0]
        second_expert_score = value_tensor[:, 1]
        total_score = top_expert_score + second_expert_score
        top_expert_score = top_expert_score / total_score
        second_expert_score = second_expert_score / total_score

        # Split into top 2 experts
        split_tensors = torch.split(index_tensor, 1, dim=-1)
        top_expert, second_expert = split_tensors[0], split_tensors[1]
        indices = torch.arange(batch_size * seq_len).unsqueeze(-1).to(tensor.device)
        top_expert = torch.cat((indices, top_expert), dim=-1)
        second_expert = torch.cat((indices, second_expert), dim=-1)

        # Sort based on expert selection
        top_expert = top_expert[top_expert[:,1].argsort()]
        second_expert = second_expert[second_expert[:,1].argsort()]

        # Count how many tokens goes to each expert
        top_expert_counts = torch.zeros(self.num_experts, dtype=int)
        for i in range(self.num_experts):
            top_expert_counts[i] = (top_expert[:,1] == i).sum()
        top_expert_counts = top_expert_counts.tolist()

        second_expert_counts = torch.zeros(self.num_experts, dtype=int)
        for i in range(self.num_experts):
            second_expert_counts[i] = (second_expert[:,1] == i).sum()
        second_expert_counts = second_expert_counts.tolist()

        # Split input tokens for each expert
        top_expert_tokens = flat_tensor[top_expert[:,0]]
        second_expert_tokens = flat_tensor[second_expert[:,0]]

        # Split into a list of tensors, element i tensor is for ith expert.
        top_expert_tokens = torch.split(top_expert_tokens, top_expert_counts, dim=0)
        second_expert_tokens = torch.split(second_expert_tokens, second_expert_counts, dim=0)

        # Input into each expert and obtain results in a list
        top_expert_outputs = [self.experts[i](top_expert_tokens[i], fine_tuning) if top_expert_counts[i] > 0 else torch.zeros(0, hidden_size, dtype=torch.float16).to(tensor.device) for i in range(self.num_experts)]
        second_expert_outputs = [self.experts[i](second_expert_tokens[i], fine_tuning) if second_expert_counts[i] > 0 else torch.zeros(0, hidden_size, dtype=torch.float16).to(tensor.device) for i in range(self.num_experts)]

        # Combine outputs
        top_expert_outputs = torch.cat(top_expert_outputs, dim=0)
        second_expert_outputs = torch.cat(second_expert_outputs, dim=0)

        # Re-index the output back to original token order
        # flat_top_expert_tensor = torch.zeros_like(flat_tensor, dtype=torch.float32).to(tensor.device)
        # flat_top_expert_tensor.index_copy_(0, top_expert[:, 0], top_expert_outputs)

        # flat_second_expert_tensor = torch.zeros_like(flat_tensor, dtype=torch.float32).to(tensor.device)
        # flat_second_expert_tensor.index_copy_(0, second_expert[:, 0], second_expert_outputs)
        flat_top_expert_tensor = torch.zeros_like(flat_tensor, dtype=torch.float32).to("cpu")
        flat_top_expert_tensor = flat_top_expert_tensor.index_copy_(0, top_expert[:, 0].to("cpu"), top_expert_outputs.to("cpu")).to(tensor.device)
        flat_second_expert_tensor = torch.zeros_like(flat_tensor, dtype=torch.float32).to("cpu")
        flat_second_expert_tensor = flat_second_expert_tensor.index_copy_(0, second_expert[:, 0].to("cpu"), second_expert_outputs.to("cpu")).to(tensor.device)

        # Find final output tensor based on weight between top and second expert
        final_tensor = top_expert_score.unsqueeze(-1) * flat_top_expert_tensor + second_expert_score.unsqueeze(-1) * flat_second_expert_tensor

        # Reshape to original [batch_size, seq_len, hidden_size]
        final_tensor = final_tensor.contiguous().reshape(batch_size, seq_len, hidden_size)

        return final_tensor, load_balancing_loss


class VisionLayer(nn.Module):
    def __init__(self,
                 hidden_dim: int,
                 head_dim: int,
                 q_head: int,
                 kv_head: int,
                 expansion_factor: int = 4,
                 dropout_ratio: float = 0.1,
                 use_moe: bool = False,
                 num_experts: int = 8,
                 lora_rank: int = 16,
                 lora_alpha:int = 32):
        super().__init__()
        self.use_moe = use_moe

        self.norm1 = nn.LayerNorm(hidden_dim)
        self.attention = VisionAttention(hidden_dim, head_dim, q_head, kv_head, lora_rank=lora_rank, lora_alpha=lora_alpha, dropout_ratio=dropout_ratio)

        self.norm2 = nn.LayerNorm(hidden_dim)
        if self.use_moe:
            self.moe = MOE(hidden_dim, num_experts=num_experts, expansion_factor=expansion_factor,
                           dropout_ratio=dropout_ratio, lora_rank=lora_rank)
        else:
            self.ffn = FeedForward(hidden_dim, expansion_factor=expansion_factor, dropout_ratio=dropout_ratio,
                                  lora_rank=lora_rank, lora_alpha=lora_alpha)

    def forward(self, tensor: torch.Tensor, fine_tuning: bool = False):
        skip_connection = tensor
        tensor = self.norm1(tensor)
        tensor = self.attention(tensor, fine_tuning=fine_tuning)
        tensor += skip_connection

        skip_connection = tensor
        tensor = self.norm2(tensor)
        if self.use_moe:
            tensor, load_balancing_loss = self.moe(tensor, fine_tuning=fine_tuning)
        else:
            tensor = self.ffn(tensor, fine_tuning=fine_tuning)
            load_balancing_loss = torch.tensor(0.0, dtype=tensor.dtype, device=tensor.device)

        tensor += skip_connection

        return tensor, load_balancing_loss


class VisionTransformer(nn.Module):
    def __init__(self,
                 image_size: tuple,
                 num_classes: int = 1,
                 patch_size: int = 8,
                 in_channels: int = 3,
                 num_layer: int = 3,
                 hidden_dim: int = 1024,
                 expansion_factor: int = 8,
                 head_dim: int = 64,
                 q_head: int = 16,
                 kv_head: int = 4,
                 dropout_ratio: float = 0.1,
                 use_moe: bool = True,
                 num_experts: int = 8,
                 load_balancing_loss_weight: float = 1e-2,
                 fine_tuning: bool = False,
                 lora_rank: int = 16,
                 lora_alpha: int = 32):
        super().__init__()
        self.num_layer = num_layer
        self.load_balancing_loss_weight = load_balancing_loss_weight
        self.fine_tuning = fine_tuning

        # Patch embedding
        self.patch_embedding = PatchEmbedding(
            image_size=image_size,
            patch_size=patch_size,
            in_channels=in_channels,
            embedding_dim=hidden_dim
        )

        # Calculate number of patches (sequence length)
        self.num_patches = (image_size[0] // patch_size) * (image_size[1] // patch_size) + 1  # +1 for cls token

        if q_head == None:
            q_head = (hidden_dim // head_dim)

        if kv_head == None:
            kv_head = (hidden_dim // head_dim)

        if hidden_dim % (head_dim * q_head) != 0 or hidden_dim % (head_dim * kv_head):
            raise ValueError("Error: hidden_dim or projection_dim (if specified) must be divisible by the product of the number of q or kv heads and the head dimension.")

        # Create transformer layers
        self.transformer = nn.ModuleList()
        for _ in range(self.num_layer):
            self.transformer.append(VisionLayer(
                hidden_dim, head_dim, q_head, kv_head,
                expansion_factor=expansion_factor,
                dropout_ratio=dropout_ratio,
                use_moe=use_moe,
                num_experts=num_experts,
                lora_rank=lora_rank,
                lora_alpha=lora_alpha
            ))
        self.output_norm = nn.LayerNorm(hidden_dim)

        # Final classifier head
        self.classifier = nn.Linear(hidden_dim, num_classes)

    def begin_fine_tunning(self) -> None:
        self.fine_tuning = True
        for name, param in self.named_parameters():
            if "lora" not in name:
                param.requires_grad = False
            else:
                param.requires_grad = True

    def exit_fine_tunning(self) -> None:
        self.fine_tuning = False
        for name, param in self.named_parameters():
            param.requires_grad = True

    def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
        # Handle input shape
        if len(x.shape) == 3:  # [batch_size, 64, 72]
            # Reshape to [batch_size, channels, height, width]
            # Assuming the input is grayscale (1 channel)
            batch_size, height, width = x.shape
            x = x.unsqueeze(1)  # Add channel dimension [batch_size, 1, 64, 72]

        # Apply patch embedding
        x = self.patch_embedding(x)

        # Track load-balancing across layers (only if MoE is used)
        load_balancing_sum = torch.tensor(0.0, device=x.device)

        # Pass through transformer layers
        for layer in self.transformer:
            x, load_balancing_loss = layer(x, fine_tuning=self.fine_tuning)
            load_balancing_sum += load_balancing_loss

        load_balancing_loss = (load_balancing_sum / self.num_layer) * self.load_balancing_loss_weight

        # Apply output normalization
        x = self.output_norm(x)

        # Use CLS token for classification
        x = x[:, 0]  # Take only the CLS token

        # Apply classifier
        x = self.classifier(x)

        return x, load_balancing_loss

In [4]:
# Define image transformations for preprocessing
transform = transforms.Compose([
    transforms.Resize((256, 256)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])

def valid_image_folder(path: str) -> bool:
    # Check if file starts with '._' or ends with '.DS_Store'
    filename = os.path.basename(path)
    if filename.startswith("._") or filename == ".DS_Store": # Stupid MacOS
        return False
    
    return True

# Use ImageFolder to automatically label images based on folder names
dataset = datasets.ImageFolder(root=dataset_filepath, is_valid_file=valid_image_folder, transform=transform)

train_size = int(0.8 * len(dataset))
val_size = int(0.1 * len(dataset))
test_size = len(dataset) - train_size - val_size
g = torch.Generator().manual_seed(seed)
train_dataset, val_dataset, test_dataset = torch.utils.data.random_split(
    dataset, [train_size, val_size, test_size], generator=g
)

# Create DataLoaders for training and validation
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [5]:
# Checkpoint loading
def load_checkpoint(model, optimizer, scheduler, filepath, device):
    checkpoint = torch.load(filepath, map_location=device, weights_only=False)
    model.load_state_dict(checkpoint["model_state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    scheduler.load_state_dict(checkpoint["scheduler_state_dict"])

    epoch = checkpoint["epoch"]
    best_loss = checkpoint.get("best_loss", float("inf"))
    print(f"Resumed from epoch {epoch}  |  val loss {checkpoint['validation_loss']:.6f}")
    return epoch + 1, best_loss

# Checkpoint saving
def save_checkpoint(model, optimizer, scheduler, epoch, val_loss, best_loss):
    checkpoint = {
        "model_state_dict":      model.state_dict(),
        "optimizer_state_dict":  optimizer.state_dict(),
        "scheduler_state_dict":  scheduler.state_dict(),
        "epoch":                 epoch,
        "validation_loss":       val_loss,
        "best_loss":             best_loss,
    }
    torch.save(checkpoint, f"{save_dir}/vit_checkpoint_epoch_{epoch}.pt")

    if val_loss < best_loss:
        torch.save(checkpoint, f"{save_dir}/vit_best_model.pt")
        print(f"🔖  New best model - val loss {val_loss:.3f}")
        best_loss = val_loss
    return best_loss

In [6]:
vit = VisionTransformer(
    image_size=(256, 256),    # Your input image dimensions
    patch_size=16,           # Size of each patch
    in_channels=3,          # We only technically don't have channel value here.
    num_classes=1000,       # Number of output classes
    num_layer=3,            # Number of transformer layers
    hidden_dim=1024,        # Hidden dimension
    expansion_factor=8,     # Expansion factor for FFN
    head_dim=64,            # Dimension of each attention head
    q_head=16,              # Number of query heads
    kv_head=4,              # Number of key/value heads
    use_moe=False,          # Whether to use Mixture of Experts
    dropout_ratio=0.5
).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(vit.parameters(), lr=lr, weight_decay=weight_decay)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs*len(train_loader), eta_min=1e-5)

# Load checkpoint if available
current_epoch = 0
best_loss = float("inf")
if checkpoint_filepath is not None and checkpoint_filepath != "":
    current_epoch, best_loss = load_checkpoint(vit, optimizer, scheduler, checkpoint_filepath, device)

print(f"This model has {sum(p.numel() for p in vit.parameters())} parameters.")
print(f"Training on {device}")

This model has 87071912 parameters.
Training on mps


In [7]:
# Initialize loss tracking lists
loss_train = []
loss_valid = []
accuracy_train = []
accuracy_valid = []

In [8]:
# Training loop
for epoch in range(current_epoch, epochs):
    print(f"Epoch {epoch+1}/{epochs}")
    
    # Training phase
    vit.train()
    loss_train_epoch = []
    correct_train = 0
    total_train = 0
    
    # Sliding windows to store metrics for the last 1000 iterations
    recent_losses = deque(maxlen=1000)
    recent_corrects = deque(maxlen=1000)
    recent_totals = deque(maxlen=1000)
    
    for i, (inputs, targets) in enumerate(tqdm(train_loader, desc="Training")):
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        # Forward pass with mixed precision
        outputs, load_balancing_loss = vit(inputs)
        loss = criterion(outputs, targets) + load_balancing_loss
        
        # Backward pass with gradient scaling
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(vit.parameters(), max_norm=1.0)
        
        # Optimizer step
        optimizer.step()
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Update scheduler
        scheduler.step()
        
        # Record loss for epoch statistics
        loss_train_epoch.append(loss.item())
        
        # Calculate accuracy
        predictions = torch.argmax(outputs, dim=1)
        batch_correct = (predictions == targets).sum().item()
        batch_total = targets.size(0)
        
        # Update running totals for epoch statistics
        total_train += batch_total
        correct_train += batch_correct
        
        # Add current batch metrics to the sliding windows
        recent_losses.append(loss.item())
        recent_corrects.append(batch_correct)
        recent_totals.append(batch_total)
        
        # Calculate global iteration number
        iteration = epoch * len(train_loader) + i
        
        # Report metrics every 1000 iterations
        if (iteration + 1) % 1000 == 0:
            # Calculate average metrics over the last 1000 iterations
            avg_loss = sum(recent_losses) / len(recent_losses)
            avg_accuracy = 100 * sum(recent_corrects) / sum(recent_totals)
            
            print(f"Iteration {iteration + 1} - Avg Loss: {avg_loss:.6f}, Avg Accuracy: {avg_accuracy:.2f}%")
    
    # Calculate epoch statistics
    epoch_loss = np.mean(loss_train_epoch)
    epoch_accuracy = 100 * correct_train / total_train
    loss_train.append(epoch_loss)
    accuracy_train.append(epoch_accuracy)
    
    # Validation phase
    vit.eval()
    loss_val_epoch = []
    correct_val = 0
    total_val = 0
    
    with torch.no_grad():
        for inputs, targets in tqdm(val_loader, desc="Validation"):
            inputs = inputs.to(device)
            targets = targets.to(device)
            
            # Forward pass
            outputs, load_balancing_loss = vit(inputs)
            loss = criterion(outputs, targets) + load_balancing_loss
            
            # Record loss
            loss_val_epoch.append(loss.item())
            
            # Calculate accuracy
            predictions = torch.argmax(outputs, dim=1)
            total_val += targets.size(0)
            correct_val += (predictions == targets).sum().item()
    
    # Calculate epoch validation statistics
    epoch_val_loss = np.mean(loss_val_epoch)
    epoch_val_accuracy = 100 * correct_val / total_val
    loss_valid.append(epoch_val_loss)
    accuracy_valid.append(epoch_val_accuracy)
    
    # Print epoch results
    print(f"Training - Loss: {epoch_loss:.6f}, Accuracy: {epoch_accuracy:.2f}%")
    print(f"Validation - Loss: {epoch_val_loss:.6f}, Accuracy: {epoch_val_accuracy:.2f}%")
    
    # Save checkpoint
    best_loss = save_checkpoint(vit, optimizer, scheduler, epoch, epoch_val_loss, best_loss)

Epoch 1/10


Training:   7%|███████████                                                                                                                                                    | 1000/14403 [39:12<9:07:45,  2.45s/it]

Iteration 1000 - Avg Loss: 6.886652, Avg Accuracy: 0.40%


Training:  14%|█████████████████████▊                                                                                                                                       | 2000/14403 [1:20:43<8:35:38,  2.49s/it]

Iteration 2000 - Avg Loss: 6.676625, Avg Accuracy: 0.74%


Training:  21%|████████████████████████████████▋                                                                                                                            | 3000/14403 [2:02:39<8:04:08,  2.55s/it]

Iteration 3000 - Avg Loss: 6.583114, Avg Accuracy: 0.87%


Training:  28%|███████████████████████████████████████████▌                                                                                                                 | 4000/14403 [2:45:11<7:21:09,  2.54s/it]

Iteration 4000 - Avg Loss: 6.520658, Avg Accuracy: 0.98%


Training:  35%|██████████████████████████████████████████████████████▌                                                                                                      | 5000/14403 [3:27:52<6:47:43,  2.60s/it]

Iteration 5000 - Avg Loss: 6.455154, Avg Accuracy: 1.21%


Training:  42%|█████████████████████████████████████████████████████████████████▍                                                                                           | 6000/14403 [4:10:55<6:08:01,  2.63s/it]

Iteration 6000 - Avg Loss: 6.397080, Avg Accuracy: 1.31%


Training:  49%|████████████████████████████████████████████████████████████████████████████▎                                                                                | 7000/14403 [4:53:56<5:18:02,  2.58s/it]

Iteration 7000 - Avg Loss: 6.338722, Avg Accuracy: 1.56%


Training:  56%|███████████████████████████████████████████████████████████████████████████████████████▏                                                                     | 8000/14403 [5:37:02<4:34:51,  2.58s/it]

Iteration 8000 - Avg Loss: 6.285608, Avg Accuracy: 1.84%


Training:  62%|██████████████████████████████████████████████████████████████████████████████████████████████████                                                           | 9000/14403 [6:20:16<3:52:02,  2.58s/it]

Iteration 9000 - Avg Loss: 6.242333, Avg Accuracy: 2.02%


Training:  69%|████████████████████████████████████████████████████████████████████████████████████████████████████████████▎                                               | 10000/14403 [7:04:00<3:10:55,  2.60s/it]

Iteration 10000 - Avg Loss: 6.198742, Avg Accuracy: 2.20%


Training:  76%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▏                                    | 11000/14403 [7:47:26<2:27:40,  2.60s/it]

Iteration 11000 - Avg Loss: 6.165244, Avg Accuracy: 2.28%


Training:  83%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▉                          | 12000/14403 [8:29:09<1:35:58,  2.40s/it]

Iteration 12000 - Avg Loss: 6.117333, Avg Accuracy: 2.63%


Training:  90%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌               | 13000/14403 [9:07:20<55:00,  2.35s/it]

Iteration 13000 - Avg Loss: 6.095272, Avg Accuracy: 2.70%


Training:  97%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████▌    | 14000/14403 [9:46:24<15:52,  2.36s/it]

Iteration 14000 - Avg Loss: 6.073475, Avg Accuracy: 2.78%


Training: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 14403/14403 [10:02:18<00:00,  2.51s/it]
Validation: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1801/1801 [23:56<00:00,  1.25it/s]


Training - Loss: 6.350513, Accuracy: 1.72%
Validation - Loss: 6.171954, Accuracy: 2.44%
🔖  New best model - val loss 6.172
Epoch 2/10


Training:   4%|██████▌                                                                                                                                                        | 597/14403 [25:50<10:13:42,  2.67s/it]

Iteration 15000 - Avg Loss: 6.009665, Avg Accuracy: 2.97%


Training:  11%|█████████████████▍                                                                                                                                           | 1597/14403 [1:09:11<9:00:46,  2.53s/it]

Iteration 16000 - Avg Loss: 5.997274, Avg Accuracy: 3.25%


Training:  18%|████████████████████████████▎                                                                                                                                | 2597/14403 [1:52:36<8:33:14,  2.61s/it]

Iteration 17000 - Avg Loss: 5.958666, Avg Accuracy: 3.33%


Training:  25%|███████████████████████████████████████▏                                                                                                                     | 3597/14403 [2:36:31<7:25:34,  2.47s/it]

Iteration 18000 - Avg Loss: 5.932547, Avg Accuracy: 3.66%


Training:  32%|██████████████████████████████████████████████████                                                                                                           | 4597/14403 [3:18:35<7:07:25,  2.62s/it]

Iteration 19000 - Avg Loss: 5.911347, Avg Accuracy: 3.60%


Training:  39%|█████████████████████████████████████████████████████████████                                                                                                | 5597/14403 [4:00:43<6:12:28,  2.54s/it]

Iteration 20000 - Avg Loss: 5.883861, Avg Accuracy: 3.84%


Training:  40%|███████████████████████████████████████████████████████████████                                                                                              | 5785/14403 [4:08:59<6:10:55,  2.58s/it]


KeyboardInterrupt: 

In [None]:
# Final evaluation on test set
print("\nEvaluating on test set...")
vit.eval()
test_loss = 0
correct = 0
total = 0

with torch.no_grad():
    for inputs, targets in tqdm(test_loader, desc="Testing"):
        inputs = inputs.to(device)
        targets = targets.to(device)
        
        outputs, load_balancing_loss = vit(inputs)
        loss = criterion(outputs, targets) + load_balancing_loss
        
        test_loss += loss.item()
        predictions = torch.argmax(outputs, dim=1)
        total += targets.size(0)
        correct += (predictions == targets).sum().item()

avg_test_loss = test_loss / len(test_loader)
test_accuracy = 100 * correct / total

print(f"Test set - Loss: {avg_test_loss:.6f}, Accuracy: {test_accuracy:.2f}%")

# Save final model
torch.save({
    'model_state_dict': vit.state_dict(),
    'test_accuracy': test_accuracy,
    'test_loss': avg_test_loss
}, f"{save_dir}/vit_final_model.pt")

print("Training completed!")