# Model Training Basics for AG News Classification

## Overview

This tutorial demonstrates fundamental model training concepts following methodologies from:
- Devlin et al. (2019): "BERT: Pre-training of Deep Bidirectional Transformers"
- Liu et al. (2019): "RoBERTa: A Robustly Optimized BERT Pretraining Approach"
- He et al. (2021): "DeBERTa: Decoding-enhanced BERT with Disentangled Attention"

### Tutorial Objectives
1. Load and prepare AG News dataset
2. Configure transformer models
3. Implement training loops
4. Apply optimization strategies
5. Monitor training progress
6. Save and load checkpoints

Author: Võ Hải Dũng  
Email: vohaidung.work@gmail.com  
Date: 2025

## 1. Environment Setup and Imports

In [None]:
# Standard library imports
import os
import sys
import json
import warnings
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
from dataclasses import dataclass
import time

# Data and ML imports
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    get_linear_schedule_with_warmup
)
from sklearn.metrics import accuracy_score, f1_score

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm

# Project imports
PROJECT_ROOT = Path("../..").resolve()
sys.path.insert(0, str(PROJECT_ROOT))

from src.data.datasets.ag_news import AGNewsDataset, AGNewsConfig
from src.data.loaders.dataloader import create_dataloader
from src.models.transformers.roberta.roberta_enhanced import RoBERTaEnhanced
from src.training.trainers.base_trainer import BaseTrainer
from src.training.callbacks.early_stopping import EarlyStopping
from src.training.callbacks.model_checkpoint import ModelCheckpoint
from src.utils.reproducibility import set_seed
from src.utils.memory_utils import get_memory_usage, clear_memory
from configs.constants import (
    AG_NEWS_CLASSES,
    AG_NEWS_NUM_CLASSES,
    DATA_DIR,
    MODEL_DIR
)
from configs.config_loader import ConfigLoader

# Configuration
warnings.filterwarnings('ignore')
set_seed(42)

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if device.type == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

## 2. Load Configuration

In [None]:
# Load training configuration
config_loader = ConfigLoader()
training_config = config_loader.load_config('training/standard/base_training.yaml')

@dataclass
class TrainingArgs:
    """Training arguments following Transformers library conventions."""
    model_name: str = "roberta-base"
    num_epochs: int = 3
    batch_size: int = 16
    learning_rate: float = 2e-5
    warmup_ratio: float = 0.1
    weight_decay: float = 0.01
    max_length: int = 256
    gradient_accumulation_steps: int = 1
    fp16: bool = torch.cuda.is_available()
    evaluation_steps: int = 500
    save_steps: int = 1000
    logging_steps: int = 100
    
# Override with loaded config
args = TrainingArgs(
    model_name=training_config.get('model_name', 'roberta-base'),
    num_epochs=training_config.get('num_epochs', 3),
    batch_size=training_config.get('batch_size', 16),
    learning_rate=training_config.get('learning_rate', 2e-5)
)

print("Training Configuration:")
for key, value in args.__dict__.items():
    print(f"  {key}: {value}")

## 3. Load and Prepare Dataset

In [None]:
# Load datasets
print("Loading AG News dataset...")
data_config = AGNewsConfig(
    data_dir=DATA_DIR / "processed",
    max_length=args.max_length
)

train_dataset = AGNewsDataset(data_config, split="train")
val_dataset = AGNewsDataset(data_config, split="validation")

print(f"Train samples: {len(train_dataset):,}")
print(f"Validation samples: {len(val_dataset):,}")

# Sample data for quick training (optional for tutorial)
SAMPLE_SIZE = 5000  # Use smaller subset for tutorial
if len(train_dataset) > SAMPLE_SIZE:
    print(f"\nUsing {SAMPLE_SIZE} samples for tutorial (full dataset: {len(train_dataset)})")
    train_indices = np.random.choice(len(train_dataset), SAMPLE_SIZE, replace=False)
    train_dataset = torch.utils.data.Subset(train_dataset, train_indices)

# Initialize tokenizer
print(f"\nLoading tokenizer: {args.model_name}")
tokenizer = AutoTokenizer.from_pretrained(args.model_name)

# Tokenization function
def tokenize_function(examples):
    """Tokenize text examples."""
    return tokenizer(
        examples,
        padding='max_length',
        truncation=True,
        max_length=args.max_length,
        return_tensors='pt'
    )

