# Transfer Learning for Fashion-MNIST with ResNet50

This notebook implements transfer learning for classifying Fashion-MNIST images using a pretrained ResNet50 model. The dataset is in CSV format (`fashion-mnist_train.csv`, `fashion-mnist_test.csv`), with each row containing a `label` (0–9) and 784 pixel values (`pixel1`–`pixel784`) for a 28×28 grayscale image. We adapt ResNet50, train a custom head, fine-tune deeper layers, and visualize results.

## Objectives
- Load and validate the CSV dataset.
- Analyze dataset characteristics (class distribution, pixel stats, sample images).
- Preprocess images (resize to 224×224, convert to 3 channels, augment).
- Train a custom head on pretrained ResNet50, then fine-tune deeper layers.
- Display training and validation metrics per epoch for explainability.
- Visualize training progress and performance with plots.
- Optimize for a 16 GB GPU.

## Hardware
- **GPU**: 16 GB (e.g., NVIDIA RTX 3060).
- **Optimizations**: Mixed precision training, batch size of 32, `pin_memory=True`, selective layer unfreezing.

## Why Per-Epoch Metrics?
Displaying training and validation loss, accuracy, and GPU memory usage per epoch:
- **Explainability**: Shows model convergence and performance trends in real-time.
- **Debugging**: Helps detect overfitting (large train-val accuracy gap), underfitting (high loss), or memory issues.
- **Optimization**: Confirms GPU usage stays within 16 GB (~8–10 GB expected).

## Step 1: Setup and Imports

We install dependencies, import libraries, and configure the environment. Logging captures metrics and GPU memory usage, and random seeds ensure reproducibility.

In [None]:
# Install dependencies (run if needed)
!pip install torch torchvision pandas numpy matplotlib seaborn scikit-learn pillow

# Import libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import Dataset, DataLoader
from torch.cuda.amp import GradScaler, autocast
import torchvision.transforms as transforms
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix
import logging
import os
from datetime import datetime
from torchvision import models
from PIL import Image

# Set up logging
logging.basicConfig(
    filename=f'fashion_mnist_training_{datetime.now().strftime("%Y%m%d_%H%M%S")}.log',
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s'
)

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logging.info(f"Using device: {device}")

# Log initial GPU memory
if device.type == "cuda":
    logging.info(f"GPU Memory Allocated: {torch.cuda.memory_allocated(device)/1e6:.2f} MB")
    logging.info(f"GPU Memory Cached: {torch.cuda.memory_reserved(device)/1e6:.2f} MB")
    print(f"Initial GPU Memory Allocated: {torch.cuda.memory_allocated(device)/1e6:.2f} MB")

# Class names for Fashion-MNIST
class_names = [
    'T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot'
]

# Enable inline plotting
%matplotlib inline

## Step 2: Dataset Loading and Validation

We define a `FashionMNISTCSVDataset` class to load the CSV files (785 columns: 1 `label`, 784 pixels). Validation ensures the correct format, preventing errors. A sample of the data is displayed to verify structure.

In [None]:
class FashionMNISTCSVDataset(Dataset):
    def __init__(self, csv_file, transform=None):
        self.data = pd.read_csv(csv_file)
        # Validate dataset format
        expected_columns = 785  # 1 label + 784 pixels
        if len(self.data.columns) != expected_columns:
            raise ValueError(f"Expected {expected_columns} columns, got {len(self.data.columns)}")
        if 'label' not in self.data.columns:
            raise ValueError("CSV must contain a 'label' column")
        pixel_cols = [col for col in self.data.columns if col.startswith('pixel')]
        if len(pixel_cols) != 784:
            raise ValueError(f"Expected 784 pixel columns, got {len(pixel_cols)}")
        pixel_values = self.data[pixel_cols].values
        if pixel_values.min() < 0 or pixel_values.max() > 255:
            raise ValueError("Pixel values must be in range [0, 255]")
        
        self.transform = transform
        self.labels = self.data['label'].values
        self.images = self.data[pixel_cols].values.reshape(-1, 28, 28).astype(np.uint8)
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        image = Image.fromarray(image, mode='L')  # Grayscale image
        if self.transform:
            image = self.transform(image)
        return image, label

