# Model Training Basics for AG News Classification

## Overview

This tutorial demonstrates fundamental model training concepts following methodologies from:
- Goodfellow et al. (2016): "Deep Learning"
- Smith (2018): "A Disciplined Approach to Neural Network Hyper-Parameters"
- Howard & Ruder (2018): "Universal Language Model Fine-tuning for Text Classification"

### Learning Objectives
1. Understand the training loop architecture
2. Implement basic model training with transformers
3. Apply optimization strategies and learning rate scheduling
4. Monitor training progress and prevent overfitting
5. Save and load model checkpoints

Author: Võ Hải Dũng  
Email: vohaidung.work@gmail.com  
Date: 2025

## 1. Setup and Imports

In [None]:
# Standard library imports
import os
import sys
import json
import time
from pathlib import Path
from typing import Dict, List, Tuple, Optional
from dataclasses import dataclass
import warnings
warnings.filterwarnings('ignore')

# Data manipulation
import numpy as np
import pandas as pd

# PyTorch
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR

# Transformers
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    get_linear_schedule_with_warmup
)

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm

# Project setup
PROJECT_ROOT = Path("../..").resolve()
sys.path.insert(0, str(PROJECT_ROOT))

# Project imports
from src.data.datasets.ag_news import AGNewsDataset, AGNewsConfig
from src.models.base.base_model import BaseModel
from src.training.trainers.base_trainer import BaseTrainer
from src.evaluation.metrics.classification_metrics import ClassificationMetrics
from src.utils.reproducibility import set_seed
from src.utils.logging_config import setup_logging
from configs.constants import (
    AG_NEWS_NUM_CLASSES,
    DATA_DIR,
    MODEL_DIR
)

# Setup
set_seed(42)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

print(f"PyTorch version: {torch.__version__}")
print(f"Device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")

## 2. Prepare Data

In [None]:
# Load dataset
dataset_config = AGNewsConfig(
    data_dir=DATA_DIR / "processed",
    max_samples=1000,  # Use subset for tutorial
    use_cache=True
)

try:
    train_dataset = AGNewsDataset(dataset_config, split="train")
    val_dataset = AGNewsDataset(dataset_config, split="validation")
    print(f"Train samples: {len(train_dataset)}")
    print(f"Validation samples: {len(val_dataset)}")
except:
    print("Using synthetic data for demonstration")
    # Create synthetic dataset
    class SyntheticDataset(Dataset):
        def __init__(self, size=1000):
            self.texts = [f"Sample text {i}" for i in range(size)]
            self.labels = torch.randint(0, AG_NEWS_NUM_CLASSES, (size,))
        
        def __len__(self):
            return len(self.texts)
        
        def __getitem__(self, idx):
            return {
                'text': self.texts[idx],
                'label': self.labels[idx]
            }
    
    train_dataset = SyntheticDataset(800)
    val_dataset = SyntheticDataset(200)

# Create data loaders
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)

print(f"\nDataLoaders created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")

## 3. Simple Model Implementation

