# Alphanumeric CNN Training Notebook

This notebook allows you to train the CNN model for different datasets (MNIST, EMNIST) interactively.

In [None]:
%load_ext autoreload
%autoreload 2

import os
import time
import torch
import torch.nn as nn
import torch.optim as optim
from tqdm.notebook import tqdm
import matplotlib.pyplot as plt

# Import local modules
# Ensure these files are in the same directory
from model_extended import create_model, get_model_summary, get_class_info
from data_loader import load_dataset
from utils import get_device, plot_training_history, plot_confusion_matrix

In [None]:
# === Configuration ===

# Dataset Options: 'mnist' (digits), 'letters' (A-Z), 'balanced' (alphanumeric), 'byclass' (full)
DATASET_TYPE = 'mnist'  

# Training Hyperparameters
BATCH_SIZE = 64
EPOCHS = 5
LEARNING_RATE = 0.001

# Paths
SAVE_DIR = './models'
OUTPUT_DIR = './outputs'

# Setup
os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

DEVICE = get_device()
print(f"Using Setup: {DEVICE}, Dataset: {DATASET_TYPE}")

## 1. Load Dataset

In [None]:
train_loader, test_loader, _ = load_dataset(
    dataset_type=DATASET_TYPE,
    batch_size=BATCH_SIZE
)

## 2. Initialize Model

In [None]:
model = create_model(DATASET_TYPE)
model = model.to(DEVICE)

print(get_model_summary(model, DATASET_TYPE))

## 3. Training Logic
We define the training and validation steps here to integrate with Jupyter's progress bars.

In [None]:
def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0
    
    pbar = tqdm(loader, desc='Training', leave=False)
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item() * images.size(0)
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        # Update progress bar
        pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{100. * correct / total:.2f}%'})
        
    return running_loss / total, 100. * correct / total

def validate(model, loader, criterion, device):
    model.eval()
    running_loss = 0.0
    correct = 0
    total = 0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for images, labels in tqdm(loader, desc='Validating', leave=False):
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item() * images.size(0)
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
            
            all_preds.extend(predicted.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
            
    return running_loss / total, 100. * correct / total, all_labels, all_preds

## 4. Run Training

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=2, verbose=True
)

history = {'train_loss': [], 'train_acc': [], 'val_loss': [], 'val_acc': []}
best_acc = 0.0

print(f"ðŸš€ Starting training for {EPOCHS} epochs...")
start_time = time.time()

for epoch in range(1, EPOCHS + 1):
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, DEVICE)
    val_loss, val_acc, y_true, y_pred = validate(model, test_loader, criterion, DEVICE)
    
    scheduler.step(val_loss)
    
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    
    print(f"Epoch {epoch}/{EPOCHS} | Train Loss: {train_loss:.4f} Acc: {train_acc:.2f}% | Val Loss: {val_loss:.4f} Acc: {val_acc:.2f}%")
    
    if val_acc > best_acc:
        best_acc = val_acc
        model_path = os.path.join(SAVE_DIR, f'{DATASET_TYPE}_notebook_best.pth')
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'val_acc': val_acc,
            'dataset_type': DATASET_TYPE
        }, model_path)
        print(f"  âœ“ Saved best model to {model_path} ({val_acc:.2f}%)")

total_time = (time.time() - start_time) / 60
print(f"\nâœ¨ Training Complete in {total_time:.2f} minutes!")

## 5. Visualizations

In [None]:
plot_training_history(history, show=True)
plot_confusion_matrix(y_true, y_pred, show=True)