# Interactive Training (CIFAR-10)

This notebook lets you configure, train, and visualize the CIFAR-10 model end-to-end.


### Objectives
- Achieve ≥ 85% test accuracy on CIFAR-10
- Keep parameters < 200k
- Ensure receptive field > 44
- Use C1C2C3C4 architecture without MaxPooling
- Include Depthwise Separable Conv (C2) and Dilated Convs (C3/C4)
- Use GAP, Albumentations (HF, ShiftScaleRotate, CoarseDropout)

### Key Features
- C1C2C3C4 network with dilations (d=2,4,8) and DW separable conv in C2
- Global Average Pooling + Linear head
- Albumentations pipeline with dataset mean/std
- OneCycleLR schedule (base lr=0.003, max_lr=0.2) for faster convergence
- Detailed metrics: train/test accuracy, train/test loss, LR curve
- Visualizations: training curves, per-class accuracy, misclassified images
- Interactive cells to override config, train, and analyze results


## Setup & Imports


In [None]:
# !git clone https://github.com/SachinDangayach/AU_7_NN_CIFAR.git

## Configure Training


In [None]:
# Interactive Training Notebook
# Use: Run cells top-to-bottom. Adjust config overrides in the next cell.
import os, sys

# Add the cloned repository directory to the Python path
# sys.path.append('/content/AU_7_NN_CIFAR')
# sys.path.append('/kaggle/working/AU_7_NN_CIFAR')

from config import get_config, ProjectConfig
from src.data.data_manager import create_data_manager
from src.models.model import create_model
from src.training.trainer import create_trainer
from src.visualization.visualizer import create_visualizer

import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = get_config()
config.training.device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f"Device: {device}")
print(f"Batch size: {config.data.batch_size}")
print(f"Scheduler: {config.training.scheduler_type}, base LR={config.training.learning_rate}, max_lr={getattr(config.training,'max_lr', None)}")


## Data Loaders


In [None]:
# Optional: override some config values here
# Example: faster experimentation
config.training.max_epochs = 30
config.training.scheduler_type = 'OneCycleLR'
config.training.learning_rate = 0.003
config.training.max_lr = 0.2
config.training.post_target_extra_epochs = 3
config.training.target_test_accuracy = 85.0

# Print quick summary
from config import print_config
print_config(config)


## Model Summary


In [None]:
# Quick verification: run a dummy forward pass
if 'model' not in locals():
    model = create_model(config.model).to(device)

x = torch.randn(2, 3, 32, 32).to(device)
out = model(x)
print("Input:", tuple(x.shape), "Output:", tuple(out.shape))
assert out.shape[-1] == config.model.num_classes, "Output classes mismatch"
print("✅ Forward pass OK")


In [None]:
# Data
data_manager = create_data_manager(config.data)
# Optionally compute dataset stats (commented for speed)
# mean, std = data_manager.calculate_dataset_statistics()
train_t, test_t = data_manager.create_transforms()
train_ds, test_ds = data_manager.load_datasets(train_t, test_t)
train_loader, test_loader = data_manager.create_data_loaders(train_ds, test_ds)

print(len(train_loader), len(test_loader))


## Train


In [None]:
# Create model
model = create_model(config.model).to(device)

# Count parameters
from src.models.model import count_model_parameters
total_params = count_model_parameters(model)

print(f"Model created successfully!")
print(f"Total parameters: {total_params:,}")
print(f"Parameter requirement (< {config.model.max_parameters:,}): {'✓' if total_params < config.model.max_parameters else '✗'}")

# Display model summary
viz = create_visualizer(config.visualization)
viz.display_model_summary(model)



In [None]:
# Print full model architecture summary
from src.visualization.visualizer import create_visualizer
if 'model' not in locals():
    model = create_model(config.model).to(device)
visualizer = create_visualizer(config.visualization)
visualizer.display_model_summary(model)


## Training Curves


In [None]:
# Model and trainer
model = create_model(config.model).to(device)
trainer = create_trainer(model, config.training, device)

metrics = trainer.train(
    train_loader,
    test_loader,
    max_epochs=config.training.max_epochs,
    target_test_acc=config.training.target_test_accuracy,
    post_target_extra_epochs=config.training.post_target_extra_epochs,
)

best = metrics.get_best_metrics()
print("\nTraining completed!")
print(f"Best test accuracy: {best.get('best_test_accuracy', float('nan')):.2f}%")
print(f"Best epoch: {best.get('best_epoch', -1)}")


In [None]:
# Visualize curves
visualizer = create_visualizer(config.visualization)
visualizer.plot_training_curves(
    metrics.train_losses,
    metrics.train_accuracies,
    metrics.learning_rates,
    None,  # or provide a save path like 'training_curves.png'
    metrics.test_accuracies,
    metrics.test_losses,
)



## Analysis: Per-Class Accuracy and Misclassifications


In [None]:
# Optional: visualize per-class accuracy and misclassified images
from src.data.data_manager import CIFAR10DataManager
class_names = train_ds.classes

# Per-class accuracy
_ = visualizer.plot_per_class_accuracy(model, test_loader, class_names, device)

# Misclassified images
visualizer.analyze_misclassified_images(model, test_loader, class_names, device=device, num_images=16)