# Create data loaders
train_loader = DataLoader(
    train_dataset,
    batch_size=args.batch_size,
    shuffle=True,
    num_workers=2,
    pin_memory=True
)

val_loader = DataLoader(
    val_dataset,
    batch_size=args.batch_size * 2,
    shuffle=False,
    num_workers=2,
    pin_memory=True
)

print(f"\nDataLoaders created:")
print(f"  Train batches: {len(train_loader)}")
print(f"  Validation batches: {len(val_loader)}")

## 4. Initialize Model

In [None]:
# Initialize model
print(f"Initializing model: {args.model_name}")

model = AutoModelForSequenceClassification.from_pretrained(
    args.model_name,
    num_labels=AG_NEWS_NUM_CLASSES,
    problem_type="single_label_classification"
)

# Move model to device
model = model.to(device)

# Model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel Architecture:")
print(f"  Total parameters: {total_params:,}")
print(f"  Trainable parameters: {trainable_params:,}")
print(f"  Model size: {total_params * 4 / 1e9:.2f} GB (fp32)")

# Enable mixed precision if available
scaler = torch.cuda.amp.GradScaler() if args.fp16 else None
if scaler:
    print("  Mixed precision training: Enabled")

## 5. Setup Optimizer and Scheduler

In [None]:
# Calculate total training steps
total_steps = len(train_loader) * args.num_epochs
warmup_steps = int(total_steps * args.warmup_ratio)

print("Optimization Configuration:")
print(f"  Total steps: {total_steps}")
print(f"  Warmup steps: {warmup_steps}")
print(f"  Learning rate: {args.learning_rate}")
print(f"  Weight decay: {args.weight_decay}")

# Initialize optimizer
optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=args.learning_rate,
    weight_decay=args.weight_decay,
    betas=(0.9, 0.999),
    eps=1e-8
)

# Initialize learning rate scheduler
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=warmup_steps,
    num_training_steps=total_steps
)

# Loss function
criterion = nn.CrossEntropyLoss()

print("\nOptimizer and scheduler initialized successfully")

## 6. Training Functions

In [None]:
def train_epoch(model, loader, optimizer, scheduler, criterion, device, scaler=None):
    """
    Train model for one epoch.
    
    Following training practices from:
    - Loshchilov & Hutter (2019): "Decoupled Weight Decay Regularization"
    """
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    progress_bar = tqdm(loader, desc="Training")
    
    for batch_idx, batch in enumerate(progress_bar):
        # Move batch to device
        texts, labels = batch
        inputs = tokenizer(texts, 
                          padding=True, 
                          truncation=True, 
                          max_length=args.max_length, 
                          return_tensors='pt')
        
        input_ids = inputs['input_ids'].to(device)
        attention_mask = inputs['attention_mask'].to(device)
        labels = labels.to(device)
        
        # Forward pass with mixed precision
        if scaler:
            with torch.cuda.amp.autocast():
                outputs = model(input_ids=input_ids, 
                              attention_mask=attention_mask,
                              labels=labels)
                loss = outputs.loss
        else:
            outputs = model(input_ids=input_ids,
                          attention_mask=attention_mask,
                          labels=labels)
            loss = outputs.loss
        
        # Backward pass
        if scaler:
            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()
        else:
            loss.backward()
            optimizer.step()
        
        scheduler.step()
        optimizer.zero_grad()
        
        # Calculate accuracy
        predictions = torch.argmax(outputs.logits, dim=-1)
        correct += (predictions == labels).sum().item()
        total += labels.size(0)
        total_loss += loss.item()
        
        # Update progress bar
        progress_bar.set_postfix({
            'loss': loss.item(),
            'acc': correct / total,
            'lr': optimizer.param_groups[0]['lr']
        })
    
    avg_loss = total_loss / len(loader)
    accuracy = correct / total
    
    return avg_loss, accuracy


