In [None]:
# 2D Sinusoidal Positional Encoding for Vision Transformer
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import math
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

class SinCos2DPositionalEncoding(nn.Module):
    def __init__(self, dim, h=8, w=8):
        super(SinCos2DPositionalEncoding, self).__init__()
        
        # Create 2D positional encoding
        pos_enc = torch.zeros(h, w, dim)
        
        # Separate channels for width and height dimensions
        dim_h = dim // 2
        dim_w = dim // 2
        
        # Position indices
        y_pos = torch.arange(h).unsqueeze(1).repeat(1, w).reshape(h, w)
        x_pos = torch.arange(w).unsqueeze(0).repeat(h, 1).reshape(h, w)
        
        # Create division term for computing positional encoding values
        div_term_h = torch.exp(torch.arange(0, dim_h, 2).float() * -(math.log(10000.0) / dim_h))
        div_term_w = torch.exp(torch.arange(0, dim_w, 2).float() * -(math.log(10000.0) / dim_w))
        
        # Apply sin and cos to odd and even indices
        for i in range(0, dim_h, 2):
            if i < dim_h:
                pos_enc[:, :, i] = torch.sin(y_pos.float() * div_term_h[i//2])
                pos_enc[:, :, i+1] = torch.cos(y_pos.float() * div_term_h[i//2])
            
        for i in range(0, dim_w, 2):
            if i + dim_h < dim:
                pos_enc[:, :, i+dim_h] = torch.sin(x_pos.float() * div_term_w[i//2])
                pos_enc[:, :, i+dim_h+1] = torch.cos(x_pos.float() * div_term_w[i//2])
        
        # Flatten the positional encoding to match the sequence format (h*w, dim)
        pos_enc = pos_enc.reshape(h * w, dim)
        
        # Add extra position for class token
        cls_pos_enc = torch.zeros(1, dim)
        pos_enc = torch.cat([cls_pos_enc, pos_enc], dim=0)
        
        # Register as buffer (persistent but not model parameter)
        self.register_buffer('pos_enc', pos_enc.unsqueeze(0))
        
    def forward(self, x):
        return x + self.pos_enc
    
# 2D Sinusoidal Positional Encoding Generator
class SinCos2DPositionalEncodingAppend:
    def __init__(self, pos_dim, h, w):
        """
        Generate 2D sinusoidal positional encodings for image patches
        
        Args:
            pos_dim: Dimension of the positional encoding vector
            h: Number of patches in height dimension
            w: Number of patches in width dimension
        """
        self.pos_dim = pos_dim
        self.h = h
        self.w = w
        
        # Initialize positional encodings
        self.generate_encodings()
        
    def generate_encodings(self):
        # Create position indices
        y_pos = torch.arange(self.h).unsqueeze(1).repeat(1, self.w).reshape(self.h, self.w)
        x_pos = torch.arange(self.w).unsqueeze(0).repeat(self.h, 1).reshape(self.h, self.w)
        
        # Split dimensions for height and width
        dim_h = self.pos_dim // 2
        dim_w = self.pos_dim - dim_h  # In case pos_dim is odd
        
        # Division terms for computing positional encoding
        div_term_h = torch.exp(torch.arange(0, dim_h, 2).float() * -(math.log(10000.0) / dim_h))
        div_term_w = torch.exp(torch.arange(0, dim_w, 2).float() * -(math.log(10000.0) / dim_w))
        
        # Create positional encoding tensor
        pos_enc = torch.zeros(self.h, self.w, self.pos_dim)
        
        # Apply sin and cos to encode height positions
        for i in range(0, dim_h, 2):
            if i < dim_h:
                pos_enc[:, :, i] = torch.sin(y_pos.float() * div_term_h[i//2])
                if i + 1 < dim_h:
                    pos_enc[:, :, i+1] = torch.cos(y_pos.float() * div_term_h[i//2])
        
        # Apply sin and cos to encode width positions
        for i in range(0, dim_w, 2):
            if i + dim_h < self.pos_dim:
                pos_enc[:, :, i+dim_h] = torch.sin(x_pos.float() * div_term_w[i//2])
                if i + dim_h + 1 < self.pos_dim:
                    pos_enc[:, :, i+dim_h+1] = torch.cos(x_pos.float() * div_term_w[i//2])
        
        # Reshape to (h*w, pos_dim)
        self.pos_enc = pos_enc.reshape(self.h * self.w, self.pos_dim)
        
        # Create a special positional encoding for the class token
        self.cls_pos_enc = torch.zeros(1, self.pos_dim)

In [2]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import math
import matplotlib.pyplot as plt
import numpy as np
from tqdm import tqdm

# Set random seed for reproducibility
torch.manual_seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Hyperparameters
batch_size = 128
num_epochs = 50
learning_rate = 3e-4
weight_decay = 1e-4

image_size = 32
patch_size = 4
num_classes = 10
dim = 256
depth = 6
heads = 8
mlp_dim = 512
channels = 3
dropout = 0.1

# Data Loading and Preprocessing
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# Multi-head Self Attention
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, dim, heads=8, dropout=0.1):
        super(MultiHeadSelfAttention, self).__init__()
        self.heads = heads
        self.head_dim = dim // heads
        self.scale = self.head_dim ** -0.5
        
        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        self.to_out = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)
        
    def forward(self, x):
        b, n, c = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: t.reshape(b, n, self.heads, self.head_dim).transpose(1, 2), qkv)
        
        # Attention
        dots = torch.matmul(q, k.transpose(-1, -2)) * self.scale
        attn = F.softmax(dots, dim=-1)
        attn = self.dropout(attn)
        
        # Apply attention to v
        out = torch.matmul(attn, v)
        out = out.transpose(1, 2).reshape(b, n, c)
        out = self.to_out(out)
        return out

# MLP Block
class MLP(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.1):
        super(MLP, self).__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        return self.net(x)

# Transformer Encoder Block
class TransformerBlock(nn.Module):
    def __init__(self, dim, heads, mlp_dim, dropout=0.1):
        super(TransformerBlock, self).__init__()
        self.norm1 = nn.LayerNorm(dim)
        self.attn = MultiHeadSelfAttention(dim, heads, dropout)
        self.norm2 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, mlp_dim, dropout)
        
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

# Vision Transformer
class ViT(nn.Module):
    def __init__(
        self, 
        image_size, 
        patch_size, 
        num_classes, 
        dim, 
        depth, 
        heads, 
        mlp_dim, 
        channels=3, 
        dropout=0.1
    ):
        super(ViT, self).__init__()
        assert image_size % patch_size == 0, 'Image size must be divisible by patch size'
        self.num_patches = (image_size // patch_size) ** 2
        patch_dim = channels * patch_size ** 2
        
        self.patch_size = patch_size
        self.h_patches = image_size // patch_size
        self.w_patches = image_size // patch_size
        
        # Patch embedding
        self.to_patch_embedding = nn.Linear(patch_dim, dim)
        
        # Class token
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        
        # Positional encoding - 2D version
        self.pos_embedding = SinCos2DPositionalEncoding(dim, h=self.h_patches, w=self.w_patches)
        
        # Dropout
        self.dropout = nn.Dropout(dropout)
        
        # Transformer blocks
        self.transformer_blocks = nn.ModuleList([
            TransformerBlock(dim, heads, mlp_dim, dropout) for _ in range(depth)
        ])
        
        # MLP Head
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim),
            nn.Linear(dim, num_classes)
        )
        
    def forward(self, img):
        # Get batch size and reshape image into patches
        b, c, h, w = img.shape
        
        # Split image into patches
        patches = img.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        patches = patches.contiguous().view(b, c, -1, self.patch_size, self.patch_size)
        patches = patches.permute(0, 2, 1, 3, 4).contiguous().view(b, -1, c * self.patch_size * self.patch_size)
        
        # Project patches to embedding dimension
        x = self.to_patch_embedding(patches)
        
        # Add class token
        cls_tokens = self.cls_token.expand(b, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        
        # Add positional encoding - now 2D aware
        x = self.pos_embedding(x)
        x = self.dropout(x)
        
        # Apply transformer blocks
        for block in self.transformer_blocks:
            x = block(x)
        
        # Get the class token representation
        x = x[:, 0]
        
        # MLP head
        return self.mlp_head(x)

# Create model, optimizer, and loss function
model = ViT(
    image_size=image_size,
    patch_size=patch_size,
    num_classes=num_classes,
    dim=dim,
    depth=depth,
    heads=heads,
    mlp_dim=mlp_dim,
    channels=channels,
    dropout=dropout
).to(device)

optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
criterion = nn.CrossEntropyLoss()

# Training loop
train_losses = []
train_accs = []
test_losses = []
test_accs = []

def train_one_epoch(model, train_loader, optimizer, criterion, epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}')
    for batch_idx, (inputs, targets) in enumerate(pbar):
        inputs, targets = inputs.to(device), targets.to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        
        # Backward pass and optimize
        loss.backward()
        optimizer.step()
        
        # Statistics
        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()
        
        # Update progress bar
        pbar.set_postfix({
            'loss': running_loss / (batch_idx + 1),
            'acc': 100. * correct / total
        })
    
    train_loss = running_loss / len(train_loader)
    train_acc = 100. * correct / total
    
    return train_loss, train_acc

def test(model, test_loader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(test_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            
            # Forward pass
            outputs = model(inputs)
            loss = criterion(outputs, targets)
            
            # Statistics
            running_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
    
    test_loss = running_loss / len(test_loader)
    test_acc = 100. * correct / total
    
    print(f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')
    
    return test_loss, test_acc

# Main training loop
for epoch in range(num_epochs):
    # Train
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion, epoch)
    train_losses.append(train_loss)
    train_accs.append(train_acc)
    
    # Test
    test_loss, test_acc = test(model, test_loader, criterion)
    test_losses.append(test_loss)
    test_accs.append(test_acc)
    
    # Update learning rate
    scheduler.step()
    
    print(f'Epoch {epoch+1}/{num_epochs}: Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%, Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.2f}%')

# Plot training and testing curves
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curves')

plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Accuracy')
plt.plot(test_accs, label='Test Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.legend()
plt.title('Accuracy Curves')

plt.tight_layout()
plt.savefig('vit_training_curves.png')
plt.show()

# Save the model
torch.save(model.state_dict(), 'vit_cifar10.pth')
print('Training complete. Model saved to vit_cifar10.pth')  

Using device: cuda


Epoch 1/50: 100%|██████████| 391/391 [00:56<00:00,  6.88it/s, loss=1.68, acc=37.7]


Test Loss: 1.3857, Test Acc: 50.27%
Epoch 1/50: Train Loss: 1.6754, Train Acc: 37.71%, Test Loss: 1.3857, Test Acc: 50.27%


Epoch 2/50: 100%|██████████| 391/391 [00:57<00:00,  6.82it/s, loss=1.31, acc=52.4]


Test Loss: 1.1557, Test Acc: 58.07%
Epoch 2/50: Train Loss: 1.3108, Train Acc: 52.35%, Test Loss: 1.1557, Test Acc: 58.07%


Epoch 3/50: 100%|██████████| 391/391 [00:57<00:00,  6.82it/s, loss=1.17, acc=57.8]


Test Loss: 1.0661, Test Acc: 61.22%
Epoch 3/50: Train Loss: 1.1732, Train Acc: 57.78%, Test Loss: 1.0661, Test Acc: 61.22%


Epoch 4/50: 100%|██████████| 391/391 [00:57<00:00,  6.83it/s, loss=1.09, acc=60.9]


Test Loss: 0.9872, Test Acc: 64.76%
Epoch 4/50: Train Loss: 1.0868, Train Acc: 60.94%, Test Loss: 0.9872, Test Acc: 64.76%


Epoch 5/50: 100%|██████████| 391/391 [00:57<00:00,  6.84it/s, loss=1.02, acc=63.6]


Test Loss: 0.9306, Test Acc: 66.45%
Epoch 5/50: Train Loss: 1.0176, Train Acc: 63.56%, Test Loss: 0.9306, Test Acc: 66.45%


Epoch 6/50: 100%|██████████| 391/391 [00:57<00:00,  6.80it/s, loss=0.961, acc=65.7]


Test Loss: 0.9055, Test Acc: 67.10%
Epoch 6/50: Train Loss: 0.9613, Train Acc: 65.69%, Test Loss: 0.9055, Test Acc: 67.10%


Epoch 7/50: 100%|██████████| 391/391 [00:58<00:00,  6.67it/s, loss=0.914, acc=67.5]


Test Loss: 0.8678, Test Acc: 69.24%
Epoch 7/50: Train Loss: 0.9145, Train Acc: 67.52%, Test Loss: 0.8678, Test Acc: 69.24%


Epoch 8/50:   0%|          | 1/391 [00:00<04:17,  1.52it/s, loss=0.765, acc=72.3]


KeyboardInterrupt: 

In [8]:
import torch

if torch.cuda.is_available():
    print("GPU is available:", torch.cuda.get_device_name(0))
else:
    print("GPU not available")


GPU is available: NVIDIA GeForce RTX 3050 Laptop GPU


In [17]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
from tqdm import tqdm

# Import from APE repository
from ape.nn.position.algebraic import Grid
from ape.nn.position.schemes import grid_applicative

# Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

# Hyperparameters
batch_size = 128
num_epochs = 50
learning_rate = 3e-4
weight_decay = 1e-4

image_size = 32
patch_size = 4
num_classes = 10
dim = 256
depth = 6
heads = 8
mlp_dim = 512
channels = 3
dropout = 0.1

# Dataset
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
test_dataset = datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# Setup Grid Algebraic Positional Encoding
seq_len = (image_size // patch_size) ** 2 + 1  # +1 for cls token
grid_size = image_size // patch_size
head_dim = dim // heads

# Initialize APE for 2D grid
ape = Grid(num_axes=2, dim=head_dim, num_heads=heads).to(device)
ape.precompute(grid_size)

# ViT Components
class MultiHeadSelfAttention(nn.Module):
    def __init__(self, dim, heads, dropout=0.1):
        super().__init__()
        self.heads = heads
        self.head_dim = dim // heads
        self.scale = self.head_dim ** -0.5

        self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
        self.to_out = nn.Linear(dim, dim)
        self.dropout = nn.Dropout(dropout)
        
        # Reference to APE instance
        self.ape = ape

    def forward(self, x):
        b, n, _ = x.shape
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: t.reshape(b, n, self.heads, self.head_dim).transpose(1, 2), qkv)
        
        # Create position IDs for grid
        # Skip cls token (position 0)
        seq_positions = torch.arange(1, n, device=x.device)
        
        # Convert 1D positions to 2D grid coordinates
        grid_h = grid_size
        grid_w = grid_size
        
        # Only calculate positions for non-cls tokens
        y_positions = ((seq_positions - 1) // grid_w).long()
        x_positions = ((seq_positions - 1) % grid_w).long()
        
        # Get positional encodings (cls token gets position 0)
        q_cls = q[:, :, 0:1, :]
        k_cls = k[:, :, 0:1, :]
        
        # Non-cls tokens
        q_grid = q[:, :, 1:, :]
        k_grid = k[:, :, 1:, :]
        
        # Apply attention with algebraic positional encodings for grid tokens
        if n > 1:  # If we have more than just the cls token
            # Get positional maps
            q_maps_grid = self.ape(x_positions, y_positions)
            k_maps_grid = self.ape(x_positions, y_positions)
            b_size = q_grid.size(0)
    
            # Expand the first dimension of maps to match batch size
            # Assuming the maps are the same for each batch element
            maps_x_q, maps_y_q = q_maps_grid
            maps_x_k, maps_y_k = k_maps_grid
    
            # Expand maps to match batch size by repeating along first dimension
            maps_x_q = maps_x_q.unsqueeze(0).expand(b_size, -1, -1, -1, -1)
            maps_y_q = maps_y_q.unsqueeze(0).expand(b_size, -1, -1, -1, -1)
            maps_x_k = maps_x_k.unsqueeze(0).expand(b_size, -1, -1, -1, -1)
            maps_y_k = maps_y_k.unsqueeze(0).expand(b_size, -1, -1, -1, -1)
            
            # Recreate tuples with expanded maps
            q_maps_grid = (maps_x_q, maps_y_q)
            k_maps_grid = (maps_x_k, maps_y_k)
            
            # Now use grid_applicative
            from ape.nn.position.schemes import grid_applicative
            atn_fn = grid_applicative(q_maps_grid, k_maps_grid, None)
            
            # Apply to non-cls tokens
            grid_attn = atn_fn(q_grid, k_grid, None)

            # print(f"q_maps_grid shape: {[t.shape for t in q_maps_grid]}") 
            # Create attention function with APE
            # This is the fixed part - using grid_applicative instead of Grid.adjust_attention
            from ape.nn.position.schemes import grid_applicative
            atn_fn = grid_applicative(q_maps_grid, k_maps_grid, None)
            
            # Apply to non-cls tokens - calculate attention scores
            # grid_attn = atn_fn(q_grid, k_grid)
            grid_attn = atn_fn(q_grid, k_grid, None)
            grid_attn = grid_attn * self.scale  # Apply scaling
            
            # Calculate attention for cls token normally
            cls_attn = (q_cls @ k[:, :, 1:].transpose(-2, -1)) * self.scale
            cls_rest_attn = (q[:, :, 1:] @ k_cls.transpose(-2, -1)) * self.scale
            
            # Build full attention matrix
            attn_cls = torch.cat([
                torch.zeros(b, self.heads, 1, 1, device=x.device),  # cls to cls
                cls_attn  # cls to rest
            ], dim=-1)
            
            attn_rest = torch.cat([
                cls_rest_attn,  # rest to cls
                grid_attn  # rest to rest
            ], dim=-1)
            
            attn = torch.cat([attn_cls, attn_rest], dim=-2)
        else:
            # If we only have cls token, use regular attention
            attn = (q @ k.transpose(-2, -1)) * self.scale
        
        # Apply softmax and dropout
        attn = attn.softmax(dim=-1)
        attn = self.dropout(attn)
        
        # Get output
        out = attn @ v
        out = out.transpose(1, 2).reshape(b, n, -1)
        return self.to_out(out)

class MLP(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x): 
        return self.net(x)

class TransformerBlock(nn.Module):
    def __init__(self, dim, heads, mlp_dim, dropout=0.1):
        super().__init__()
        self.attn = MultiHeadSelfAttention(dim, heads, dropout)
        self.norm1 = nn.LayerNorm(dim)
        self.mlp = MLP(dim, mlp_dim, dropout)
        self.norm2 = nn.LayerNorm(dim)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

class ViT(nn.Module):
    def __init__(self, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, dropout=0.1):
        super().__init__()
        self.patch_size = patch_size
        self.patch_dim = channels * patch_size ** 2
        self.grid_size = image_size // patch_size
        self.num_patches = self.grid_size ** 2
        
        # Patch embedding
        self.to_patch_embedding = nn.Linear(self.patch_dim, dim)
        self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout = nn.Dropout(dropout)
        
        # Transformer blocks with APE
        self.transformer = nn.Sequential(
            *[TransformerBlock(dim, heads, mlp_dim, dropout) for _ in range(depth)]
        )
        
        # MLP head
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(dim), 
            nn.Linear(dim, num_classes)
        )

    def forward(self, img):
        b, c, h, w = img.shape
        
        # Extract patches
        patches = img.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size)
        patches = patches.permute(0, 2, 3, 1, 4, 5).contiguous()
        patches = patches.view(b, -1, self.patch_dim)
        
        # Project patches to embedding dimension
        x = self.to_patch_embedding(patches)
        
        # Add cls token
        cls_token = self.cls_token.expand(b, -1, -1)
        x = torch.cat((cls_token, x), dim=1)
        
        # Apply dropout
        x = self.dropout(x)
        
        # Pass through transformer blocks
        x = self.transformer(x)
        
        # Use cls token for classification
        return self.mlp_head(x[:, 0])

# Model
model = ViT(
    image_size=image_size,
    patch_size=patch_size,
    num_classes=num_classes,
    dim=dim,
    depth=depth,
    heads=heads,
    mlp_dim=mlp_dim,
    dropout=dropout
).to(device)

optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=weight_decay)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)
criterion = nn.CrossEntropyLoss()

train_losses, train_accs, test_losses, test_accs = [], [], [], []

def train_one_epoch(model, loader, optimizer, criterion):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for x, y in tqdm(loader, desc='Training'):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        out = model(x)
        loss = criterion(out, y)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
        correct += (out.argmax(1) == y).sum().item()
        total += y.size(0)
    return total_loss / len(loader), 100 * correct / total

def evaluate(model, loader, criterion):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    with torch.no_grad():
        for x, y in loader:
            x, y = x.to(device), y.to(device)
            out = model(x)
            loss = criterion(out, y)
            total_loss += loss.item()
            correct += (out.argmax(1) == y).sum().item()
            total += y.size(0)
    return total_loss / len(loader), 100 * correct / total

for epoch in range(num_epochs):
    train_loss, train_acc = train_one_epoch(model, train_loader, optimizer, criterion)
    test_loss, test_acc = evaluate(model, test_loader, criterion)
    scheduler.step()

    train_losses.append(train_loss)
    train_accs.append(train_acc)
    test_losses.append(test_loss)
    test_accs.append(test_acc)

    print(f"Epoch {epoch+1}/{num_epochs}: Train Acc={train_acc:.2f}%, Test Acc={test_acc:.2f}%")

# Plotting
plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(train_losses, label='Train Loss')
plt.plot(test_losses, label='Test Loss')
plt.legend()
plt.title('Loss Curves')
plt.subplot(1, 2, 2)
plt.plot(train_accs, label='Train Acc')
plt.plot(test_accs, label='Test Acc')
plt.legend()
plt.title('Accuracy Curves')
plt.tight_layout()
plt.savefig('vit_ape_curves.png')
plt.show()

# Save
torch.save(model.state_dict(), 'vit_ape.pth')
print("Model saved as vit_ape.pth")

Using device: cuda


Training:   0%|          | 0/391 [00:00<?, ?it/s]


RuntimeError: einsum(): subscript q has size 64 for operand 1 which does not broadcast with previously seen size 8