In [2]:
# ==============================================================================
# Cell 1: Imports and Setup
# ==============================================================================
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm
import math

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

# ==============================================================================
# Cell 2: Configuration
# ==============================================================================
config = {
    "patch_size": 4,
    "image_size": 32,
    "in_channels": 3,
    "embed_dim": 512,
    "num_heads": 8,
    "mlp_ratio": 4.0,
    "depth": 6,
    "num_classes": 10,
    "dropout_rate": 0.1,
    "learning_rate": 1e-3,
    "weight_decay": 0.05,
    "batch_size": 256,
    "num_epochs": 100,
    "warmup_epochs": 5,
}

# ==============================================================================
# Cell 3: Building Blocks for the Vision Transformer (ViT)
# ==============================================================================
class PatchEmbedding(nn.Module):
    def __init__(self, image_size, patch_size, in_channels, embed_dim):
        super().__init__()
        self.num_patches = (image_size // patch_size) ** 2
        self.proj = nn.Conv2d(in_channels, embed_dim, kernel_size=patch_size, stride=patch_size)

    def forward(self, x):
        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        return x

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout_rate):
        super().__init__()
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.attn_drop = nn.Dropout(dropout_rate)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.proj_drop = nn.Dropout(dropout_rate)

    def forward(self, x):
        B, N, D = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5)
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)
        x = (attn @ v).transpose(1, 2).reshape(B, N, D)
        x = self.proj(x)
        x = self.proj_drop(x)
        return x

class MLP(nn.Module):
    def __init__(self, in_features, hidden_features, out_features, dropout_rate):
        super().__init__()
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = nn.GELU()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(dropout_rate)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class TransformerEncoder(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio, dropout_rate):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout_rate)
        self.norm2 = nn.LayerNorm(embed_dim)
        mlp_hidden_dim = int(embed_dim * mlp_ratio)
        self.mlp = MLP(embed_dim, mlp_hidden_dim, embed_dim, dropout_rate)

    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x

# ==============================================================================
# Cell 4: The Full Vision Transformer (ViT) Model
# ==============================================================================
class VisionTransformer(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.patch_embed = PatchEmbedding(config["image_size"], config["patch_size"], config["in_channels"], config["embed_dim"])
        self.cls_token = nn.Parameter(torch.zeros(1, 1, config["embed_dim"]))
        num_patches = self.patch_embed.num_patches
        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, config["embed_dim"]))
        self.pos_drop = nn.Dropout(config["dropout_rate"])
        self.blocks = nn.ModuleList([TransformerEncoder(config["embed_dim"], config["num_heads"], config["mlp_ratio"], config["dropout_rate"]) for _ in range(config["depth"])])
        self.norm = nn.LayerNorm(config["embed_dim"])
        self.head = nn.Linear(config["embed_dim"], config["num_classes"])

    def forward(self, x):
        B = x.shape[0]
        x = self.patch_embed(x)
        cls_tokens = self.cls_token.expand(B, -1, -1)
        x = torch.cat((cls_tokens, x), dim=1)
        x = x + self.pos_embed
        x = self.pos_drop(x)
        for block in self.blocks:
            x = block(x)
        x = self.norm(x)
        cls_token_final = x[:, 0]
        logits = self.head(cls_token_final)
        return logits

# ==============================================================================
# Cell 5: Data Loading and Augmentation
# ==============================================================================
transform_train = transforms.Compose([
    transforms.RandomResizedCrop(config["image_size"], scale=(0.8, 1.0)),
    transforms.RandomHorizontalFlip(),
    transforms.TrivialAugmentWide(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)),
])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=config["batch_size"], shuffle=True, num_workers=2, pin_memory=True)
testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=config["batch_size"], shuffle=False, num_workers=2, pin_memory=True)

# ==============================================================================
# Cell 6: Training and Evaluation Loop
# ==============================================================================
model = VisionTransformer(config).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=config["learning_rate"], weight_decay=config["weight_decay"])

# Warmup Scheduler
def warmup_cosine_scheduler(epoch):
    if epoch < config["warmup_epochs"]:
        return float(epoch) / float(max(1, config["warmup_epochs"]))
    else:
        progress = float(epoch - config["warmup_epochs"]) / float(max(1, config["num_epochs"] - config["warmup_epochs"]))
        return 0.5 * (1.0 + math.cos(math.pi * progress))

scheduler = LambdaLR(optimizer, lr_lambda=warmup_cosine_scheduler)