# Check CSV files exist
train_csv = 'fashion-mnist_train.csv'
test_csv = 'fashion-mnist_test.csv'
if not (os.path.exists(train_csv) and os.path.exists(test_csv)):
    raise FileNotFoundError("CSV files not found. Please check the file paths.")

# Display sample data
sample_data = pd.read_csv(train_csv, nrows=5)
print("Sample of training data (first 5 rows, first 10 columns):")
print(sample_data.iloc[:, :10])  # Show label and first 9 pixels
print(f"Training set size: {len(pd.read_csv(train_csv))} images")
print(f"Test set size: {len(pd.read_csv(test_csv))} images")

## Step 3: Data Analysis

We analyze the training dataset to understand its properties:
- **Class Distribution**: Plots a histogram to check if classes are balanced (~6,000 images per class).
- **Pixel Statistics**: Computes mean and std to inform normalization.
- **Sample Images**: Visualizes 10 images to confirm data integrity.

These plots provide insights into the dataset’s structure and quality.

In [None]:
def analyze_dataset(train_csv):
    data = pd.read_csv(train_csv)
    pixel_cols = [col for col in data.columns if col.startswith('pixel')]
    labels = data['label'].values
    images = data[pixel_cols].values.reshape(-1, 28, 28)
    
    # Plot class distribution
    plt.figure(figsize=(10, 6))
    sns.countplot(x=labels)
    plt.title('Class Distribution in Fashion-MNIST Training Set')
    plt.xticks(ticks=range(10), labels=class_names, rotation=45)
    plt.xlabel('Class')
    plt.ylabel('Count')
    plt.show()
    logging.info("Class distribution plot displayed")
    
    # Compute pixel statistics
    mean_pixel = images.mean() / 255.0
    std_pixel = images.std() / 255.0
    logging.info(f"Pixel mean: {mean_pixel:.4f}, Pixel std: {std_pixel:.4f}")
    print(f"Pixel Mean: {mean_pixel:.4f}, Pixel Std: {std_pixel:.4f}")
    
    # Display sample images
    fig, axes = plt.subplots(2, 5, figsize=(12, 5))
    for i, ax in enumerate(axes.flat):
        ax.imshow(images[i], cmap='gray')
        ax.set_title(class_names[labels[i]])
        ax.axis('off')
    plt.tight_layout()
    plt.show()
    logging.info("Sample images displayed")

# Run analysis
analyze_dataset(train_csv)

## Step 4: Data Pipeline

We create data loaders to preprocess images:
- **Resize**: Scale 28×28 images to 224×224 for ResNet50.
- **Channel Duplication**: Convert grayscale to 3 channels using `Grayscale(num_output_channels=3)`.
- **Augmentation**: Apply random horizontal flips to improve generalization.
- **Normalization**: Use ImageNet stats (`mean=[0.485], std=[0.229]`) for pretrained weights.
- **Batch Size**: 32 to ensure GPU memory usage stays ~8–10 GB.
- **Optimization**: `pin_memory=True` and `num_workers=1` for efficient GPU data transfers.

We verify the batch shape to confirm preprocessing.

In [None]:
def get_data_loaders(train_csv, test_csv, batch_size=32):
    train_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.Grayscale(num_output_channels=3),
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.485, 0.485], std=[0.229, 0.229, 0.229])
    ])
    
    test_transform = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.Grayscale(num_output_channels=3),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.485, 0.485], std=[0.229, 0.229, 0.229])
    ])
    
    train_dataset = FashionMNISTCSVDataset(train_csv, transform=train_transform)
    test_dataset = FashionMNISTCSVDataset(test_csv, transform=test_transform)
    
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=1, pin_memory=True
    )
    test_loader = DataLoader(
        test_dataset, batch_size=batch_size, shuffle=False, num_workers=1, pin_memory=True
    )
    
    logging.info(f"Data loaders created with batch size {batch_size}")
    return train_loader, test_loader