In [None]:
class SimpleClassifier(nn.Module):
    """
    Simple neural network classifier for demonstration.
    """
    
    def __init__(self, input_dim: int, hidden_dim: int, num_classes: int, dropout: float = 0.1):
        """
        Initialize classifier.
        
        Args:
            input_dim: Input dimension
            hidden_dim: Hidden layer dimension
            num_classes: Number of output classes
            dropout: Dropout probability
        """
        super().__init__()
        
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.dropout1 = nn.Dropout(dropout)
        self.fc2 = nn.Linear(hidden_dim, hidden_dim // 2)
        self.dropout2 = nn.Dropout(dropout)
        self.fc3 = nn.Linear(hidden_dim // 2, num_classes)
        
        # Initialize weights
        self._init_weights()
    
    def _init_weights(self):
        """Initialize weights using Xavier initialization."""
        for module in self.modules():
            if isinstance(module, nn.Linear):
                nn.init.xavier_normal_(module.weight)
                if module.bias is not None:
                    nn.init.zeros_(module.bias)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass.
        
        Args:
            x: Input tensor [batch_size, input_dim]
            
        Returns:
            Logits tensor [batch_size, num_classes]
        """
        x = F.relu(self.fc1(x))
        x = self.dropout1(x)
        x = F.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        return x

# Create model
simple_model = SimpleClassifier(
    input_dim=768,  # BERT embedding dimension
    hidden_dim=256,
    num_classes=AG_NEWS_NUM_CLASSES,
    dropout=0.1
)

# Move to device
simple_model = simple_model.to(device)

# Count parameters
total_params = sum(p.numel() for p in simple_model.parameters())
trainable_params = sum(p.numel() for p in simple_model.parameters() if p.requires_grad)

print(f"Model created:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"\nModel architecture:")
print(simple_model)

## 4. Training Loop Implementation

In [None]:
@dataclass
class TrainingConfig:
    """Training configuration."""
    num_epochs: int = 3
    learning_rate: float = 2e-5
    warmup_steps: int = 100
    weight_decay: float = 0.01
    gradient_clip_norm: float = 1.0
    log_interval: int = 10

def train_epoch(model, dataloader, optimizer, scheduler, device, config):
    """
    Train model for one epoch.
    
    Args:
        model: Model to train
        dataloader: Training data loader
        optimizer: Optimizer
        scheduler: Learning rate scheduler
        device: Device to use
        config: Training configuration
        
    Returns:
        Dictionary with training metrics
    """
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    progress_bar = tqdm(dataloader, desc="Training")
    
    for batch_idx, batch in enumerate(progress_bar):
        # Prepare batch (simplified for demo)
        # In real scenario, tokenize texts here
        inputs = torch.randn(len(batch['label']), 768).to(device)  # Dummy embeddings
        labels = batch['label'].to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = F.cross_entropy(outputs, labels)
        
        # Backward pass
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip_norm)
        
        # Optimizer step
        optimizer.step()
        scheduler.step()
        
        # Calculate accuracy
        _, predicted = torch.max(outputs, 1)
        correct += (predicted == labels).sum().item()
        total += labels.size(0)
        
        # Update metrics
        total_loss += loss.item()
        
        # Update progress bar
        if batch_idx % config.log_interval == 0:
            progress_bar.set_postfix({
                'loss': f'{loss.item():.4f}',
                'acc': f'{100 * correct / total:.2f}%',
                'lr': f'{scheduler.get_last_lr()[0]:.2e}'
            })
    
    return {
        'loss': total_loss / len(dataloader),
        'accuracy': 100 * correct / total
    }

def validate(model, dataloader, device):
    """
    Validate model.
    
    Args:
        model: Model to validate
        dataloader: Validation data loader
        device: Device to use
        
    Returns:
        Dictionary with validation metrics
    """
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Validation"):
            # Prepare batch (simplified)
            inputs = torch.randn(len(batch['label']), 768).to(device)
            labels = batch['label'].to(device)
            
            # Forward pass
            outputs = model(inputs)
            loss = F.cross_entropy(outputs, labels)
            
            # Calculate accuracy
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)
            
            total_loss += loss.item()
    
    return {
        'loss': total_loss / len(dataloader),
        'accuracy': 100 * correct / total
    }

# Setup training
config = TrainingConfig()

# Initialize optimizer
optimizer = AdamW(
    simple_model.parameters(),
    lr=config.learning_rate,
    weight_decay=config.weight_decay
)

# Initialize scheduler
total_steps = len(train_loader) * config.num_epochs
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=config.warmup_steps,
    num_training_steps=total_steps
)

print(f"Training configuration:")
print(f"  Epochs: {config.num_epochs}")
print(f"  Learning rate: {config.learning_rate}")
print(f"  Total steps: {total_steps}")
print(f"  Warmup steps: {config.warmup_steps}")

## 5. Training Execution

In [None]:
# Training history
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': []
}

# Best model tracking
best_val_acc = 0
best_model_state = None

print("Starting training...")
print("=" * 50)

for epoch in range(config.num_epochs):
    print(f"\nEpoch {epoch + 1}/{config.num_epochs}")
    
    # Train
    train_metrics = train_epoch(
        simple_model, train_loader, optimizer, scheduler, device, config
    )
    
    # Validate
    val_metrics = validate(simple_model, val_loader, device)
    
    # Update history
    history['train_loss'].append(train_metrics['loss'])
    history['train_acc'].append(train_metrics['accuracy'])
    history['val_loss'].append(val_metrics['loss'])
    history['val_acc'].append(val_metrics['accuracy'])
    
    # Save best model
    if val_metrics['accuracy'] > best_val_acc:
        best_val_acc = val_metrics['accuracy']
        best_model_state = simple_model.state_dict().copy()
    
    # Print metrics
    print(f"Train Loss: {train_metrics['loss']:.4f}, Train Acc: {train_metrics['accuracy']:.2f}%")
    print(f"Val Loss: {val_metrics['loss']:.4f}, Val Acc: {val_metrics['accuracy']:.2f}%")

print("\nTraining completed!")
print(f"Best validation accuracy: {best_val_acc:.2f}%")

## 6. Visualize Training Progress

In [None]:
# Plot training history
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 4))

# Loss plot
ax1.plot(history['train_loss'], label='Train Loss', marker='o')
ax1.plot(history['val_loss'], label='Val Loss', marker='s')
ax1.set_xlabel('Epoch')
ax1.set_ylabel('Loss')
ax1.set_title('Training and Validation Loss')
ax1.legend()
ax1.grid(True, alpha=0.3)

# Accuracy plot
ax2.plot(history['train_acc'], label='Train Acc', marker='o')
ax2.plot(history['val_acc'], label='Val Acc', marker='s')
ax2.set_xlabel('Epoch')
ax2.set_ylabel('Accuracy (%)')
ax2.set_title('Training and Validation Accuracy')
ax2.legend()
ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

# Analyze training behavior
print("Training Analysis:")
print("=" * 50)

# Check for overfitting
final_train_acc = history['train_acc'][-1]
final_val_acc = history['val_acc'][-1]
gap = final_train_acc - final_val_acc

if gap > 10:
    print(f"Warning: Potential overfitting detected (gap: {gap:.2f}%)")
    print("Recommendations:")
    print("  - Increase dropout")
    print("  - Add regularization")
    print("  - Use data augmentation")
    print("  - Reduce model complexity")
elif gap < 2:
    print(f"Model is well-generalized (gap: {gap:.2f}%)")
    print("Could potentially increase model capacity")
else:
    print(f"Reasonable generalization (gap: {gap:.2f}%)")

## 7. Transformer Model Training

In [None]:
# Load a pre-trained transformer model
from transformers import AutoModelForSequenceClassification, AutoTokenizer

# Use a small model for demonstration
model_name = "distilbert-base-uncased"

print(f"Loading {model_name}...")

# Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(model_name)
transformer_model = AutoModelForSequenceClassification.from_pretrained(
    model_name,
    num_labels=AG_NEWS_NUM_CLASSES
)

# Move to device
transformer_model = transformer_model.to(device)

# Count parameters
total_params = sum(p.numel() for p in transformer_model.parameters())
trainable_params = sum(p.numel() for p in transformer_model.parameters() if p.requires_grad)

print(f"\nTransformer model loaded:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")

# Training configuration for transformer
@dataclass
class TransformerTrainingConfig:
    num_epochs: int = 2
    learning_rate: float = 5e-5
    warmup_ratio: float = 0.1
    weight_decay: float = 0.01
    max_length: int = 128
    batch_size: int = 8

transformer_config = TransformerTrainingConfig()

print(f"\nTraining configuration:")
print(f"  Epochs: {transformer_config.num_epochs}")
print(f"  Learning rate: {transformer_config.learning_rate}")
print(f"  Batch size: {transformer_config.batch_size}")
print(f"  Max sequence length: {transformer_config.max_length}")

## 8. Model Checkpointing

In [None]:
class ModelCheckpoint:
    """
    Model checkpointing utility.
    """
    
    def __init__(self, save_dir: Path, model_name: str):
        """
        Initialize checkpoint manager.
        
        Args:
            save_dir: Directory to save checkpoints
            model_name: Name of the model
        """
        self.save_dir = Path(save_dir)
        self.save_dir.mkdir(parents=True, exist_ok=True)
        self.model_name = model_name
        self.best_metric = float('-inf')
    
    def save(self, model, optimizer, scheduler, epoch, metrics, is_best=False):
        """
        Save model checkpoint.
        
        Args:
            model: Model to save
            optimizer: Optimizer state
            scheduler: Scheduler state
            epoch: Current epoch
            metrics: Current metrics
            is_best: Whether this is the best model
        """
        checkpoint = {
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
            'metrics': metrics
        }
        
        # Save latest checkpoint
        latest_path = self.save_dir / f"{self.model_name}_latest.pt"
        torch.save(checkpoint, latest_path)
        
        # Save best checkpoint if applicable
        if is_best:
            best_path = self.save_dir / f"{self.model_name}_best.pt"
            torch.save(checkpoint, best_path)
            self.best_metric = metrics.get('val_accuracy', metrics.get('val_loss', 0))
            print(f"Saved best model with metric: {self.best_metric:.4f}")
    
    def load(self, model, optimizer=None, scheduler=None, checkpoint_type='best'):
        """
        Load model checkpoint.
        
        Args:
            model: Model to load weights into
            optimizer: Optimizer to load state
            scheduler: Scheduler to load state
            checkpoint_type: 'best' or 'latest'
            
        Returns:
            Dictionary with checkpoint information
        """
        checkpoint_path = self.save_dir / f"{self.model_name}_{checkpoint_type}.pt"
        
        if not checkpoint_path.exists():
            print(f"No checkpoint found at {checkpoint_path}")
            return None
        
        checkpoint = torch.load(checkpoint_path, map_location=device)
        
        model.load_state_dict(checkpoint['model_state_dict'])
        
        if optimizer and 'optimizer_state_dict' in checkpoint:
            optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
        
        if scheduler and 'scheduler_state_dict' in checkpoint:
            scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
        
        print(f"Loaded checkpoint from epoch {checkpoint['epoch']}")
        print(f"Metrics: {checkpoint['metrics']}")
        
        return checkpoint

# Create checkpoint manager
checkpoint_manager = ModelCheckpoint(
    save_dir=MODEL_DIR / "checkpoints",
    model_name="simple_classifier"
)

# Save the best model from training
if best_model_state:
    simple_model.load_state_dict(best_model_state)
    checkpoint_manager.save(
        model=simple_model,
        optimizer=optimizer,
        scheduler=scheduler,
        epoch=config.num_epochs,
        metrics={'val_accuracy': best_val_acc},
        is_best=True
    )
    print(f"\nBest model saved to {checkpoint_manager.save_dir}")

## 9. Training Best Practices

In [None]:
# Training best practices summary
best_practices = {
    "Data Preparation": [
        "Use stratified splits to maintain class balance",
        "Apply appropriate data augmentation",
        "Normalize/standardize inputs consistently",
        "Use proper train/val/test splits"
    ],
    "Model Selection": [
        "Start with pre-trained models when available",
        "Choose architecture based on task requirements",
        "Consider model size vs performance trade-offs",
        "Use ensemble methods for better performance"
    ],
    "Training Configuration": [
        "Use learning rate scheduling (warmup + decay)",
        "Apply gradient clipping for stability",
        "Use mixed precision training for efficiency",
        "Monitor multiple metrics during training"
    ],
    "Optimization": [
        "Use AdamW optimizer for transformers",
        "Apply weight decay for regularization",
        "Tune learning rate as primary hyperparameter",
        "Use gradient accumulation for larger effective batch sizes"
    ],
    "Monitoring": [
        "Track training and validation metrics",
        "Use early stopping to prevent overfitting",
        "Save checkpoints regularly",
        "Log experiments for reproducibility"
    ],
    "Debugging": [
        "Start with small data subset",
        "Verify model can overfit single batch",
        "Check gradient flow and magnitudes",
        "Visualize predictions on samples"
    ]
}

print("Training Best Practices")
print("=" * 50)

for category, practices in best_practices.items():
    print(f"\n{category}:")
    for i, practice in enumerate(practices, 1):
        print(f"  {i}. {practice}")

## 10. Summary and Next Steps

In [None]:
# Training summary
print("Training Tutorial Summary")
print("=" * 50)

print("\nWhat we covered:")
print("1. Basic model architecture implementation")
print("2. Training loop with optimization")
print("3. Learning rate scheduling strategies")
print("4. Model checkpointing and saving")
print("5. Training monitoring and visualization")
print("6. Best practices for model training")

print("\nKey Takeaways:")
print("- Start simple and gradually increase complexity")
print("- Monitor training closely to detect issues early")
print("- Use pre-trained models when possible")
print("- Apply regularization to prevent overfitting")
print("- Save checkpoints for model recovery")
print("- Document experiments for reproducibility")

print("\nNext Steps:")
print("1. Train full transformer model on complete dataset")
print("2. Implement advanced training strategies (adversarial, curriculum)")
print("3. Explore model ensembling techniques")
print("4. Optimize hyperparameters systematically")
print("5. Deploy trained model for inference")

# Save training report
training_report = {
    'model_type': 'simple_classifier',
    'config': config.__dict__,
    'history': history,
    'best_val_accuracy': best_val_acc,
    'device': str(device),
    'timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
}

report_path = PROJECT_ROOT / "outputs" / "training" / "tutorial_report.json"
report_path.parent.mkdir(parents=True, exist_ok=True)

with open(report_path, 'w') as f:
    json.dump(training_report, f, indent=2)

print(f"\nTraining report saved to: {report_path}")