Install dependencies and imports

In [None]:
!pip install thop
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as transforms
import time
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import confusion_matrix, classification_report, precision_recall_fscore_support, roc_curve, auc
from thop import profile
import pandas as pd

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

Hyperparameters

In [None]:
# Hyperparameters
num_epochs = 50
batch_size = 128
learning_rate = 0.1
momentum = 0.9
weight_decay = 5e-4

Define BasicBlock for ResNet

In [None]:
# Basic block for ResNet
class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_channels, out_channels, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(out_channels)

        # Shortcut connection
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = F.relu(out)
        return out

ResNet Architecture


In [None]:
# ResNet Architecture
class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 64

        # Initial convolution before the blocks
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)

        # ResNet blocks
        self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)

        # Final classification layer
        self.linear = nn.Linear(512 * block.expansion, num_classes)

    def _make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = F.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

# ResNet-18
def ResNet18():
    return ResNet(BasicBlock, [2, 2, 2, 2])

Data Loading and Preprocessing


In [None]:
# Load and normalize CIFAR10
print("Preparing data...")
transform_train = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=2)

classes = ('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')


Model Setup


In [None]:
# Model setup
model = ResNet18().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=momentum, weight_decay=weight_decay)
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=5, factor=0.5, verbose=True)

# Calculate parameters and FLOPs
print("Calculating model parameters and FLOPs...")
dummy_input = torch.randn(1, 3, 32, 32).to(device)
flops, params = profile(model, inputs=(dummy_input,))
print(f"Number of parameters: {params:,}")
print(f"Number of FLOPs: {flops:,}")

# Lists to store metrics
train_losses = []
train_accuracies = []
test_losses = []
test_accuracies = []
epoch_times = []
batch_times = []
learning_rates = []


Training Function


In [None]:
# Training function
def train(epoch):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    epoch_start_time = time.time()
    batch_times_epoch = []

    for batch_idx, (inputs, targets) in enumerate(trainloader):
        batch_start_time = time.time()
        inputs, targets = inputs.to(device), targets.to(device)

        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        _, predicted = outputs.max(1)
        total += targets.size(0)
        correct += predicted.eq(targets).sum().item()

        batch_end_time = time.time()
        batch_times_epoch.append(batch_end_time - batch_start_time)

        if batch_idx % 100 == 99:    # Print every 100 mini-batches
            print(f'Epoch: {epoch+1}, Batch: {batch_idx+1}, Loss: {running_loss/100:.3f}, Acc: {100.*correct/total:.3f}%')
            running_loss = 0.0

    epoch_end_time = time.time()
    epoch_time = epoch_end_time - epoch_start_time
    avg_batch_time = sum(batch_times_epoch) / len(batch_times_epoch)

    train_loss = running_loss / len(trainloader)
    train_accuracy = 100. * correct / total

    print(f'Epoch {epoch+1} complete. Time taken: {epoch_time:.2f}s, Avg Batch Time: {avg_batch_time:.4f}s')

    return train_loss, train_accuracy, epoch_time, avg_batch_time


Testing Function


In [None]:
# Testing function
def test(epoch):
    model.eval()
    test_loss = 0
    correct = 0
    total = 0
    all_preds = []
    all_targets = []
    class_probs = {i: [] for i in range(10)}  # For ROC curves

    with torch.no_grad():
        for inputs, targets in testloader:
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets)

            test_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()

            # Store predictions and targets for confusion matrix and per-class metrics
            all_preds.extend(predicted.cpu().numpy())
            all_targets.extend(targets.cpu().numpy())

            # Store probabilities for ROC curve
            probs = F.softmax(outputs, dim=1).cpu().numpy()
            for i in range(10):
                class_probs[i].extend(probs[:, i])

    test_loss = test_loss / len(testloader)
    test_accuracy = 100. * correct / total

    print(f'Test Loss: {test_loss:.3f}, Test Acc: {test_accuracy:.3f}%')

    # Calculate confusion matrix
    cm = confusion_matrix(all_targets, all_preds)

    # Calculate per-class metrics
    precision, recall, f1, _ = precision_recall_fscore_support(all_targets, all_preds, average=None)

    # Calculate ROC curves
    fpr = {}
    tpr = {}
    roc_auc = {}
    for i in range(10):
        # Convert to binary classification problem for each class
        binary_targets = [1 if t == i else 0 for t in all_targets]
        binary_probs = class_probs[i]

        # Calculate ROC curve
        fpr[i], tpr[i], _ = roc_curve(binary_targets, binary_probs)
        roc_auc[i] = auc(fpr[i], tpr[i])

    return test_loss, test_accuracy, cm, precision, recall, f1, fpr, tpr, roc_auc

Main Training Loop

In [None]:
# Main training loop
print("Starting training...")