# Create data loaders
train_loader, test_loader = get_data_loaders(train_csv, test_csv, batch_size=32)

# Verify a batch
images, labels = next(iter(train_loader))
print(f"Batch shape: {images.shape}")  # Expect [32, 3, 224, 224]
print(f"Label shape: {labels.shape}")  # Expect [32]

## Step 5: Model Setup

We load a pretrained ResNet50, freeze its backbone to save GPU memory, and replace the fully connected layer with a custom head:
- **Input**: 2048 features from ResNet50’s final layer.
- **Head**: Linear(2048, 512) → ReLU → Dropout(0.5) → Linear(512, 10).
- **GPU**: Model is moved to GPU with mixed precision enabled for efficiency.

The head architecture is displayed for clarity.

In [None]:
def get_model(num_classes=10):
    model = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
    for param in model.parameters():
        param.requires_grad = False
    in_features = model.fc.in_features
    model.fc = nn.Sequential(
        nn.Linear(in_features, 512),
        nn.ReLU(),
        nn.Dropout(0.5),
        nn.Linear(512, num_classes)
    )
    model = model.to(device)
    logging.info("Model initialized with pretrained ResNet50 and custom head")
    return model

# Initialize model
model = get_model()
print("Custom head architecture:")
print(model.fc)

## Step 6: Head-Only Training

We train only the custom head for up to 10 epochs, keeping the ResNet50 backbone frozen:
- **Optimizer**: Adam with learning rate 0.001, weight decay 1e-4.
- **Loss**: Cross-entropy for multi-class classification.
- **Scheduler**: ReduceLROnPlateau reduces learning rate if validation loss plateaus for 3 epochs.
- **Early Stopping**: Stops if validation accuracy doesn’t improve for 5 epochs.
- **Mixed Precision**: Reduces GPU memory usage to ~8–10 GB using `torch.cuda.amp`.
- **Metrics Display**: Prints training loss, training accuracy, validation loss, validation accuracy, and GPU memory per epoch.

### What to Monitor
- **Training Loss**: Should decrease steadily.
- **Validation Loss**: Should decrease but may rise if overfitting.
- **Accuracy Gap**: Large gap (train > val) suggests overfitting.
- **GPU Memory**: Should stay ~8–10 GB, well within 16 GB.

After training, we plot loss and accuracy curves to visualize convergence.

In [None]:
def train_epoch(model, loader, criterion, optimizer, scaler):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    for images, labels in loader:
        images, labels = images.to(device), labels.to(device)
        optimizer.zero_grad()
        with autocast():
            outputs = model(images)
            loss = criterion(outputs, labels)
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        running_loss += loss.item()
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
    epoch_loss = running_loss / len(loader)
    epoch_acc = 100 * correct / total
    return epoch_loss, epoch_acc

