# Neural Networks Project: CIFAR-10 Classification with CNN

This notebook demonstrates how to use the neural networks framework to train a CNN model on the CIFAR-10 dataset for image classification.

## Setup

First, let's import the necessary libraries and set up our environment.

In [None]:
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Add the project root to the path
sys.path.append('..')

# Import project modules
from src.models.cnn_model import CNNModel
from src.utils.trainer import Trainer
from src.utils.metrics import MetricsTracker
from src.config.config_manager import ConfigManager, get_default_config

# Set random seed for reproducibility
torch.manual_seed(42)
np.random.seed(42)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## Load and Prepare CIFAR-10 Data

CIFAR-10 is a dataset of 60,000 32x32 color images in 10 classes, with 6,000 images per class. There are 50,000 training images and 10,000 test images.

In [None]:
# Define transformations with data augmentation for training
train_transform = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))  # CIFAR-10 mean and std
])

# Simpler transformation for testing
test_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))
])

# Load CIFAR-10 dataset
train_dataset = datasets.CIFAR10('../data', train=True, download=True, transform=train_transform)
test_dataset = datasets.CIFAR10('../data', train=False, transform=test_transform)

# Create data loaders
batch_size = 128
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)

# Print dataset information
print(f"Training dataset size: {len(train_dataset)}")
print(f"Test dataset size: {len(test_dataset)}")
print(f"Number of classes: {len(train_dataset.classes)}")
print(f"Classes: {train_dataset.classes}")

## Visualize Some Examples from CIFAR-10

In [None]:
# Function to show images
def imshow(img):
    # Unnormalize
    img = img.numpy().transpose((1, 2, 0))
    mean = np.array([0.4914, 0.4822, 0.4465])
    std = np.array([0.2470, 0.2435, 0.2616])
    img = std * img + mean
    img = np.clip(img, 0, 1)
    return img

# Get a batch of training data
examples = iter(train_loader)
example_data, example_targets = next(examples)

# Plot some examples
plt.figure(figsize=(15, 8))
for i in range(12):
    plt.subplot(3, 4, i+1)
    plt.imshow(imshow(example_data[i]))
    plt.title(f"{train_dataset.classes[example_targets[i]]}")
    plt.axis('off')
plt.tight_layout()
plt.show()

## Configure and Create the Model

Let's use our configuration manager to set up the model parameters for CIFAR-10, which has color images (3 channels) and 10 classes.

In [None]:
# Start with the default configuration
config_manager = ConfigManager(default_config=get_default_config())
config = config_manager.get_all()

# Update configuration for CIFAR-10
config_manager.set('model.input_channels', 3)  # CIFAR-10 images are RGB
config_manager.set('model.num_classes', 10)  # 10 classes in CIFAR-10
config_manager.set('cnn.conv_channels', [64, 128, 256])  # Deeper CNN architecture
config_manager.set('cnn.fc_units', [512, 256])  # Larger fully-connected layers
config_manager.set('model.dropout_rate', 0.5)  # Dropout for regularization
config_manager.set('training.num_epochs', 20)  # Number of epochs
config_manager.set('training.learning_rate', 0.001)  # Learning rate
config_manager.set('training.weight_decay', 1e-4)  # Weight decay for regularization

# Create model configuration
model_config = {
    'input_channels': config['model']['input_channels'],
    'num_classes': config['model']['num_classes'],
    'conv_channels': config['cnn']['conv_channels'],
    'fc_units': config['cnn']['fc_units'],
    'dropout_rate': config['model']['dropout_rate']
}

# Create the model
model = CNNModel(model_config)
model = model.to(device)

# Print model summary
print(f"CNN Model created with {model.get_parameter_count():,} trainable parameters")

## Set Up the Trainer

Now let's set up the training configuration and create our trainer.

