In [None]:
# Interactive Training Notebook
# Use: Run cells top-to-bottom. Adjust config overrides in the next cell.
import os, sys
sys.path.append('.')

from config import get_config, ProjectConfig
from src.data.data_manager import create_data_manager
from src.models.model import create_model
from src.training.trainer import create_trainer
from src.visualization.visualizer import create_visualizer

import torch

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
config = get_config()
config.training.device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f"Device: {device}")
print(f"Batch size: {config.data.batch_size}")
print(f"Scheduler: {config.training.scheduler_type}, base LR={config.training.learning_rate}, max_lr={getattr(config.training,'max_lr', None)}")


In [None]:
# Optional: override some config values here
# Example: faster experimentation
config.training.max_epochs = 30
config.training.scheduler_type = 'OneCycleLR'
config.training.learning_rate = 0.003
config.training.max_lr = 0.2
config.training.post_target_extra_epochs = 3
config.training.target_test_accuracy = 85.0

# Print quick summary
from config import print_config
print_config(config)


In [None]:
# Data
data_manager = create_data_manager(config.data)
# Optionally compute dataset stats (commented for speed)
# mean, std = data_manager.calculate_dataset_statistics()
train_t, test_t = data_manager.create_transforms()
train_ds, test_ds = data_manager.load_datasets(train_t, test_t)
train_loader, test_loader = data_manager.create_data_loaders(train_ds, test_ds)

print(len(train_loader), len(test_loader))


In [None]:
# Model and trainer
model = create_model(config.model).to(device)
trainer = create_trainer(model, config.training, device)

metrics = trainer.train(
    train_loader,
    test_loader,
    max_epochs=config.training.max_epochs,
    target_test_acc=config.training.target_test_accuracy,
    post_target_extra_epochs=config.training.post_target_extra_epochs,
)

best = metrics.get_best_metrics()
print("\nTraining completed!")
print(f"Best test accuracy: {best.get('best_test_accuracy', float('nan')):.2f}%")
print(f"Best epoch: {best.get('best_epoch', -1)}")


In [None]:
# Visualize curves
visualizer = create_visualizer(config.visualization)
visualizer.plot_training_curves(
    metrics.train_losses,
    metrics.train_accuracies,
    metrics.learning_rates,
    None,  # or provide a save path like 'training_curves.png'
    metrics.test_accuracies,
    metrics.test_losses,
)



In [None]:
# Optional: visualize per-class accuracy and misclassified images
from src.data.data_manager import CIFAR10DataManager
class_names = train_ds.classes

# Per-class accuracy
_ = visualizer.plot_per_class_accuracy(model, test_loader, class_names, device)

# Misclassified images
visualizer.analyze_misclassified_images(model, test_loader, class_names, device=device, num_images=16)
