# Model Training Basics for AG News Classification

## Overview

This notebook demonstrates fundamental model training techniques following methodologies from:
- Devlin et al. (2019): "BERT: Pre-training of Deep Bidirectional Transformers"
- Liu et al. (2019): "RoBERTa: A Robustly Optimized BERT Pretraining Approach"
- He et al. (2021): "DeBERTa: Decoding-enhanced BERT with Disentangled Attention"

### Tutorial Objectives
1. Load and prepare training data
2. Initialize transformer models
3. Configure training parameters
4. Execute training loop
5. Monitor training progress
6. Save and evaluate models

Author: Võ Hải Dũng  
Email: vohaidung.work@gmail.com  
Date: 2025

## 1. Environment Setup

In [None]:
# Standard library imports
import sys
import os
from pathlib import Path
from typing import Dict, List, Tuple, Optional, Any
import warnings

# Data and ML imports
import numpy as np
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from transformers import AutoTokenizer

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

# Project imports
PROJECT_ROOT = Path("../..").resolve()
sys.path.insert(0, str(PROJECT_ROOT))

from src.data.datasets.ag_news import AGNewsDataset, AGNewsConfig
from src.data.loaders.dataloader import create_dataloaders
from src.models.transformers.deberta.deberta_v3 import DeBERTaV3Classifier
from src.models.transformers.roberta.roberta_enhanced import RoBERTaEnhancedClassifier
from src.training.trainers.base_trainer import BaseTrainer
from src.training.trainers.standard_trainer import StandardTrainer
from src.training.callbacks.early_stopping import EarlyStopping
from src.training.callbacks.model_checkpoint import ModelCheckpoint
from src.utils.reproducibility import set_seed, get_reproducible_config
from src.utils.logging_config import setup_logging
from configs.config_loader import ConfigLoader
from configs.constants import AG_NEWS_NUM_CLASSES, MODEL_DIR, OUTPUT_DIR

# Setup
warnings.filterwarnings('ignore')
set_seed(42)
logger = setup_logging("model_training_tutorial")

# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2. Load Configuration

In [None]:
# Load training configuration
config_loader = ConfigLoader()

# Load base training config
training_config = config_loader.load_config('training/standard/base_training.yaml')

# Load model config (using DeBERTa as example)
model_config = config_loader.load_config('models/single/deberta_v3_xlarge.yaml')

# Override for tutorial (smaller settings for demonstration)
tutorial_overrides = {
    'batch_size': 8,  # Smaller batch for memory
    'num_epochs': 3,  # Fewer epochs for speed
    'max_samples': 1000,  # Limit samples for tutorial
    'gradient_accumulation_steps': 2,
    'eval_steps': 50,
    'save_steps': 100,
    'logging_steps': 10,
    'warmup_steps': 100
}

# Apply overrides
for key, value in tutorial_overrides.items():
    if key in training_config:
        training_config[key] = value

print("Training Configuration:")
print("="*50)
for key, value in training_config.items():
    if not isinstance(value, dict):
        print(f"  {key}: {value}")

## 3. Data Preparation

In [None]:
# Initialize tokenizer
model_name = model_config.get('model_name', 'microsoft/deberta-v3-base')
tokenizer = AutoTokenizer.from_pretrained(model_name)

print(f"Tokenizer: {model_name}")
print(f"Vocab size: {tokenizer.vocab_size}")

# Load datasets
data_config = AGNewsConfig(
    max_samples=tutorial_overrides['max_samples'],
    tokenizer=tokenizer,
    max_length=model_config.get('max_length', 256)
)

print("\nLoading datasets...")
train_dataset = AGNewsDataset(data_config, split='train')
val_dataset = AGNewsDataset(data_config, split='validation')