In [None]:
# Create trainer configuration
trainer_config = {
    'learning_rate': config['training']['learning_rate'],
    'weight_decay': config['training']['weight_decay'],
    'num_epochs': config['training']['num_epochs'],
    'batch_size': batch_size,
    'optimizer': 'adam',  # Use Adam optimizer
    'scheduler': 'cosine',  # Use cosine annealing scheduler
    'criterion': 'cross_entropy',  # Use cross-entropy loss
    'clip_grad_norm': 1.0,  # Clip gradients
    'early_stopping_patience': 5,  # Stop training if no improvement after 5 epochs
    'checkpoint_dir': '../checkpoints/cifar10',  # Directory to save model checkpoints
    'save_best_only': True  # Only save the best model
}

# Create directories if they don't exist
os.makedirs(trainer_config['checkpoint_dir'], exist_ok=True)

# Create the trainer
trainer = Trainer(model, trainer_config, device)

## Train the Model

Now we're ready to train our model on CIFAR-10.

In [None]:
# Start training
print(f"Starting training for {trainer_config['num_epochs']} epochs...")
stats = trainer.train(train_loader, test_loader)

# Print best results
print(f"\nBest validation accuracy: {stats['best_val_acc']:.2f}%")
print(f"Best validation loss: {stats['best_val_loss']:.4f} (epoch {stats['best_epoch']})")

## Visualize Training Results

Let's visualize how the training and validation metrics changed during training.