@torch.no_grad()
def evaluate(model, loader, criterion, device):
    """
    Evaluate model on validation/test set.
    """
    model.eval()
    total_loss = 0
    all_predictions = []
    all_labels = []
    
    progress_bar = tqdm(loader, desc="Evaluating")
    
    for batch in progress_bar:
        texts, labels = batch
        inputs = tokenizer(texts,
                          padding=True,
                          truncation=True,
                          max_length=args.max_length,
                          return_tensors='pt')
        
        input_ids = inputs['input_ids'].to(device)
        attention_mask = inputs['attention_mask'].to(device)
        labels = labels.to(device)
        
        outputs = model(input_ids=input_ids,
                       attention_mask=attention_mask,
                       labels=labels)
        
        loss = outputs.loss
        total_loss += loss.item()
        
        predictions = torch.argmax(outputs.logits, dim=-1)
        all_predictions.extend(predictions.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
    
    avg_loss = total_loss / len(loader)
    accuracy = accuracy_score(all_labels, all_predictions)
    f1 = f1_score(all_labels, all_predictions, average='macro')
    
    return avg_loss, accuracy, f1, all_predictions, all_labels

print("Training functions defined successfully")

## 7. Training Loop

In [None]:
# Initialize tracking
history = {
    'train_loss': [],
    'train_acc': [],
    'val_loss': [],
    'val_acc': [],
    'val_f1': [],
    'learning_rate': []
}

# Initialize callbacks
early_stopping = EarlyStopping(patience=3, verbose=True)
checkpoint_dir = MODEL_DIR / "checkpoints" / f"{args.model_name.replace('/', '-')}"
checkpoint_dir.mkdir(parents=True, exist_ok=True)

best_val_acc = 0
best_epoch = 0

print("Starting training...")
print("="*60)

for epoch in range(args.num_epochs):
    print(f"\nEpoch {epoch + 1}/{args.num_epochs}")
    print("-" * 40)
    
    # Training phase
    start_time = time.time()
    train_loss, train_acc = train_epoch(
        model, train_loader, optimizer, scheduler, criterion, device, scaler
    )
    train_time = time.time() - start_time
    
    # Validation phase
    start_time = time.time()
    val_loss, val_acc, val_f1, _, _ = evaluate(
        model, val_loader, criterion, device
    )
    val_time = time.time() - start_time
    
    # Update history
    history['train_loss'].append(train_loss)
    history['train_acc'].append(train_acc)
    history['val_loss'].append(val_loss)
    history['val_acc'].append(val_acc)
    history['val_f1'].append(val_f1)
    history['learning_rate'].append(optimizer.param_groups[0]['lr'])
    
    # Print epoch results
    print(f"Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f} | Val Acc: {val_acc:.4f} | Val F1: {val_f1:.4f}")
    print(f"Time: Train {train_time:.1f}s | Val {val_time:.1f}s")
    print(f"Learning Rate: {optimizer.param_groups[0]['lr']:.2e}")
    
    # Save best model
    if val_acc > best_val_acc:
        best_val_acc = val_acc
        best_epoch = epoch + 1
        
        checkpoint_path = checkpoint_dir / "best_model.pt"
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'scheduler_state_dict': scheduler.state_dict(),
            'val_acc': val_acc,
            'val_f1': val_f1,
            'args': args
        }, checkpoint_path)
        print(f"✓ Saved best model (Val Acc: {val_acc:.4f})")
    
    # Early stopping check
    if early_stopping(val_loss):
        print(f"\nEarly stopping triggered at epoch {epoch + 1}")
        break
    
    # Memory management
    if device.type == 'cuda':
        clear_memory()

print("\n" + "="*60)
print(f"Training completed!")
print(f"Best Validation Accuracy: {best_val_acc:.4f} (Epoch {best_epoch})")

## 8. Visualize Training Progress

In [None]:
# Plot training history
fig, axes = plt.subplots(2, 2, figsize=(12, 10))

# Loss plot
ax = axes[0, 0]
ax.plot(history['train_loss'], label='Train Loss', marker='o')
ax.plot(history['val_loss'], label='Val Loss', marker='s')
ax.set_xlabel('Epoch')
ax.set_ylabel('Loss')
ax.set_title('Training and Validation Loss')
ax.legend()
ax.grid(True, alpha=0.3)

# Accuracy plot
ax = axes[0, 1]
ax.plot(history['train_acc'], label='Train Acc', marker='o')
ax.plot(history['val_acc'], label='Val Acc', marker='s')
ax.set_xlabel('Epoch')
ax.set_ylabel('Accuracy')
ax.set_title('Training and Validation Accuracy')
ax.legend()
ax.grid(True, alpha=0.3)

