In [4]:
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms, models

def set_seed(seed: int):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)

set_seed(42)

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.CenterCrop(224),
    transforms.Grayscale(num_output_channels=3),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

batch_size = 64
train_ds = datasets.FashionMNIST(root="./data", train=True, download=True, transform=transform)
test_ds  = datasets.FashionMNIST(root="./data", train=False, download=True, transform=transform)
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True)
test_loader  = DataLoader(test_ds,  batch_size=batch_size, shuffle=False)

if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"Using device: {device}")

results = {}
best_acc = 0.0
best_label = None
best_wts = None

epochs = 10
for base_model in ['resnet18', 'vgg16']:
    for freeze_type in (['full', 'freeze_early', 'freeze_half', 'unfreeze_all'] if base_model == 'resnet18'
                        else ['full', 'freeze_features', 'freeze_half', 'unfreeze_all']):
        print(f"\n→ Training {base_model.upper()} ({freeze_type})")

        if base_model == 'resnet18':
            model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)
            model.fc = nn.Linear(model.fc.in_features, 10)
            if freeze_type == 'full':
                for p in model.parameters(): p.requires_grad = False
                for p in model.fc.parameters(): p.requires_grad = True
            elif freeze_type == 'freeze_early':
                for name, child in model.named_children():
                    if name in ['conv1', 'bn1', 'layer1']:
                        for p in child.parameters(): p.requires_grad = False
            elif freeze_type == 'freeze_half':
                for name, child in model.named_children():
                    if name in ['conv1', 'bn1', 'layer1', 'layer2']:
                        for p in child.parameters(): p.requires_grad = False
        else:
            model = models.vgg16(weights=models.VGG16_Weights.DEFAULT)
            model.classifier[6] = nn.Linear(model.classifier[6].in_features, 10)
            if freeze_type == 'full':
                for p in model.parameters(): p.requires_grad = False
                for p in model.classifier[6].parameters(): p.requires_grad = True
            elif freeze_type == 'freeze_features':
                for p in model.features.parameters(): p.requires_grad = False
            elif freeze_type == 'freeze_half':
                for p in list(model.features.parameters())[:15]:
                    p.requires_grad = False

        model = model.to(device)

        optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-3)
        criterion = nn.CrossEntropyLoss()

        history = []
        for ep in range(1, epochs + 1):
            model.train()
            total_loss = 0.0
            for X, y in train_loader:
                X, y = X.to(device), y.to(device)
                optimizer.zero_grad()
                loss = criterion(model(X), y)
                loss.backward()
                optimizer.step()
                total_loss += loss.item()

            model.eval()
            correct = 0
            total = 0
            with torch.no_grad():
                for X, y in test_loader:
                    X, y = X.to(device), y.to(device)
                    preds = model(X).argmax(dim=1)
                    correct += (preds == y).sum().item()
                    total += y.size(0)
            acc = correct / total
            history.append(acc)
            print(f"[{ep:2d}/{epochs}] Loss: {total_loss:.4f}  Acc: {acc:.4f}")

        label = f"{base_model}_{freeze_type}"
        results[label] = history
        if history[-1] > best_acc:
            best_acc = history[-1]
            best_label = label
            best_wts = model.state_dict()

if best_wts is not None:
    torch.save(best_wts, "FashionMNIST_best_model.pt")
    print(f"\nSaved best model: [{best_label}]  Acc: {best_acc:.4f}")

print("\n===== Training Results =====")
for label, history in results.items():
    print(f"\nModel: {label}")
    for epoch, acc in enumerate(history, 1):
        print(f"  Epoch {epoch:2d}: Accuracy = {acc:.4f}")

print(f"\nBest Model: [{best_label}] with Accuracy = {best_acc:.4f}")


Using device: cuda

→ Training RESNET18 (full)
[ 1/10] Loss: 567.8811  Acc: 0.8439
[ 2/10] Loss: 404.6588  Acc: 0.8578
[ 3/10] Loss: 378.2006  Acc: 0.8506
[ 4/10] Loss: 364.6926  Acc: 0.8654
[ 5/10] Loss: 356.8064  Acc: 0.8637
[ 6/10] Loss: 349.4091  Acc: 0.8580
[ 7/10] Loss: 344.8453  Acc: 0.8654
[ 8/10] Loss: 341.2461  Acc: 0.8651
[ 9/10] Loss: 338.6684  Acc: 0.8670
[10/10] Loss: 335.0125  Acc: 0.8635

→ Training RESNET18 (freeze_early)
[ 1/10] Loss: 271.7766  Acc: 0.9084


KeyboardInterrupt: 