In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torchvision.transforms import RandAugment
import torchvision.transforms as transforms
from torch.optim.lr_scheduler import _LRScheduler
import time

In [20]:
class WarmupCosineAnnealingLR(_LRScheduler):
    def __init__(self, optimizer, warmup_epochs, max_epochs, last_epoch=-1):
        self.warmup_epochs = warmup_epochs
        self.max_epochs = max_epochs
        super().__init__(optimizer, last_epoch)

    def get_lr(self):
        if self.last_epoch < self.warmup_epochs:
            lr_scale = (self.last_epoch + 1) / self.warmup_epochs
            return [base_lr * lr_scale for base_lr in self.base_lrs]
        else:
            progress = (self.last_epoch - self.warmup_epochs) / (self.max_epochs - self.warmup_epochs)
            lr_scale = 0.5 * (1.0 + torch.cos(torch.tensor(progress * torch.pi)))
            return [base_lr * lr_scale for base_lr in self.base_lrs]

class PatchEmbedding(nn.Module):
    def __init__(self, img_size=32, patch_size=8, in_channels=3, embed_dim=128):
        super().__init__()
        self.img_size = img_size
        self.patch_size = patch_size
        self.n_patches = (img_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        return x

class VisionTransformer(nn.Module):
    # Using your desired scaled-up architecture
    def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_dim=256, n_classes=10, n_layers=8, n_heads=8, mlp_dim=512, dropout=0.1):
        super().__init__()
        self.patch_embedding = PatchEmbedding(img_size, patch_size, in_channels, embed_dim)
        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
        self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.patch_embedding.n_patches, embed_dim))
        self.pos_dropout = nn.Dropout(p=dropout)

        # Transformer Encoder
        encoder_layer = nn.TransformerEncoderLayer(
            d_model=embed_dim,
            nhead=n_heads,
            dim_feedforward=mlp_dim,
            dropout=dropout,
            batch_first=True,
            # --- FIX 1: ADD THIS ARGUMENT FOR PRE-LAYER NORMALIZATION ---
            norm_first=True
        )
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=n_layers)

        # Classifier head
        self.mlp_head = nn.Sequential(
            nn.LayerNorm(embed_dim),
            nn.Linear(embed_dim, n_classes)
        )

    def forward(self, x):
        x = self.patch_embedding(x)
        B, N, E = x.shape
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x += self.pos_embed
        x = self.pos_dropout(x)
        x = self.transformer_encoder(x)
        cls_output = x[:, 0]
        return self.mlp_head(cls_output)


In [21]:
def train_model(epochs=25, batch_size=64):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {device}")

    transform_train = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        RandAugment(num_ops=2, magnitude=9),
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
    trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2, pin_memory=True)

    model = VisionTransformer().to(device)
    criterion = nn.CrossEntropyLoss()

    # --- FIX 2: USE A SAFER LEARNING RATE ---
    optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)

    scheduler = WarmupCosineAnnealingLR(optimizer, warmup_epochs=5, max_epochs=epochs)

    start_time = time.time()
    print("Starting training with STABILIZED configuration...")
    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        for data in trainloader:
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            running_loss += loss.item()

        print(f"Epoch {epoch + 1}, Loss: {running_loss / len(trainloader):.4f}, LR: {scheduler.get_last_lr()[0]:.6f}")
        scheduler.step()

    training_time = time.time() - start_time
    print(f"\nFinished Training in {training_time:.2f}s")

    return model, training_time

In [22]:
def evaluate_model(model, batch_size=64):
    device = 'cuda' if torch.cuda.is_available() else 'cpu'

    transform_test = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
    ])
    testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
    testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2, pin_memory=True)

    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data in testloader:
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    final_accuracy = 100 * correct / total
    return final_accuracy

In [23]:
if __name__ == '__main__':
    trained_model, total_training_time = train_model(epochs=25)
    accuracy = evaluate_model(trained_model)

    print("\n--- Results ---")
    print(f"Final Test Accuracy: {accuracy:.2f}%")
    print(f"Total Training Time: {total_training_time:.2f} seconds")


Using device: cuda
Starting training with STABILIZED configuration...




Epoch 1, Loss: 1.9603, LR: 0.000060
Epoch 2, Loss: 1.7416, LR: 0.000120
Epoch 3, Loss: 1.5941, LR: 0.000180
Epoch 4, Loss: 1.4878, LR: 0.000240
Epoch 5, Loss: 1.4237, LR: 0.000300
Epoch 6, Loss: 1.3570, LR: 0.000300
Epoch 7, Loss: 1.3002, LR: 0.000298
Epoch 8, Loss: 1.2480, LR: 0.000293
Epoch 9, Loss: 1.2026, LR: 0.000284
Epoch 10, Loss: 1.1572, LR: 0.000271
Epoch 11, Loss: 1.1202, LR: 0.000256
Epoch 12, Loss: 1.0815, LR: 0.000238
Epoch 13, Loss: 1.0343, LR: 0.000218
Epoch 14, Loss: 0.9981, LR: 0.000196
Epoch 15, Loss: 0.9607, LR: 0.000173
Epoch 16, Loss: 0.9194, LR: 0.000150
Epoch 17, Loss: 0.8886, LR: 0.000127
Epoch 18, Loss: 0.8598, LR: 0.000104
Epoch 19, Loss: 0.8276, LR: 0.000082
Epoch 20, Loss: 0.8019, LR: 0.000062
Epoch 21, Loss: 0.7802, LR: 0.000044
Epoch 22, Loss: 0.7603, LR: 0.000029
Epoch 23, Loss: 0.7508, LR: 0.000016
Epoch 24, Loss: 0.7467, LR: 0.000007
Epoch 25, Loss: 0.7255, LR: 0.000002

Finished Training in 1498.44s

--- Results ---
Final Test Accuracy: 78.62%
Total Tr