# F1 Score plot
ax = axes[1, 0]
ax.plot(history['val_f1'], label='Val F1', marker='D', color='green')
ax.set_xlabel('Epoch')
ax.set_ylabel('F1 Score')
ax.set_title('Validation F1 Score')
ax.legend()
ax.grid(True, alpha=0.3)

# Learning rate plot
ax = axes[1, 1]
ax.plot(history['learning_rate'], label='Learning Rate', marker='^', color='orange')
ax.set_xlabel('Epoch')
ax.set_ylabel('Learning Rate')
ax.set_title('Learning Rate Schedule')
ax.set_yscale('log')
ax.legend()
ax.grid(True, alpha=0.3)

plt.suptitle('Training Progress Overview', fontsize=14)
plt.tight_layout()
plt.show()

# Save training history
history_path = checkpoint_dir / "training_history.json"
with open(history_path, 'w') as f:
    json.dump(history, f, indent=2)
print(f"\nTraining history saved to: {history_path}")

## 9. Load and Test Best Model

In [None]:
# Load best model
print("Loading best model...")
checkpoint_path = checkpoint_dir / "best_model.pt"
checkpoint = torch.load(checkpoint_path, map_location=device)

model.load_state_dict(checkpoint['model_state_dict'])
print(f"Loaded model from epoch {checkpoint['epoch'] + 1}")
print(f"Best validation accuracy: {checkpoint['val_acc']:.4f}")
print(f"Best validation F1: {checkpoint['val_f1']:.4f}")

# Test on sample texts
test_texts = [
    "The stock market reached record highs today as investors celebrated positive earnings reports.",
    "Scientists discover new exoplanet that could potentially harbor life.",
    "The Lakers defeated the Celtics in overtime with a final score of 115-112.",
    "Apple announces new iPhone with revolutionary camera technology."
]

print("\nTesting on sample texts:")
print("="*60)

model.eval()
with torch.no_grad():
    for text in test_texts:
        inputs = tokenizer(text, 
                          return_tensors='pt',
                          padding=True,
                          truncation=True,
                          max_length=args.max_length)
        
        input_ids = inputs['input_ids'].to(device)
        attention_mask = inputs['attention_mask'].to(device)
        
        outputs = model(input_ids=input_ids, attention_mask=attention_mask)
        probabilities = torch.softmax(outputs.logits, dim=-1)
        predicted_class = torch.argmax(probabilities, dim=-1).item()
        confidence = probabilities[0][predicted_class].item()
        
        print(f"\nText: {text[:80]}...")
        print(f"Predicted: {AG_NEWS_CLASSES[predicted_class]} (confidence: {confidence:.3f})")

## 10. Training Summary and Next Steps

### Training Summary

This tutorial demonstrated fundamental concepts in training transformer models:

1. **Data Preparation**: 
   - Loaded AG News dataset
   - Created efficient data loaders
   - Implemented proper tokenization

2. **Model Configuration**:
   - Initialized pre-trained transformer model
   - Configured for sequence classification
   - Enabled mixed precision training

3. **Training Process**:
   - Implemented training and evaluation loops
   - Applied gradient accumulation
   - Used learning rate scheduling
   - Monitored multiple metrics

4. **Best Practices Applied**:
   - Early stopping to prevent overfitting
   - Model checkpointing for best weights
   - Memory management for efficiency
   - Comprehensive logging and visualization

### Advanced Training Techniques

To further improve performance, consider:

1. **Advanced Models**:
   - DeBERTa-v3 for better performance
   - Ensemble methods for robustness
   - Domain-adapted models

2. **Training Strategies**:
   - Adversarial training for robustness
   - Curriculum learning for efficiency
   - Knowledge distillation for compression

3. **Optimization Techniques**:
   - SAM optimizer for better generalization
   - Lookahead optimizer for stability
   - Gradient clipping for training stability

4. **Data Augmentation**:
   - Back-translation for more data
   - Paraphrasing for diversity
   - Mixup for regularization

### Next Steps

1. **Evaluation**: Proceed to `04_evaluation_tutorial.ipynb` for comprehensive model evaluation
2. **Advanced Training**: Explore prompt-based methods in `05_prompt_engineering.ipynb`
3. **Production**: Learn API deployment in `07_api_usage.ipynb`
4. **Optimization**: Implement efficient training with LoRA and PEFT methods