def train_one_epoch(epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    progress_bar = tqdm(trainloader, desc=f"Epoch {epoch+1}/{config['num_epochs']} [Train]")

    for i, (inputs, labels) in enumerate(progress_bar):
        inputs, labels = inputs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        # Gradient Clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()

        progress_bar.set_postfix(loss=running_loss/(i+1), acc=f"{(100.*correct/total):.2f}%")

def test(epoch):
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for inputs, labels in testloader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    accuracy = 100. * correct / total
    return accuracy

# ==============================================================================
# Cell 7: Run Training
# ==============================================================================
best_accuracy = 0.0
for epoch in range(config["num_epochs"]):
    train_one_epoch(epoch)
    current_accuracy = test(epoch)
    scheduler.step()

    if current_accuracy > best_accuracy:
        best_accuracy = current_accuracy
        print(f"🎉 New best accuracy at epoch {epoch+1}: {best_accuracy:.2f}%")
    else:
        print(f"Epoch {epoch+1} Test Accuracy: {current_accuracy:.2f}% (Best: {best_accuracy:.2f}%)")

print(f"\n🏁 Training Finished! Best Test Accuracy: {best_accuracy:.2f}%")

Using device: cuda


100%|██████████| 170M/170M [00:55<00:00, 3.09MB/s]
Epoch 1/100 [Train]: 100%|██████████| 196/196 [02:06<00:00,  1.55it/s, acc=9.55%, loss=2.43]


🎉 New best accuracy at epoch 1: 9.94%


Epoch 2/100 [Train]: 100%|██████████| 196/196 [02:11<00:00,  1.49it/s, acc=24.93%, loss=2.03]


🎉 New best accuracy at epoch 2: 37.78%


Epoch 3/100 [Train]: 100%|██████████| 196/196 [02:12<00:00,  1.48it/s, acc=32.84%, loss=1.83]


🎉 New best accuracy at epoch 3: 40.42%


Epoch 4/100 [Train]: 100%|██████████| 196/196 [02:12<00:00,  1.48it/s, acc=36.24%, loss=1.75]


🎉 New best accuracy at epoch 4: 44.00%


Epoch 5/100 [Train]: 100%|██████████| 196/196 [02:12<00:00,  1.48it/s, acc=37.31%, loss=1.73]


🎉 New best accuracy at epoch 5: 45.21%


Epoch 6/100 [Train]: 100%|██████████| 196/196 [02:11<00:00,  1.49it/s, acc=35.81%, loss=1.76]


Epoch 6 Test Accuracy: 43.38% (Best: 45.21%)


Epoch 7/100 [Train]: 100%|██████████| 196/196 [02:12<00:00,  1.48it/s, acc=36.11%, loss=1.76]


Epoch 7 Test Accuracy: 42.61% (Best: 45.21%)


Epoch 8/100 [Train]: 100%|██████████| 196/196 [02:12<00:00,  1.48it/s, acc=34.39%, loss=1.79]


Epoch 8 Test Accuracy: 42.59% (Best: 45.21%)


Epoch 9/100 [Train]: 100%|██████████| 196/196 [02:11<00:00,  1.49it/s, acc=34.73%, loss=1.8]


Epoch 9 Test Accuracy: 42.41% (Best: 45.21%)


Epoch 10/100 [Train]: 100%|██████████| 196/196 [02:12<00:00,  1.48it/s, acc=34.19%, loss=1.8]


Epoch 10 Test Accuracy: 43.82% (Best: 45.21%)


Epoch 11/100 [Train]: 100%|██████████| 196/196 [02:11<00:00,  1.49it/s, acc=35.44%, loss=1.77]


Epoch 11 Test Accuracy: 42.62% (Best: 45.21%)


Epoch 12/100 [Train]: 100%|██████████| 196/196 [02:11<00:00,  1.49it/s, acc=35.00%, loss=1.78]


Epoch 12 Test Accuracy: 45.07% (Best: 45.21%)


Epoch 13/100 [Train]: 100%|██████████| 196/196 [02:12<00:00,  1.48it/s, acc=35.07%, loss=1.79]


Epoch 13 Test Accuracy: 42.95% (Best: 45.21%)


Epoch 14/100 [Train]: 100%|██████████| 196/196 [02:11<00:00,  1.49it/s, acc=34.93%, loss=1.78]


Epoch 14 Test Accuracy: 45.08% (Best: 45.21%)


Epoch 15/100 [Train]: 100%|██████████| 196/196 [02:12<00:00,  1.48it/s, acc=35.96%, loss=1.76]


Epoch 15 Test Accuracy: 44.88% (Best: 45.21%)


Epoch 16/100 [Train]: 100%|██████████| 196/196 [02:12<00:00,  1.48it/s, acc=37.02%, loss=1.74]


Epoch 16 Test Accuracy: 44.86% (Best: 45.21%)


Epoch 17/100 [Train]: 100%|██████████| 196/196 [02:12<00:00,  1.48it/s, acc=36.50%, loss=1.75]


Epoch 17 Test Accuracy: 43.85% (Best: 45.21%)


Epoch 18/100 [Train]: 100%|██████████| 196/196 [02:11<00:00,  1.49it/s, acc=36.15%, loss=1.76]


Epoch 18 Test Accuracy: 43.39% (Best: 45.21%)


Epoch 19/100 [Train]: 100%|██████████| 196/196 [02:12<00:00,  1.48it/s, acc=36.74%, loss=1.74]


🎉 New best accuracy at epoch 19: 46.61%


Epoch 20/100 [Train]: 100%|██████████| 196/196 [02:11<00:00,  1.49it/s, acc=36.97%, loss=1.73]


Epoch 20 Test Accuracy: 45.07% (Best: 46.61%)


Epoch 21/100 [Train]: 100%|██████████| 196/196 [02:12<00:00,  1.48it/s, acc=37.64%, loss=1.72]


🎉 New best accuracy at epoch 21: 46.93%


Epoch 22/100 [Train]: 100%|██████████| 196/196 [02:12<00:00,  1.48it/s, acc=37.49%, loss=1.72]


Epoch 22 Test Accuracy: 45.61% (Best: 46.93%)


Epoch 23/100 [Train]:  36%|███▌      | 70/196 [00:48<01:27,  1.44it/s, acc=37.89%, loss=1.71]


KeyboardInterrupt: 