In [None]:
# Plot training and validation loss/accuracy
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(range(1, len(stats['train_loss']) + 1), stats['train_loss'], label='Training Loss')
plt.plot(range(1, len(stats['val_loss']) + 1), stats['val_loss'], label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training and Validation Loss')
plt.legend()
plt.grid(True)

plt.subplot(1, 2, 2)
plt.plot(range(1, len(stats['train_acc']) + 1), stats['train_acc'], label='Training Accuracy')
plt.plot(range(1, len(stats['val_acc']) + 1), stats['val_acc'], label='Validation Accuracy')
plt.xlabel('Epoch')
plt.ylabel('Accuracy (%)')
plt.title('Training and Validation Accuracy')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.show()

# Plot learning rate changes during training
plt.figure(figsize=(10, 4))
plt.plot(range(1, len(stats['learning_rates']) + 1), stats['learning_rates'])
plt.xlabel('Epoch')
plt.ylabel('Learning Rate')
plt.title('Learning Rate Schedule')
plt.grid(True)
plt.show()

## Load the Best Model and Evaluate on Test Set

In [None]:
# Load the best model
best_model_path = os.path.join(trainer_config['checkpoint_dir'], 'best_model.pt')
model.load(best_model_path)

# Evaluate on test set
print("Evaluating the best model on the test set...")
test_loss, test_acc = trainer.evaluate(test_loader, desc="Test")
print(f"Test Loss: {test_loss:.4f}")
print(f"Test Accuracy: {test_acc:.2f}%")

## Visualize Predictions and Generate Confusion Matrix

In [None]:
# Create a metrics tracker
metrics_tracker = MetricsTracker(task_type='classification', n_classes=10)

# Get a batch of test data
test_examples = iter(test_loader)
test_images, test_labels = next(test_examples)

# Move to device
test_images = test_images.to(device)
test_labels = test_labels.to(device)

# Get predictions
model.eval()
with torch.no_grad():
    outputs = model(test_images)
    probabilities = torch.softmax(outputs, dim=1)
    _, predicted = torch.max(outputs, 1)
    
    # Update metrics tracker
    metrics_tracker.update(test_labels, predicted, probabilities)

# Evaluate the model on the entire test set and track metrics
model.eval()
with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        probabilities = torch.softmax(outputs, dim=1)
        _, predicted = torch.max(outputs, 1)
        
        # Update the metrics tracker
        metrics_tracker.update(labels, predicted, probabilities)

# Print detailed metrics
metrics_tracker.print_metrics()

# Move tensors back to CPU for plotting
test_images = test_images.cpu()
test_labels = test_labels.cpu()
predicted = predicted.cpu()

# Plot images with predictions
plt.figure(figsize=(15, 8))
for i in range(12):
    plt.subplot(3, 4, i+1)
    plt.imshow(imshow(test_images[i]))
    title_color = 'green' if predicted[i] == test_labels[i] else 'red'
    plt.title(f"True: {train_dataset.classes[test_labels[i]]}\nPred: {train_dataset.classes[predicted[i]]}", 
              color=title_color)
    plt.axis('off')
plt.tight_layout()
plt.show()

# Visualize confusion matrix
metrics = metrics_tracker.compute()
confusion_mat = metrics['confusion_matrix']

# Compute normalized confusion matrix
confusion_mat_normalized = confusion_mat.astype('float') / confusion_mat.sum(axis=1)[:, np.newaxis]

plt.figure(figsize=(12, 10))
plt.imshow(confusion_mat_normalized, interpolation='nearest', cmap=plt.cm.Blues)
plt.title('Normalized Confusion Matrix')
plt.colorbar()
tick_marks = np.arange(len(train_dataset.classes))
plt.xticks(tick_marks, train_dataset.classes, rotation=45)
plt.yticks(tick_marks, train_dataset.classes)

# Add text annotations
thresh = confusion_mat_normalized.max() / 2.
for i in range(confusion_mat_normalized.shape[0]):
    for j in range(confusion_mat_normalized.shape[1]):
        plt.text(j, i, format(confusion_mat_normalized[i, j], '.2f'),
                 ha="center", va="center",
                 color="white" if confusion_mat_normalized[i, j] > thresh else "black")

plt.tight_layout()
plt.ylabel('True label')
plt.xlabel('Predicted label')
plt.show()

## Identify the Most Difficult Classes

Let's analyze which classes are the most difficult for our model to classify.

In [None]:
# Calculate per-class accuracy
per_class_accuracy = confusion_mat.diagonal() / confusion_mat.sum(axis=1)

# Create a bar chart showing accuracy for each class
plt.figure(figsize=(12, 6))
classes = train_dataset.classes
plt.bar(range(len(classes)), per_class_accuracy * 100)
plt.xticks(range(len(classes)), classes, rotation=45)
plt.xlabel('Class')
plt.ylabel('Accuracy (%)')
plt.title('Per-Class Accuracy')
plt.grid(axis='y')

# Add accuracy values on top of the bars
for i, v in enumerate(per_class_accuracy):
    plt.text(i, v * 100 + 1, f"{v*100:.1f}%", ha='center')
    
plt.tight_layout()
plt.show()

# Find the most confused class pairs
n_classes = len(classes)
confusion_pairs = []

for i in range(n_classes):
    for j in range(n_classes):
        if i != j:
            # True label i, predicted as j
            confusion_pairs.append(((i, j), confusion_mat[i, j]))

# Sort by number of confusions (descending)
confusion_pairs.sort(key=lambda x: x[1], reverse=True)

# Print top confused pairs
print("Top confused class pairs (true -> predicted):")
for (true_label, pred_label), count in confusion_pairs[:10]:
    print(f"  {classes[true_label]} -> {classes[pred_label]}: {count} instances")

## Save the Model Configuration

Let's save the model configuration for future reference.

In [None]:
# Save the configuration to a file
os.makedirs('../outputs', exist_ok=True)
config_path = '../outputs/cifar10_cnn_config.yaml'
config_manager.save_config(config_path)
print(f"Configuration saved to {config_path}")

## Conclusion

In this notebook, we have demonstrated how to use the neural networks framework to:

1. Load and preprocess the CIFAR-10 dataset
2. Configure and create a CNN model for image classification
3. Train the model with appropriate hyperparameters
4. Evaluate performance with various metrics
5. Visualize predictions and identify difficult classes

CIFAR-10 is more challenging than MNIST and even this relatively simple CNN architecture can achieve reasonable results. For better performance, more advanced architectures like ResNet or EfficientNet could be implemented by extending the BaseModel class in our framework.