print(f"Train samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

# Create dataloaders
train_dataloader, val_dataloader = create_dataloaders(
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    batch_size=training_config['batch_size'],
    num_workers=2,
    pin_memory=torch.cuda.is_available()
)

print(f"\nDataLoaders created:")
print(f"  Train batches: {len(train_dataloader)}")
print(f"  Validation batches: {len(val_dataloader)}")

## 4. Model Initialization

In [None]:
# Initialize model
def initialize_model(model_type: str = 'deberta') -> nn.Module:
    """
    Initialize a transformer model for classification.
    
    Following initialization strategies from:
        Glorot & Bengio (2010): "Understanding the difficulty of training deep feedforward neural networks"
    """
    if model_type == 'deberta':
        model = DeBERTaV3Classifier(
            model_name=model_name,
            num_labels=AG_NEWS_NUM_CLASSES,
            dropout_rate=model_config.get('dropout_rate', 0.1),
            use_pooler=model_config.get('use_pooler', False)
        )
    elif model_type == 'roberta':
        model = RoBERTaEnhancedClassifier(
            model_name='roberta-base',
            num_labels=AG_NEWS_NUM_CLASSES,
            dropout_rate=model_config.get('dropout_rate', 0.1)
        )
    else:
        raise ValueError(f"Unknown model type: {model_type}")
    
    return model

# Create model
model = initialize_model('deberta')
model = model.to(device)

# Model summary
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print("Model Information:")
print("="*50)
print(f"Model type: DeBERTa-v3")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Non-trainable parameters: {total_params - trainable_params:,}")
print(f"Model size: {total_params * 4 / 1024**2:.1f} MB (fp32)")

## 5. Training Setup

In [None]:
# Setup optimizer
from src.training.optimization.optimizers.adamw_custom import create_adamw_optimizer

optimizer = create_adamw_optimizer(
    model=model,
    learning_rate=training_config.get('learning_rate', 2e-5),
    weight_decay=training_config.get('weight_decay', 0.01),
    betas=(0.9, 0.999),
    eps=1e-8
)

# Setup learning rate scheduler
from src.training.optimization.schedulers.cosine_warmup import CosineWarmupScheduler

num_training_steps = len(train_dataloader) * training_config['num_epochs']
scheduler = CosineWarmupScheduler(
    optimizer=optimizer,
    warmup_steps=training_config['warmup_steps'],
    total_steps=num_training_steps
)

# Setup loss function
from src.training.objectives.losses.label_smoothing import LabelSmoothingCrossEntropy

criterion = LabelSmoothingCrossEntropy(
    num_classes=AG_NEWS_NUM_CLASSES,
    smoothing=training_config.get('label_smoothing', 0.1)
)

print("Training Setup:")
print("="*50)
print(f"Optimizer: AdamW")
print(f"Learning rate: {training_config.get('learning_rate', 2e-5)}")
print(f"Scheduler: Cosine with warmup")
print(f"Warmup steps: {training_config['warmup_steps']}")
print(f"Total training steps: {num_training_steps}")
print(f"Loss function: Label Smoothing Cross-Entropy")
print(f"Label smoothing: {training_config.get('label_smoothing', 0.1)}")

## 6. Training Callbacks

In [None]:
# Setup callbacks
callbacks = []

# Early stopping
early_stopping = EarlyStopping(
    patience=training_config.get('early_stopping_patience', 3),
    min_delta=training_config.get('early_stopping_delta', 0.001),
    mode='max',  # For accuracy
    verbose=True
)
callbacks.append(early_stopping)

# Model checkpoint
checkpoint_dir = OUTPUT_DIR / "checkpoints" / "tutorial"
checkpoint_dir.mkdir(parents=True, exist_ok=True)

model_checkpoint = ModelCheckpoint(
    directory=checkpoint_dir,
    filename_prefix="deberta_v3",
    monitor='val_accuracy',
    mode='max',
    save_best_only=True,
    save_weights_only=False,
    verbose=True
)
callbacks.append(model_checkpoint)

print("Callbacks configured:")
print("="*50)
print(f"[CONFIGURED] Early Stopping (patience={training_config.get('early_stopping_patience', 3)})")
print(f"[CONFIGURED] Model Checkpoint (saving to {checkpoint_dir})")
print(f"[CONFIGURED] Learning Rate Monitor")
print(f"[CONFIGURED] Gradient Norm Logger")

## 7. Training Loop

In [None]:
# Initialize trainer
trainer = StandardTrainer(
    model=model,
    optimizer=optimizer,
    criterion=criterion,
    scheduler=scheduler,
    device=device,
    callbacks=callbacks,
    gradient_clip_val=training_config.get('gradient_clip_val', 1.0),
    gradient_accumulation_steps=training_config.get('gradient_accumulation_steps', 1),
    mixed_precision=training_config.get('mixed_precision', False)
)

# Training history storage
history = {
    'train_loss': [],
    'train_accuracy': [],
    'val_loss': [],
    'val_accuracy': [],
    'learning_rate': []
}

print("Starting training...")
print("="*50)

# Training loop
for epoch in range(training_config['num_epochs']):
    print(f"\nEpoch {epoch + 1}/{training_config['num_epochs']}")
    print("-" * 40)
    
    # Training phase
    train_metrics = trainer.train_epoch(
        train_dataloader=train_dataloader,
        epoch=epoch
    )
    
    # Validation phase
    val_metrics = trainer.validate(
        val_dataloader=val_dataloader
    )
    
    # Store metrics
    history['train_loss'].append(train_metrics['loss'])
    history['train_accuracy'].append(train_metrics['accuracy'])
    history['val_loss'].append(val_metrics['loss'])
    history['val_accuracy'].append(val_metrics['accuracy'])
    history['learning_rate'].append(scheduler.get_last_lr()[0])
    
    # Print metrics
    print(f"Train Loss: {train_metrics['loss']:.4f} | Train Acc: {train_metrics['accuracy']:.4f}")
    print(f"Val Loss: {val_metrics['loss']:.4f} | Val Acc: {val_metrics['accuracy']:.4f}")
    print(f"Learning Rate: {scheduler.get_last_lr()[0]:.2e}")
    
    # Check early stopping
    if early_stopping.check_stop(val_metrics['accuracy']):
        print(f"\nEarly stopping triggered at epoch {epoch + 1}")
        break

print("\nTraining completed!")

## 8. Training Visualization

In [None]:
# Plot training history
fig, axes = plt.subplots(1, 3, figsize=(15, 4))

# Loss plot
axes[0].plot(history['train_loss'], label='Train Loss', marker='o')
axes[0].plot(history['val_loss'], label='Val Loss', marker='s')
axes[0].set_xlabel('Epoch')
axes[0].set_ylabel('Loss')
axes[0].set_title('Training and Validation Loss')
axes[0].legend()
axes[0].grid(True, alpha=0.3)

# Accuracy plot
axes[1].plot(history['train_accuracy'], label='Train Accuracy', marker='o')
axes[1].plot(history['val_accuracy'], label='Val Accuracy', marker='s')
axes[1].set_xlabel('Epoch')
axes[1].set_ylabel('Accuracy')
axes[1].set_title('Training and Validation Accuracy')
axes[1].legend()
axes[1].grid(True, alpha=0.3)

# Learning rate plot
axes[2].plot(history['learning_rate'], label='Learning Rate', marker='d', color='green')
axes[2].set_xlabel('Epoch')
axes[2].set_ylabel('Learning Rate')
axes[2].set_title('Learning Rate Schedule')
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.suptitle('Training Progress', fontsize=14)
plt.tight_layout()
plt.show()

# Print final metrics
print("\nFinal Training Metrics:")
print("="*50)
print(f"Best Validation Accuracy: {max(history['val_accuracy']):.4f}")
print(f"Final Training Loss: {history['train_loss'][-1]:.4f}")
print(f"Final Validation Loss: {history['val_loss'][-1]:.4f}")

## 9. Model Evaluation

In [None]:
# Detailed evaluation on validation set
from src.evaluation.metrics.classification_metrics import ClassificationMetrics
from sklearn.metrics import classification_report, confusion_matrix

# Get predictions
model.eval()
all_predictions = []
all_labels = []
all_probs = []

with torch.no_grad():
    for batch in tqdm(val_dataloader, desc="Evaluating"):
        inputs = {k: v.to(device) for k, v in batch.items() if k != 'labels'}
        labels = batch['labels'].to(device)
        
        outputs = model(**inputs)
        logits = outputs.logits if hasattr(outputs, 'logits') else outputs
        
        probs = torch.softmax(logits, dim=-1)
        predictions = torch.argmax(logits, dim=-1)
        
        all_predictions.extend(predictions.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        all_probs.extend(probs.cpu().numpy())

# Calculate metrics
metrics_calculator = ClassificationMetrics(num_classes=AG_NEWS_NUM_CLASSES)
metrics = metrics_calculator.compute_metrics(
    predictions=np.array(all_predictions),
    labels=np.array(all_labels),
    probabilities=np.array(all_probs)
)

print("Detailed Evaluation Metrics:")
print("="*50)
print(f"Accuracy: {metrics['accuracy']:.4f}")
print(f"Macro F1: {metrics['macro_f1']:.4f}")
print(f"Weighted F1: {metrics['weighted_f1']:.4f}")
print(f"Macro Precision: {metrics['macro_precision']:.4f}")
print(f"Macro Recall: {metrics['macro_recall']:.4f}")

# Classification report
from configs.constants import ID_TO_LABEL

print("\nClassification Report:")
print("="*50)
print(classification_report(
    all_labels, 
    all_predictions, 
    target_names=[ID_TO_LABEL[i] for i in range(AG_NEWS_NUM_CLASSES)]
))

## 10. Model Saving and Loading

In [None]:
# Save the trained model
save_path = MODEL_DIR / "tutorial" / "deberta_v3_trained"
save_path.mkdir(parents=True, exist_ok=True)

# Save model weights
torch.save({
    'epoch': training_config['num_epochs'],
    'model_state_dict': model.state_dict(),
    'optimizer_state_dict': optimizer.state_dict(),
    'scheduler_state_dict': scheduler.state_dict(),
    'best_accuracy': max(history['val_accuracy']),
    'history': history,
    'config': training_config
}, save_path / "checkpoint.pt")

print(f"Model saved to: {save_path}")

# Save tokenizer
tokenizer.save_pretrained(save_path / "tokenizer")
print(f"Tokenizer saved to: {save_path / 'tokenizer'}")

# Demonstrate loading
print("\nLoading saved model...")

# Load checkpoint
checkpoint = torch.load(save_path / "checkpoint.pt", map_location=device)

# Create new model instance
loaded_model = initialize_model('deberta')
loaded_model.load_state_dict(checkpoint['model_state_dict'])
loaded_model = loaded_model.to(device)
loaded_model.eval()

print(f"Model loaded successfully!")
print(f"Best accuracy from checkpoint: {checkpoint['best_accuracy']:.4f}")

## 11. Conclusions and Next Steps

### Training Summary

This tutorial demonstrated fundamental model training concepts:

1. **Data Preparation**: Loaded and preprocessed AG News dataset
2. **Model Initialization**: Set up DeBERTa-v3 classifier
3. **Training Configuration**: Configured optimizer, scheduler, and loss
4. **Training Loop**: Executed training with validation
5. **Monitoring**: Tracked metrics and visualized progress
6. **Evaluation**: Computed comprehensive metrics
7. **Persistence**: Saved and loaded trained models

### Key Takeaways

1. **Hyperparameters Matter**: Learning rate, batch size, and warmup significantly impact performance
2. **Early Stopping**: Prevents overfitting and saves computation
3. **Label Smoothing**: Improves generalization for classification
4. **Gradient Accumulation**: Enables larger effective batch sizes
5. **Mixed Precision**: Speeds up training with minimal accuracy loss

### Next Steps

1. **Advanced Training**:
   - Try different models (RoBERTa, XLNet)
   - Implement ensemble training
   - Use adversarial training

2. **Optimization**:
   - Hyperparameter tuning with Optuna
   - Learning rate scheduling experiments
   - Different loss functions

3. **Efficiency**:
   - LoRA fine-tuning
   - Quantization-aware training
   - Knowledge distillation

4. **Production**:
   - Model optimization for inference
   - API deployment
   - Monitoring and maintenance

### References

For deeper understanding, consult:
- Training documentation: `docs/user_guide/model_training.md`
- Advanced techniques: `notebooks/tutorials/06_instruction_tuning.ipynb`
- API deployment: `notebooks/tutorials/07_api_usage.ipynb`