In [1]:
import torch
import lightning as L
import matplotlib.pyplot as plt
import numpy as np

from models.cnn import AdvancedMNISTCNN
from data.mnist import MNISTDataModule

In [2]:
# Configuration
config = {
    'learning_rate': 0.003,
    'dropout': 0.05,
    'batch_size': 32,
    'max_epochs': 20,
    'num_workers': 4,
    'data_dir': '../data',
    'weight_decay': 0,
}

print(f"Config: {config}")

Config: {'learning_rate': 0.003, 'dropout': 0.05, 'batch_size': 32, 'max_epochs': 20, 'num_workers': 4, 'data_dir': '../data', 'weight_decay': 0}


In [3]:
torch.manual_seed(42)

model = AdvancedMNISTCNN(
    learning_rate=config['learning_rate'], 
    dropout=config['dropout'],
    weight_decay=config['weight_decay']
)

data_module = MNISTDataModule(
    batch_size=config['batch_size'],
    data_dir=config['data_dir'],
    num_workers=config['num_workers']
)

total_params = sum(p.numel() for p in model.parameters())
print(f"Total parameters: {total_params}")

Total parameters: 19306


In [None]:
class MetricsLogger(L.Callback):
    def __init__(self):
        self.train_losses = []
        self.train_accs = []
        self.val_losses = []
        self.val_accs = []
        
    def on_train_epoch_end(self, trainer, pl_module):
        train_loss = trainer.logged_metrics.get('train_loss', 0)
        train_acc = trainer.logged_metrics.get('train_acc', 0)
        self.train_losses.append(float(train_loss))
        self.train_accs.append(float(train_acc))
        
    def on_validation_epoch_end(self, trainer, pl_module):
        val_loss = trainer.logged_metrics.get('val_loss', 0)
        val_acc = trainer.logged_metrics.get('val_acc', 0)
        self.val_losses.append(float(val_loss))
        self.val_accs.append(float(val_acc))
        
        epoch = trainer.current_epoch + 1
        current_lr = trainer.optimizers[0].param_groups[0]['lr']
        
        print(
            f"Epoch {epoch:2d}/{config['max_epochs']} | "
            f"LR: {current_lr:.6f} | "
            f"Train Loss: {self.train_losses[-1]:.4f} | "
            f"Train Acc: {self.train_accs[-1]:.4f} | "
            f"Val Loss: {val_loss:.4f} | "
            f"Val Acc: {val_acc:.4f}"
        )

metrics_logger = MetricsLogger()

In [None]:
# Train the model
trainer = L.Trainer(
    max_epochs=config['max_epochs'],
    accelerator='auto',
    devices=1,
    deterministic=True,
    callbacks=[metrics_logger],
    enable_progress_bar=True,
    log_every_n_steps=50
)

print("Starting training...")
trainer.fit(model, data_module)
print("Training completed!")

In [None]:
# Final validation
val_results = trainer.validate(model, data_module)
print(f"\nFinal validation results: {val_results}")
print(f"Final validation accuracy: {val_results[0]['val_acc']:.4f}")
print(f"Best validation accuracy: {max(metrics_logger.val_accs):.4f}")

In [None]:
# Plot training curves
epochs = range(1, len(metrics_logger.train_losses) + 1)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))

# Loss curves
ax1.plot(epochs, metrics_logger.train_losses, 'b-', label='Train Loss', linewidth=2)
ax1.plot(epochs, metrics_logger.val_losses, 'r-', label='Validation Loss', linewidth=2)
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Accuracy curves
ax2.plot(epochs, metrics_logger.train_accs, 'b-', label='Train Accuracy', linewidth=2)
ax2.plot(epochs, metrics_logger.val_accs, 'r-', label='Validation Accuracy', linewidth=2)
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy')
ax2.set_title('Training and Validation Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# Training summary
print("\n" + "="*60)
print("TRAINING SUMMARY")
print("="*60)
print(f"Model: AdvancedMNISTCNN")
print(f"Total Parameters: {total_params:,}")
print(f"Epochs: {config['max_epochs']}")
print(f"Batch Size: {config['batch_size']}")
print(f"Learning Rate Schedule: 0.003→0.001→0.0003→0.0001")
print(f"Dropout: {config['dropout']}")
print(f"Weight Decay: {config['weight_decay']}")
print("-"*60)
print(f"Final Train Accuracy: {metrics_logger.train_accs[-1]:.4f}")
print(f"Final Validation Accuracy: {metrics_logger.val_accs[-1]:.4f}")
print(f"Best Validation Accuracy: {max(metrics_logger.val_accs):.4f}")
print(f"Final Train Loss: {metrics_logger.train_losses[-1]:.6f}")
print(f"Final Validation Loss: {metrics_logger.val_losses[-1]:.6f}")
print("="*60)