def evaluate(model, loader, criterion):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    with torch.no_grad():
        for images, labels in loader:
            images, labels = images.to(device), labels.to(device)
            with autocast():
                outputs = model(images)
                loss = criterion(outputs, labels)
            running_loss += loss.item()
            _, predicted = torch.max(outputs, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    val_loss = running_loss / len(loader)
    val_acc = 100 * correct / total
    return val_loss, val_acc, all_preds, all_labels

def train_model(model, train_loader, test_loader, num_epochs=10, fine_tune=False):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=0.001 if not fine_tune else 0.0001, weight_decay=1e-4)
    scaler = GradScaler()
    scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=3, verbose=True)
    best_val_acc = 0.0
    patience = 5
    patience_counter = 0
    train_losses, train_accs = [], []
    val_losses, val_accs = [], []
    
    # Print header for metrics
    print(f"{'Epoch':<6} {'Train Loss':<12} {'Train Acc':<12} {'Val Loss':<12} {'Val Acc':<12} {'GPU Memory (MB)':<15}")
    print("-" * 67)
    
    for epoch in range(num_epochs):
        train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, scaler)
        val_loss, val_acc, preds, labels = evaluate(model, test_loader, criterion)
        scheduler.step(val_loss)
        train_losses.append(train_loss)
        train_accs.append(train_acc)
        val_losses.append(val_loss)
        val_accs.append(val_acc)
        
        # Display metrics
        gpu_memory = torch.cuda.memory_allocated(device)/1e6 if device.type == "cuda" else 0.0
        print(f"{epoch+1:<6} {train_loss:<12.4f} {train_acc:<12.2f} {val_loss:<12.4f} {val_acc:<12.2f} {gpu_memory:<15.2f}")
        
        # Log metrics
        logging.info(f"Epoch {epoch+1}/{num_epochs} - Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.2f}%")
        logging.info(f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.2f}%")
        if device.type == "cuda":
            logging.info(f"GPU Memory Allocated: {gpu_memory:.2f} MB")
        
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), 'best_model_head.pth' if not fine_tune else 'best_model_finetune.pth')
            logging.info("Best model saved")
            patience_counter = 0
        else:
            patience_counter += 1
        
        if patience_counter >= patience:
            logging.info("Early stopping triggered")
            print("Early stopping triggered")
            break
    
    # Plot training curves
    plt.figure(figsize=(12, 4))
    plt.subplot(1, 2, 1)
    plt.plot(train_losses, label='Train Loss')
    plt.plot(val_losses, label='Val Loss')
    plt.title(f"Loss Curves ({'Fine-Tuning' if fine_tune else 'Head-Only'})")
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend()
    plt.subplot(1, 2, 2)
    plt.plot(train_accs, label='Train Accuracy')
    plt.plot(val_accs, label='Val Accuracy')
    plt.title(f"Accuracy Curves ({'Fine-Tuning' if fine_tune else 'Head-Only'})")
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy (%)')
    plt.legend()
    plt.tight_layout()
    plt.show()
    logging.info(f"Training curves plotted ({'fine-tuning' if fine_tune else 'head-only'})")
    
    return train_losses, train_accs, val_losses, val_accs, preds, labels

# Train head only
if device.type == "cuda":
    torch.cuda.empty_cache()
logging.info("Starting head-only training")
print("\nHead-Only Training:")
head_train_losses, head_train_accs, head_val_losses, head_val_accs, head_preds, head_labels = train_model(
    model, train_loader, test_loader, num_epochs=10, fine_tune=False
)

## Step 7: Confusion Matrix for Head-Only Training

We plot a confusion matrix to analyze the model’s performance on the test set after head-only training. This visualizes which classes are correctly classified and which are confused (e.g., Shirt vs. T-shirt/top).

In [None]:
def plot_confusion_matrix(preds, labels, title):
    cm = confusion_matrix(labels, preds)
    plt.figure(figsize=(10, 8))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=class_names, yticklabels=class_names)
    plt.title(title)
    plt.xlabel('Predicted')
    plt.ylabel('True')
    plt.show()
    logging.info(f"Confusion matrix plotted: {title}")

# Plot confusion matrix
plot_confusion_matrix(head_preds, head_labels, 'Confusion Matrix (Head-Only Training)')

## Step 8: Fine-Tuning

We load the best model from head-only training, unfreeze the `layer4` block of ResNet50, and fine-tune for up to 10 epochs with a lower learning rate (0.0001):
- **Purpose**: Adapt deeper features to Fashion-MNIST while preserving pretrained weights.
- **Metrics Display**: Prints the same per-epoch metrics (training/validation loss, accuracy, GPU memory) as head-only training.
- **Optimization**: Uses mixed precision and selective unfreezing to keep memory ~8–10 GB.

### What to Monitor
- **Loss Improvement**: Expect validation loss to decrease further compared to head-only training.
- **Accuracy Gain**: Validation accuracy should improve (e.g., from ~85–90% to ~90–93%).
- **Overfitting**: Watch for increasing validation loss or a larger train-val accuracy gap.
- **GPU Memory**: Should remain stable, slightly higher than head-only due to more trainable parameters.

Training curves are plotted to compare with head-only results.

In [None]:
def fine_tune_model(model):
    for param in model.layer4.parameters():
        param.requires_grad = True
    logging.info("Unfrozen layer4 for fine-tuning")
    return model

