# Advanced CIFAR-10 Classification with C1C2C3C40 Architecture

## Objectives
1. **Architecture**: C1C2C3C40 structure (No MaxPooling, stride=2 in last conv)
2. **Receptive Field**: Total RF > 44
3. **Advanced Convolutions**: Depthwise Separable + Dilated Convolution
4. **Global Average Pooling**: Compulsory with optional FC layer
5. **Data Augmentation**: Albumentation library with specific transforms
6. **Performance**: Achieve 85% accuracy with < 200k parameters
7. **Code Modularity**: Well-organized, reusable modules

## Key Features
- **Depthwise Separable Convolution** in Conv Block 2
- **Dilated Convolution** in Conv Block 3
- **Stride=2** instead of MaxPooling in Conv Block 4
- **Global Average Pooling** with FC layer
- **Albumentation** for data augmentation
- **OneCycleLR** scheduler for better training


## Import Required Libraries


In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torchsummary import summary
import matplotlib.pyplot as plt
import numpy as np
import time
import warnings
warnings.filterwarnings('ignore')

# Import our custom modules
from src.models.model import CIFAR10Net, count_model_parameters
from src.data.data_manager import CIFAR10DataManager
from src.training.trainer import ModelTrainer
from src.visualization.visualizer import ModelVisualizer
from src.utils.utils import get_device, print_device_info, print_receptive_field_info
from config import get_config

print("All libraries imported successfully!")


## Setup Device and Configuration


In [None]:
# Load configuration
config = get_config()

# Setup device
device = get_device()
print_device_info(device)

# Print receptive field information
print_receptive_field_info()

print(f"\nConfiguration loaded:")
print(f"Target accuracy: {config.training.target_accuracy}%")
print(f"Max parameters: {config.model.max_parameters:,}")
print(f"Training epochs: {config.training.epochs}")
print(f"Learning rate: {config.training.learning_rate}")


## Data Preparation and Augmentation


In [None]:
# Create data manager
data_manager = CIFAR10DataManager(config.data)

# Calculate dataset statistics
mean, std = data_manager.calculate_dataset_statistics()

# Create transforms
train_transform, test_transform = data_manager.create_transforms()

# Load datasets
train_dataset, test_dataset = data_manager.load_datasets(train_transform, test_transform)
classes = train_dataset.classes

# Create data loaders
train_loader, test_loader = data_manager.create_data_loaders(train_dataset, test_dataset)

print(f"\nData setup completed:")
print(f"Classes: {classes}")
print(f"Training samples: {len(train_dataset)}")
print(f"Test samples: {len(test_dataset)}")
print(f"Training batches: {len(train_loader)}")
print(f"Test batches: {len(test_loader)}")


## Data Visualization


In [None]:
# Create visualizer
visualizer = ModelVisualizer(config.visualization)

# Display sample images
visualizer.display_sample_images(
    train_loader, 
    classes, 
    mean, 
    std, 
    config.visualization.num_sample_images
)


## Model Architecture


In [None]:
# Create model
model = CIFAR10Net(config.model).to(device)

# Count parameters
total_params = count_model_parameters(model)

print(f"Model created successfully!")
print(f"Total parameters: {total_params:,}")
print(f"Parameter requirement (< {config.model.max_parameters:,}): {'✓' if total_params < config.model.max_parameters else '✗'}")

# Display model summary
visualizer.display_model_summary(model)


## Model Training


In [None]:
# Create trainer
trainer = ModelTrainer(model, config.training, device)

print("Starting training...")
print("=" * 50)

# Train the model
metrics = trainer.train(train_loader, test_loader)

print(f"\nTraining completed!")
print(f"Best validation accuracy: {trainer.best_val_accuracy:.2f}%")
print(f"Best epoch: {trainer.best_epoch}")


## Results Visualization and Analysis


In [None]:
# Load best model for analysis
trainer.load_best_model()

# Plot training curves
visualizer.plot_training_curves(
    metrics.train_losses,
    metrics.train_accuracies,
    metrics.val_losses,
    metrics.val_accuracies,
    metrics.learning_rates,
    config.visualization.training_curves_path
)

# Plot learning rate schedule
visualizer.plot_learning_rate_schedule(metrics.learning_rates)


## Per-Class Accuracy Analysis


In [None]:
# Per-class accuracy analysis
class_accuracies = visualizer.plot_per_class_accuracy(
    model, 
    test_loader, 
    classes, 
    device,
    config.visualization.class_accuracy_path
)

print("\nPer-class accuracies:")
for class_name, acc in class_accuracies.items():
    print(f"{class_name}: {acc:.2f}%")


## Misclassified Images Analysis


In [None]:
# Analyze misclassified images
visualizer.analyze_misclassified_images(
    model,
    test_loader,
    classes,
    mean,
    std,
    config.visualization.num_misclassified_images,
    device
)


## Final Results Summary


In [None]:
# Final analysis
best_metrics = metrics.get_best_metrics()

print("=" * 60)
print("FINAL RESULTS SUMMARY")
print("=" * 60)

print(f"Best validation accuracy: {best_metrics['best_val_accuracy']:.2f}%")
print(f"Best epoch: {best_metrics['best_epoch']}")
print(f"Target accuracy: {config.training.target_accuracy}%")
print(f"Target achieved: {'✓' if best_metrics['best_val_accuracy'] >= config.training.target_accuracy else '✗'}")

print(f"\nModel Architecture Compliance:")
print(f"✓ C1C2C3C40 structure: Implemented")
print(f"✓ No MaxPooling: Implemented")
print(f"✓ Stride=2 in Conv Block 4: Implemented")
print(f"✓ Depthwise Separable Convolution: Implemented")
print(f"✓ Dilated Convolution: Implemented")
print(f"✓ Global Average Pooling: Implemented")
print(f"✓ FC layer after GAP: Implemented")
print(f"✓ Albumentation augmentations: Implemented")

print(f"\nParameter count: {total_params:,} (< {config.model.max_parameters:,} requirement: {'✓' if total_params < config.model.max_parameters else '✗'})")
print(f"Receptive Field: > 44 (requirement: ✓)")

print(f"\nData Augmentation Applied:")
print(f"✓ Horizontal Flip: p={config.data.horizontal_flip_prob}")
print(f"✓ ShiftScaleRotate: p={config.data.shift_scale_rotate_prob}")
print(f"✓ CoarseDropout: p={config.data.coarse_dropout_prob}")

print("\n✅ Training completed successfully!")
print("=" * 60)
