In [6]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import LambdaLR
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from tqdm.auto import tqdm
import math
import random

# Set random seeds
def set_seed(seed=42):
    random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)

# Simple 3-layer MLP
class MLP(nn.Module):
    def __init__(self, input_dim=28*28, hidden_dim=256, output_dim=10):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim)
        self.fc3 = nn.Linear(hidden_dim, output_dim)
        self.relu = nn.ReLU()
    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return self.fc3(x)

# Freeze/unfreeze helper
def freeze_module(m):
    for p in m.parameters():
        p.requires_grad = False
def unfreeze_module(m):
    for p in m.parameters():
        p.requires_grad = True

# Warmup + cosine decay LR scheduler
def make_warmup_cosine_scheduler(optimizer, warmup_epochs, total_epochs, min_lr_ratio=0.0):
    def lr_lambda(epoch):
        if epoch < warmup_epochs and warmup_epochs > 0:
            return epoch / warmup_epochs
        else:
            t = (epoch - warmup_epochs) / max(1, total_epochs - warmup_epochs)
            return min_lr_ratio + 0.5 * (1 - min_lr_ratio) * (1 + math.cos(math.pi * t))
    return LambdaLR(optimizer, lr_lambda=lr_lambda)

def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss, correct, total = 0, 0, 0
    for x, y in tqdm(loader, desc='train', leave=False):
        x, y = x.to(device), y.to(device)
        optimizer.zero_grad()
        logits = model(x)
        loss = criterion(logits, y)
        loss.backward()
        optimizer.step()
        total_loss += loss.item() * x.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += x.size(0)
    return total_loss / total, 100 * correct / total

@torch.no_grad()
def validate(model, loader, criterion, device):
    model.eval()
    total_loss, correct, total = 0, 0, 0
    for x, y in loader:
        x, y = x.to(device), y.to(device)
        logits = model(x)
        loss = criterion(logits, y)
        total_loss += loss.item() * x.size(0)
        preds = logits.argmax(dim=1)
        correct += (preds == y).sum().item()
        total += x.size(0)
    return total_loss / total, 100 * correct / total

def main():
    set_seed()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

    transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
    train_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
    val_dataset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

    train_loader = DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=0)
    val_loader = DataLoader(val_dataset, batch_size=128, shuffle=False, num_workers=0)

    model = MLP().to(device)

    # Initially freeze all layers except last fc3
    freeze_module(model)
    unfreeze_module(model.fc3)

    optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3, weight_decay=1e-4)
    scheduler = make_warmup_cosine_scheduler(optimizer, warmup_epochs=5, total_epochs=30)

    criterion = nn.CrossEntropyLoss()

    # Gradual unfreeze schedule: unfreeze fc2 after 10 epochs, fc1 after 20
    unfreeze_schedule = {10: model.fc2, 20: model.fc1}

    for epoch in range(30):
        if epoch in unfreeze_schedule:
            unfreeze_module(unfreeze_schedule[epoch])
            optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3, weight_decay=1e-4)
            scheduler = make_warmup_cosine_scheduler(optimizer, warmup_epochs=5, total_epochs=30)
            print(f"Epoch {epoch}: Unfroze layer {unfreeze_schedule[epoch]} and reset optimizer/scheduler.")

        train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device)
        val_loss, val_acc = validate(model, val_loader, criterion, device)
        scheduler.step()

        print(f"Epoch {epoch+1} - Train loss: {train_loss:.4f}, Train acc: {train_acc:.2f}% | Val loss: {val_loss:.4f}, Val acc: {val_acc:.2f}%")

if __name__ == "__main__":
    main()


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 1 - Train loss: 2.3087, Train acc: 8.95% | Val loss: 2.3093, Val acc: 9.03%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 2 - Train loss: 2.0895, Train acc: 58.80% | Val loss: 1.8805, Val acc: 74.68%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 3 - Train loss: 1.5977, Train acc: 76.08% | Val loss: 1.3384, Val acc: 79.04%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 4 - Train loss: 1.1521, Train acc: 79.06% | Val loss: 0.9658, Val acc: 81.19%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 5 - Train loss: 0.8767, Train acc: 81.27% | Val loss: 0.7545, Val acc: 83.15%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 6 - Train loss: 0.7159, Train acc: 83.03% | Val loss: 0.6295, Val acc: 84.81%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 7 - Train loss: 0.6213, Train acc: 84.37% | Val loss: 0.5587, Val acc: 85.75%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 8 - Train loss: 0.5633, Train acc: 85.37% | Val loss: 0.5122, Val acc: 86.42%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 9 - Train loss: 0.5233, Train acc: 86.05% | Val loss: 0.4795, Val acc: 87.09%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 10 - Train loss: 0.4940, Train acc: 86.55% | Val loss: 0.4550, Val acc: 87.59%
Epoch 10: Unfroze layer Linear(in_features=256, out_features=256, bias=True) and reset optimizer/scheduler.


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 11 - Train loss: 0.4810, Train acc: 86.84% | Val loss: 0.4550, Val acc: 87.59%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 12 - Train loss: 0.2943, Train acc: 91.33% | Val loss: 0.2237, Val acc: 93.32%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 13 - Train loss: 0.1959, Train acc: 94.33% | Val loss: 0.1584, Val acc: 95.30%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 14 - Train loss: 0.1463, Train acc: 95.71% | Val loss: 0.1311, Val acc: 95.96%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 15 - Train loss: 0.1137, Train acc: 96.65% | Val loss: 0.1247, Val acc: 96.21%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 16 - Train loss: 0.0938, Train acc: 97.19% | Val loss: 0.1120, Val acc: 96.49%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 17 - Train loss: 0.0694, Train acc: 97.90% | Val loss: 0.1042, Val acc: 96.90%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 18 - Train loss: 0.0580, Train acc: 98.25% | Val loss: 0.1002, Val acc: 97.01%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 19 - Train loss: 0.0431, Train acc: 98.78% | Val loss: 0.1061, Val acc: 96.91%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 20 - Train loss: 0.0360, Train acc: 98.95% | Val loss: 0.1069, Val acc: 96.79%
Epoch 20: Unfroze layer Linear(in_features=784, out_features=256, bias=True) and reset optimizer/scheduler.


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 21 - Train loss: 0.0332, Train acc: 99.19% | Val loss: 0.1069, Val acc: 96.79%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 22 - Train loss: 0.0598, Train acc: 98.02% | Val loss: 0.0943, Val acc: 97.23%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 23 - Train loss: 0.0738, Train acc: 97.64% | Val loss: 0.1160, Val acc: 96.99%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 24 - Train loss: 0.0775, Train acc: 97.59% | Val loss: 0.1370, Val acc: 96.58%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 25 - Train loss: 0.0905, Train acc: 97.41% | Val loss: 0.1127, Val acc: 97.28%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 26 - Train loss: 0.0899, Train acc: 97.46% | Val loss: 0.1365, Val acc: 96.57%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 27 - Train loss: 0.0649, Train acc: 98.03% | Val loss: 0.1608, Val acc: 96.42%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 28 - Train loss: 0.0523, Train acc: 98.45% | Val loss: 0.1366, Val acc: 96.95%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 29 - Train loss: 0.0471, Train acc: 98.60% | Val loss: 0.1274, Val acc: 96.99%


train:   0%|          | 0/469 [00:00<?, ?it/s]

Epoch 30 - Train loss: 0.0387, Train acc: 98.86% | Val loss: 0.1760, Val acc: 96.65%