# Load best head-only model and fine-tune
model.load_state_dict(torch.load('best_model_head.pth'))
model = fine_tune_model(model)
if device.type == "cuda":
    torch.cuda.empty_cache()
logging.info("Starting fine-tuning")
print("\nFine-Tuning:")
finetune_train_losses, finetune_train_accs, finetune_val_losses, finetune_val_accs, finetune_preds, finetune_labels = train_model(
    model, train_loader, test_loader, num_epochs=10, fine_tune=True
)

## Step 9: Confusion Matrix for Fine-Tuning

We plot a confusion matrix for the fine-tuned model to assess improvements in classification performance compared to head-only training.

In [None]:
plot_confusion_matrix(finetune_preds, finetune_labels, 'Confusion Matrix (Fine-Tuning)')

## Step 10: Summary and Experimentation

### Results
- **Head-Only Training**:
  - Expected validation accuracy: ~85–90%.
  - Only the custom head was trained, leveraging pretrained features.
- **Fine-Tuning**:
  - Expected validation accuracy: ~90–93%.
  - Improved by adapting `layer4` features to Fashion-MNIST.
- **GPU Usage**:
  - Peak memory: ~8–10 GB, optimized for 16 GB GPU with mixed precision, batch size 32, and selective unfreezing.
  - Verified via per-epoch memory display and logs.

### Visualizations
- **Class Distribution**: Confirms balanced classes (~6,000 per class).
- **Sample Images**: Verifies data integrity and class representation.
- **Training Curves**: Shows loss and accuracy trends for head-only and fine-tuning phases.
- **Confusion Matrices**: Highlights misclassifications and improvements post-fine-tuning.

### Per-Epoch Metrics
The displayed metrics (training/validation loss, accuracy, GPU memory) per epoch provide real-time insights:
- **Loss Trends**: Decreasing loss indicates learning; divergence suggests overfitting.
- **Accuracy Trends**: Rising accuracy shows improvement; large train-val gaps indicate overfitting.
- **GPU Memory**: Stable usage (~8–10 GB) confirms optimization.

Check the log file (`fashion_mnist_training_YYYYMMDD_HHMMSS.log`) for detailed records.

### Experimentation Ideas
Based on observed metrics, try the following:
- **Augmentation**: Add `transforms.RandomRotation(10)` to `train_transform` if underfitting (low accuracy):
  ```python
  train_transform = transforms.Compose([
      transforms.Resize((224, 224)),
      transforms.Grayscale(num_output_channels=3),
      transforms.RandomHorizontalFlip(p=0.5),
      transforms.RandomRotation(10),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.485, 0.485], std=[0.229, 0.229, 0.229])
  ])
  ```
- **Dropout**: Reduce to 0.3 or increase to 0.7 in `get_model` if overfitting (large train-val gap):
  ```python
  model.fc = nn.Sequential(
      nn.Linear(in_features, 512),
      nn.ReLU(),
      nn.Dropout(0.3),  # Adjust here
      nn.Linear(512, num_classes)
  )
  ```
- **Weight Decay**: Test 5e-4 or 1e-3 in `train_model` to reduce overfitting:
  ```python
  optimizer = optim.Adam(model.parameters(), lr=0.001 if not fine_tune else 0.0001, weight_decay=5e-4)
  ```
- **Image Size**: Use 112×112 to save memory if GPU usage is high:
  ```python
  transforms.Resize((112, 112))
  ```
- **Batch Size**: Reduce to 16 if memory errors occur:
  ```python
  train_loader, test_loader = get_data_loaders(train_csv, test_csv, batch_size=16)
  ```

### Debugging Tips
- **High Validation Loss**: Increase `patience` in `train_model` or adjust learning rate.
- **Low Accuracy**: Train for more epochs (`num_epochs=15`) or add augmentation.
- **Memory Issues**: Check GPU memory in per-epoch display; reduce batch size or image size if >12 GB.

The best models are saved as `best_model_head.pth` and `best_model_finetune.pth` for further evaluation or deployment.