# CIFAR-10 Classification with Pre-activation ResNet (resnetx_cifar10)

Welcome to the demo notebook for the `resnetx_cifar10` project! In this notebook, we'll walk through:

- Loading CIFAR-10 dataset
- Training a custom Pre-activation ResNet (ResNet v2 style)
- Visualizing training progress
- Displaying confusion matrix and prediction samples
- Drawing insights from results

This notebook is fully modular and uses:
- `models.py` for the architecture
- `dataloader_generator.py` for dataset loading
- `utils.py` for training, plotting, and evaluation

> **Goal:** Achieve high accuracy on CIFAR-10 using a simplified and interpretable ResNet architecture.

---

 Let's begin by importing necessary modules and setting up!


In [None]:
# Imports from standard libraries
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt
import pandas as pd

# Project-specific imports
from models import PreActResNet34
from utils import train_model, plot_confusion_matrix, plot_predictions
from dataloader_generator import get_cifar10_dataloaders

# Device configuration
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"Using device: {device}")

# Hyperparameters
batch_size = 100

# Get the train and validation dataloaders + class name mapping
train_dl, valid_dl, class_names_dict = get_cifar10_dataloaders(batch_size=batch_size)

# Quick sanity check
print("Train batches:", len(train_dl))
print("Validation batches:", len(valid_dl))
print("Classes:", class_names_dict)

# Device configuration
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Using device: {device}')

# Initialize the model and move it to device
model = PreActResNet34().to(device)

# Loss function and optimizer
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

# Training parameters
num_epochs = 50

# Train the model
train_loss, train_acc, test_loss, test_acc = train_model(
    model, train_dl, valid_dl, loss_fn, optimizer, num_epochs, device
)

# Plot training history
history = {
    'train_loss': train_loss,
    'train_acc': train_acc,
    'test_loss': test_loss,
    'test_acc': test_acc
}
pd.DataFrame(history).plot(figsize=(10,5))
plt.title("Training and Validation Metrics Over Epochs")
plt.grid(True)
plt.show()

In [None]:
# Plot confusion matrix on validation set
plot_confusion_matrix(model, valid_dl, class_names_dict, device)

# Plot examples of correct and incorrect predictions on training set
plot_predictions(model, train_dl, class_names_dict, device, row=1, col=8, figsize=(15,3), max_size=20)

## Conclusion and Next Steps

- The PreActResNet34 model trained on CIFAR-10 demonstrates solid performance with good accuracy and generalization.
- Data augmentations like random cropping, flipping, and color jittering helped improve robustness.
- Confusion matrix and prediction plots provide insights into class-wise performance and common misclassifications.
- Potential improvements:
  - Experiment with learning rate schedulers or other optimizers.
  - Try deeper variants like PreActResNet50 or 101 for potentially better accuracy.
  - Implement early stopping or checkpointing for more efficient training.
  - Expand to other datasets or tasks to validate model generalization.

Feel free to explore and build upon this foundation!
