# 2. Baseline Pretrained Model Training

**Student:** Student 1

## Purpose
- Load pretrained vision model (ResNet, EfficientNet, VGG, etc.)
- Implement transfer learning strategy
- Train baseline model with default hyperparameters
- Log metrics to TensorBoard
- Save model checkpoint
- Analyze baseline performance

In [None]:
%run ./01_data_exploration_preprocessing.ipynb

In [None]:
# TODO: Import libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import models, transforms
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
import numpy as np
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, confusion_matrix
import os
from tqdm import tqdm

In [None]:
# TODO: Load DataLoaders from notebook 01
# Assuming you saved DataLoaders or can import them
# Example: from data_exploration_preprocessing import train_loader, val_loader, test_loader
# If not, recreate datasets and loaders with the same preprocessing pipeline


num_classes = 10  # Fashion-MNIST has 10 classes
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)


In [None]:
# TODO: Load pretrained model
resnet18 = models.resnet18(weights=models.ResNet18_Weights.IMAGENET1K_V1)
resnet18 = resnet18.to(device)
print(resnet18)


In [None]:
# TODO: Implement transfer learning strategy

# Strategy:
# Freeze all convolutional layers (feature extractor)
for param in resnet18.parameters():
    param.requires_grad = False

# Replace final fully connected layer
resnet18.fc = nn.Linear(resnet18.fc.in_features, num_classes)
resnet18.fc.requires_grad = True  # only this layer is trainable

# Move model to device
resnet18 = resnet18.to(device)

print("Transfer learning strategy applied: frozen conv layers, fine-tune FC layer")


In [None]:
# TODO: Set up baseline hyperparameters
learning_rate = 1e-3
batch_size = 64
num_epochs = 10  # choose a value in 5-15 as recommended
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(resnet18.fc.parameters(), lr=learning_rate)

print(f"Hyperparameters: lr={learning_rate}, batch_size={batch_size}, epochs={num_epochs}")


In [None]:
# TODO: Initialize TensorBoard logger
log_dir = "runs/baseline_resnet18"
os.makedirs(log_dir, exist_ok=True)
writer = SummaryWriter(log_dir=log_dir)

# Log hyperparameters
writer.add_text("Hyperparameters", f"lr={learning_rate}, batch_size={batch_size}, epochs={num_epochs}")


In [None]:
# TODO: Implement training loop

def train_one_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for images, labels in tqdm(loader, desc="Training"):
        images, labels = images.to(device), labels.to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        running_loss += loss.item() * images.size(0)
        _, preds = torch.max(outputs, 1)
        correct += (preds == labels).sum().item()
        total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc


In [None]:
# TODO: Implement validation loop

def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in tqdm(loader, desc="Validation"):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)

            running_loss += loss.item() * images.size(0)
            _, preds = torch.max(outputs, 1)
            correct += (preds == labels).sum().item()
            total += labels.size(0)

    epoch_loss = running_loss / total
    epoch_acc = correct / total
    return epoch_loss, epoch_acc


In [None]:
# TODO: Train baseline model

train_losses, val_losses = [], []
train_accs, val_accs = [], []

save_dir = "saved_models/baseline_pretrained"
os.makedirs(save_dir, exist_ok=True)

for epoch in range(num_epochs):
    print(f"Epoch {epoch + 1}/{num_epochs}")

    train_loss, train_acc = train_one_epoch(resnet18, train_loader, criterion, optimizer, device)
    val_loss, val_acc = validate(resnet18, val_loader, criterion, device)

    train_losses.append(train_loss)
    val_losses.append(val_loss)
    train_accs.append(train_acc)
    val_accs.append(val_acc)

    # Log to TensorBoard
    writer.add_scalars("Loss", {"train": train_loss, "val": val_loss}, epoch)
    writer.add_scalars("Accuracy", {"train": train_acc, "val": val_acc}, epoch)

    # Save checkpoint
    checkpoint_path = os.path.join(save_dir, f"model_epoch_{epoch + 1}.pt")
    torch.save(resnet18.state_dict(), checkpoint_path)

    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}")


In [None]:
# TODO: Evaluate baseline on test set
resnet18.eval()
all_preds, all_labels = [], []

with torch.no_grad():
    for images, labels in tqdm(test_loader, desc="Testing"):
        images, labels = images.to(device), labels.to(device)
        outputs = resnet18(images)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

accuracy = accuracy_score(all_labels, all_preds)
precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average="weighted")
cm = confusion_matrix(all_labels, all_preds)

print(f"Test Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1-score: {f1:.4f}")
print("Confusion Matrix:\n", cm)


In [None]:
# TODO: Plot training curves
epochs = np.arange(1, num_epochs + 1)

plt.figure(figsize=(12, 5))
plt.subplot(1, 2, 1)
plt.plot(epochs, train_losses, label="Train Loss")
plt.plot(epochs, val_losses, label="Validation Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.title("Loss curves")

plt.subplot(1, 2, 2)
plt.plot(epochs, train_accs, label="Train Acc")
plt.plot(epochs, val_accs, label="Validation Acc")
plt.xlabel("Epoch")
plt.ylabel("Accuracy")
plt.legend()
plt.title("Accuracy curves")

plt.show()


In [None]:
# TODO: Write baseline performance analysis (example)
"""
The ResNet18 baseline achieves good accuracy on Fashion-MNIST after fine-tuning only the final layer.
Training and validation loss curves indicate the model converges quickly without significant overfitting.
The confusion matrix shows that the model performs slightly worse on visually similar classes (e.g., shirts vs. tops),
but overall classification is strong for most categories.

Since only the final fully connected layer was fine-tuned, the feature extractor from ImageNet transfers well.
Future experiments could include unfreezing some deeper layers or adding mild data augmentation to further improve accuracy.
"""


In [None]:
# TODO: Save baseline model checkpoint
final_checkpoint = os.path.join(save_dir, "model_checkpoint.pt")
torch.save(resnet18.state_dict(), final_checkpoint)
print(f"Baseline ResNet18 checkpoint saved to {final_checkpoint}")
