In [3]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchmetrics
from torchvision import datasets, models
from torchvision import transforms as T
from torch.utils.data import DataLoader
from tqdm import tqdm
from torch.utils.tensorboard import SummaryWriter
from torchsummary import summary

import matplotlib.pyplot as plt
import numpy as np

## CNN Model Structure

In [2]:
# Define custom CNN class
class SimpleCNN(nn.Module):
    def __init__(self, num_classes=10):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(64 * 8 * 8, 128)
        self.fc2 = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.pool(torch.relu(self.conv1(x)))
        x = self.pool(torch.relu(self.conv2(x)))
        x = x.view(-1, 64 * 8 * 8)
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x


# Define function for loading pre-trained model (ResNet-18)
def load_resnet18(num_classes=10):
    model = models.resnet18(pretrained=True)
    model.fc = nn.Linear(model.fc.in_features, num_classes)
    return model

## Utils

### 1. Helper Classes: `AverageMeter` class for tracking loss and other metrics


In [10]:
class AverageMeter:
    def __init__(self):
        self.reset()

    def reset(self):
        self.sum = 0
        self.count = 0

    def update(self, value, n=1):
        self.sum += value * n
        self.count += n

    @property
    def avg(self):
        return self.sum / self.count if self.count > 0 else 0

## Trasform and load data

In [5]:
# Advanced preprocessing and data augmentation
transform_train = T.Compose([T.RandomHorizontalFlip(),
                             T.RandomRotation(10),
                             T.RandomResizedCrop(32, scale=(0.8, 1.0)),
                             T.ToTensor(),
                             T.Normalize((0.5, 0.5, 0.5), 
                                         (0.5, 0.5, 0.5))])

transform_val = T.Compose([T.ToTensor(),
                           T.Normalize((0.5, 0.5, 0.5), 
                                       (0.5, 0.5, 0.5))])

In [None]:
# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10(root="./data", train=True, transform=transform_train, download=True)
val_dataset = datasets.CIFAR10(root="./data", train=False, transform=transform_val, download=True)

train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64, shuffle=False)

## Model

In [None]:
# Choose model (custom CNN or ResNet-18)
model_type = "custom"  # change to "resnet" for ResNet-18
num_classes = 10
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

if model_type == "custom":
    model = SimpleCNN(num_classes=num_classes).to(device)
    summary(model, (3, 32, 32))
else:
    model = load_resnet18(num_classes=num_classes).to(device)
    summary(model, (3, 32, 32))


# Define criterion, optimizer, and metrics
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
train_accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes).to(device)
val_accuracy = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes).to(device)

## Training and Validation Functions


In [None]:
# Training and validation functions
def train_one_epoch(model, dataloader, criterion, optimizer, device, epoch, accuracy_metric):
    model.train()
    loss_meter = AverageMeter()
    accuracy_metric.reset()
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1} [Training]", leave=False)
    
    for X_batch, y_batch in progress_bar:
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)

        outputs = model(X_batch)
        loss = criterion(outputs, y_batch)
        
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # Update loss and accuracy
        loss_meter.update(loss.item(), X_batch.size(0))
        preds = outputs.argmax(dim=1)
        accuracy_metric.update(preds, y_batch)

        progress_bar.set_postfix(loss=loss_meter.avg, accuracy=accuracy_metric.compute().item())
        
    avg_loss = loss_meter.avg
    avg_accuracy = accuracy_metric.compute().item()
    
    return avg_loss, avg_accuracy

In [None]:
def validate(model, dataloader, criterion, device, epoch, accuracy_metric):
    model.eval()
    loss_meter = AverageMeter()
    accuracy_metric.reset()
    progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1} [Validation]", leave=False)
    
    with torch.no_grad():
        for X_batch, y_batch in progress_bar:
            X_batch, y_batch = X_batch.to(device), y_batch.to(device)

            outputs = model(X_batch)
            loss = criterion(outputs, y_batch)

            # Update loss and accuracy
            loss_meter.update(loss.item(), X_batch.size(0))
            preds = outputs.argmax(dim=1)
            accuracy_metric.update(preds, y_batch)

            progress_bar.set_postfix(loss=loss_meter.avg)
        
    avg_loss = loss_meter.avg
    avg_accuracy = accuracy_metric.compute().item()
    
    return avg_loss, avg_accuracy

## Training Script

In [None]:
# Initialize TensorBoard  
writer = SummaryWriter()  

# Training loop  
num_epochs = 20  
best_val_acc = 0.0  
for epoch in range(num_epochs):  
    train_loss, train_acc = train_one_epoch(model, train_loader, criterion, optimizer, device, epoch, train_accuracy)  
    val_loss, val_acc = validate(model, val_loader, criterion, device, epoch, val_accuracy)  

    # Log metrics to TensorBoard  
    writer.add_scalar('Loss/Train', train_loss, epoch)  
    writer.add_scalar('Accuracy/Train', train_acc, epoch)  
    writer.add_scalar('Loss/Validation', val_loss, epoch)  
    writer.add_scalar('Accuracy/Validation', val_acc, epoch)  

    print(f"Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Train Accuracy: {train_acc:.4f}, "  
          f"Val Loss: {val_loss:.4f}, Val Accuracy: {val_acc:.4f}")  

    # Save best model  
    if val_acc > best_val_acc:  
        best_val_acc = val_acc  
        torch.save(model.state_dict(), "best_model.pth")  
        print(f"Best model saved at epoch {epoch+1} with validation accuracy: {best_val_acc:.4f}")  

# Close TensorBoard writer  
writer.close()

In [None]:
# Plot training and validation loss and accuracy
def plot_metrics(metric_values, title, xlabel, ylabel):
    plt.figure(figsize=(10, 5))
    plt.plot(metric_values, label=f"{ylabel}")
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.legend()
    plt.show()

# Data visualization function
def show_sample_predictions(model, dataloader, class_names, device):
    model.eval()
    with torch.no_grad():
        X_batch, y_batch = next(iter(dataloader))
        X_batch, y_batch = X_batch.to(device), y_batch.to(device)
        
        outputs = model(X_batch)
        _, preds = torch.max(outputs, 1)

        plt.figure(figsize=(12, 8))
        for i in range(8):
            plt.subplot(2, 4, i + 1)
            plt.imshow(np.transpose(X_batch[i].cpu().numpy(), (1, 2, 0)) * 0.5 + 0.5)  # Un-normalize for display
            plt.title(f"True: {class_names[y_batch[i]]}, Pred: {class_names[preds[i]]}")
            plt.axis("off")
        plt.show()

# Example usage: show predictions for a batch of validation data
class_names = train_dataset.classes  # CIFAR-10 class names
show_sample_predictions(model, val_loader, class_names, device)

In [None]:
%load_ext tensorboard 
%tensorboard --logdir runs