for epoch in range(num_epochs):
    current_lr = optimizer.param_groups[0]['lr']
    learning_rates.append(current_lr)

    # Train for one epoch
    train_loss, train_accuracy, epoch_time, avg_batch_time = train(epoch)
    train_losses.append(train_loss)
    train_accuracies.append(train_accuracy)
    epoch_times.append(epoch_time)
    batch_times.append(avg_batch_time)

    # Test the model
    test_loss, test_accuracy, cm, precision, recall, f1, fpr, tpr, roc_auc = test(epoch)
    test_losses.append(test_loss)
    test_accuracies.append(test_accuracy)

    # Update learning rate based on validation loss
    scheduler.step(test_loss)

    print(f'Epoch {epoch+1}/{num_epochs} - '
          f'LR: {current_lr:.6f}, '
          f'Train Loss: {train_loss:.4f}, '
          f'Train Acc: {train_accuracy:.2f}%, '
          f'Test Loss: {test_loss:.4f}, '
          f'Test Acc: {test_accuracy:.2f}%')

print("Training complete!")


Plotting Metrics

In [None]:
# Plot metrics
plt.figure(figsize=(20, 15))

# Plot loss curves
plt.subplot(2, 3, 1)
plt.plot(range(1, num_epochs+1), train_losses, label='Train Loss')
plt.plot(range(1, num_epochs+1), test_losses, label='Test Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Epoch vs Loss')
plt.legend()

# Plot accuracy curves
plt.subplot(2, 3, 2)
plt.plot(range(1, num_epochs+1), train_accuracies, label='Train Accuracy')
plt.plot(range(1, num_epochs+1), test_accuracies, label='Test Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.title('Epoch vs Accuracy')
plt.legend()

# Plot confusion matrix (using the last epoch's results)
plt.subplot(2, 3, 3)
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
plt.xlabel('Predicted')
plt.ylabel('True')
plt.title('Confusion Matrix')


More Plots

In [None]:
# Plot per-class metrics (using the last epoch's results)
plt.subplot(2, 3, 4)
x = np.arange(len(classes))
width = 0.2
plt.bar(x - width, precision, width, label='Precision')
plt.bar(x, recall, width, label='Recall')
plt.bar(x + width, f1, width, label='F1-Score')
plt.xlabel('Classes')
plt.ylabel('Score')
plt.title('Performance per Class')
plt.xticks(x, classes, rotation=45)
plt.legend()

# Plot ROC curves (using the last epoch's results)
plt.subplot(2, 3, 5)
for i in range(10):
    plt.plot(fpr[i], tpr[i], label=f'{classes[i]} (AUC = {roc_auc[i]:.2f})')
plt.plot([0, 1], [0, 1], 'k--')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve per Class')
plt.legend(loc="lower right")

# Plot epoch times and learning rate
plt.subplot(2, 3, 6)
fig, ax1 = plt.subplots()
color = 'tab:red'
ax1.set_xlabel('Epochs')
ax1.set_ylabel('Time (s)', color=color)
ax1.plot(range(1, num_epochs+1), epoch_times, color=color, label='Epoch Time')
ax1.tick_params(axis='y', labelcolor=color)

ax2 = ax1.twinx()
color = 'tab:blue'
ax2.set_ylabel('Learning Rate', color=color)
ax2.plot(range(1, num_epochs+1), learning_rates, color=color, label='Learning Rate')
ax2.tick_params(axis='y', labelcolor=color)

plt.title('Epoch Training Time & Learning Rate')
fig.tight_layout()

plt.savefig('resnet18_metrics.png')
plt.show()


Create Statistics Table and Save Model

In [None]:
# Create a table of training statistics
stats_df = pd.DataFrame({
    'Epoch': range(1, num_epochs+1),
    'Train Loss': train_losses,
    'Test Loss': test_losses,
    'Train Accuracy': train_accuracies,
    'Test Accuracy': test_accuracies,
    'Epoch Time (s)': epoch_times,
    'Avg Batch Time (s)': batch_times,
    'Learning Rate': learning_rates
})

print("\nTraining Statistics:")
print(stats_df.to_string(index=False))

# Save the model
torch.save(model.state_dict(), 'resnet18_cifar10.pth')
print("Model saved to 'resnet18_cifar10.pth'")

Final Evaluation Report

In [None]:
# Final evaluation report
print("\nFinal Model Evaluation:")
print(f"Model: ResNet-18")
print(f"Dataset: CIFAR-10")
print(f"Number of parameters: {params:,}")
print(f"Number of FLOPs: {flops:,}")
print(f"Best test accuracy: {max(test_accuracies):.2f}%")
print(f"Final test accuracy: {test_accuracies[-1]:.2f}%")
print(f"Average epoch training time: {sum(epoch_times)/len(epoch_times):.2f}s")
print(f"Average batch processing time: {sum(batch_times)/len(batch_times):.4f}s")
