# üåäüèúÔ∏è Natural Disaster Prediction Model Training

## Big Data and Deep Learning-Based Natural Disaster Prediction Using Multi-Source Environmental Data

This notebook trains a multi-encoder deep learning model for predicting floods and droughts.

### Architecture Overview:
- **CNN Encoder**: Processes satellite features (NDVI, EVI, LST) - spatial patterns
- **LSTM Encoder**: Processes weather sequences (temp, precipitation, wind, etc.) - temporal patterns
- **MLP Encoder**: Processes static features (elevation, landcover, coordinates) - geographic context
- **Mid-Level Fusion**: Combines all encoder outputs
- **Prediction Heads**: Flood (binary) and Drought (multi-class) classification

## 1. Setup and Imports

In [1]:
# Standard libraries
import os
import sys
import time
import warnings
from datetime import datetime

# Data processing
import numpy as np
import pandas as pd

# Deep Learning
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.optim import AdamW
from torch.optim.lr_scheduler import ReduceLROnPlateau
from torch.amp import GradScaler, autocast

# Visualization
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.notebook import tqdm

# Suppress warnings
warnings.filterwarnings('ignore')

# Add project root to path - FIX for notebook
if 'notebooks' in os.getcwd():
    PROJECT_ROOT = os.path.dirname(os.getcwd())
else:
    PROJECT_ROOT = os.getcwd()

# Ensure project root is in path
if PROJECT_ROOT not in sys.path:
    sys.path.insert(0, PROJECT_ROOT)

# Change to project root for consistent paths
os.chdir(PROJECT_ROOT)

print(f"Project root: {PROJECT_ROOT}")
print(f"Current working directory: {os.getcwd()}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if hasattr(torch.backends, 'mps'):
    print(f"MPS available: {torch.backends.mps.is_available()}")

Project root: /Users/leonnn/Desktop/DL for disaster
Current working directory: /Users/leonnn/Desktop/DL for disaster
PyTorch version: 2.10.0
CUDA available: False
MPS available: True


In [2]:
# Import project modules
from configs.config import (
    get_config, DATA_PATH, MODEL_DIR, LOG_DIR,
    SATELLITE_FEATURES, WEATHER_FEATURES, STATIC_FEATURES
)
from src.dataset import (
    DisasterDataProcessor, DisasterDataset, 
    create_dataloaders, compute_class_weights
)
from src.models import (
    DisasterPredictionModel, MultiTaskLoss, create_model
)
from src.utils import (
    setup_logger, MetricsCalculator, EarlyStopping,
    CheckpointManager, TrainingHistory, set_seed, get_device,
    format_time, plot_confusion_matrix
)

print("\n‚úÖ All modules imported successfully!")


‚úÖ All modules imported successfully!


## 2. Configuration

In [3]:
# Load configuration
config = get_config()

# Set random seed for reproducibility
set_seed(config.data.random_seed)

# Get device
device = get_device()
config.device = str(device)

# Print configuration summary
print("="*60)
print("CONFIGURATION SUMMARY")
print("="*60)
print(f"\nüìä Data Configuration:")
print(f"   Train/Val/Test: {config.data.train_ratio}/{config.data.valid_ratio}/{config.data.test_ratio}")
print(f"   Sequence length: {config.data.sequence_length} days")
print(f"   Grid size: {config.data.grid_size}x{config.data.grid_size}")
print(f"   Batch size: {config.data.batch_size}")

print(f"\nüèóÔ∏è Model Configuration:")
print(f"   Encoder output dim: {config.model.encoder_output_dim}")
print(f"   CNN channels: {config.model.cnn_channels}")
print(f"   LSTM hidden: {config.model.lstm_hidden_size}, layers: {config.model.lstm_num_layers}")
print(f"   MLP hidden: {config.model.mlp_hidden_sizes}")

print(f"\nüéØ Training Configuration:")
print(f"   Epochs: {config.training.num_epochs}")
print(f"   Learning rate: {config.training.learning_rate}")
print(f"   Early stopping patience: {config.training.early_stopping_patience}")
print(f"   Device: {device}")
print("="*60)

Random seed set to 42
Using Apple Silicon MPS
CONFIGURATION SUMMARY

üìä Data Configuration:
   Train/Val/Test: 0.7/0.15/0.15
   Sequence length: 7 days
   Grid size: 5x5
   Batch size: 256

üèóÔ∏è Model Configuration:
   Encoder output dim: 128
   CNN channels: [32, 64, 128]
   LSTM hidden: 128, layers: 2
   MLP hidden: [64, 128]

üéØ Training Configuration:
   Epochs: 100
   Learning rate: 0.001
   Early stopping patience: 10
   Device: mps


## 3. Load and Preprocess Data

In [4]:
# Initialize data processor
processor = DisasterDataProcessor(config.data)

# Process data (load, handle missing values, normalize, split)
print("Loading and preprocessing data...")
train_df, val_df, test_df = processor.process(DATA_PATH)

# Save preprocessors for inference
processor.save_preprocessors()

Loading and preprocessing data...
Loading data from /Users/leonnn/Desktop/DL for disaster/SEA_2024_FINAL_CLEAN.csv...
Loaded 1,323,822 rows with 21 columns

Data splits:
  Train: 926,346 samples (70.0%)
  Valid: 198,738 samples (15.0%)
  Test:  198,738 samples (15.0%)
Filled 2,227 missing values in 'ndvi' with 0.6749
Filled 2,227 missing values in 'evi' with 0.4121
Filled 544,641 missing values in 'lst' with 27.3900
Filled 16,836 missing values in 'precip_mm' with 2.1642
Filled 16,836 missing values in 'temp_c' with 25.0372
Filled 16,836 missing values in 'dewpoint_c' with 20.9961
Filled 16,836 missing values in 'wind_u' with -0.1192
Filled 16,836 missing values in 'wind_v' with 0.2172
Filled 16,836 missing values in 'evap_mm' with 2.9369
Filled 20,273 missing values in 'pressure_hpa' with 978.2073
Filled 16,836 missing values in 'soil_temp_c' with 25.7139
Filled 720 missing values in 'ndvi' with 0.6749
Filled 720 missing values in 'evi' with 0.4121
Filled 117,325 missing values in 'ls

In [5]:
# Compute class weights for imbalanced data
class_weights = compute_class_weights(train_df)

# Move weights to device
flood_weights = class_weights['flood'].to(device)
drought_weights = class_weights['drought'].to(device)

print(f"\nüìä Data Distribution:")
print(f"   Training samples: {len(train_df):,}")
print(f"   Validation samples: {len(val_df):,}")
print(f"   Test samples: {len(test_df):,}")

print(f"\n   Flood distribution (train):")
print(train_df['flood'].value_counts())
print(f"\n   Drought distribution (train):")
print(train_df['drought'].value_counts())

Class weights computed:
  Flood: [0.5268532037734985, 9.809869766235352]
  Drought: [0.5105904936790466, 24.106016159057617]

üìä Data Distribution:
   Training samples: 926,346
   Validation samples: 198,738
   Test samples: 198,738

   Flood distribution (train):
flood
0    879131
1     47215
Name: count, dtype: int64

   Drought distribution (train):
drought
0    907132
1     19214
Name: count, dtype: int64


## 4. Create DataLoaders

In [6]:
# Create dataloaders
print("Creating DataLoaders...")
train_loader, val_loader, test_loader = create_dataloaders(
    train_df, val_df, test_df, config.data
)

print(f"\n‚úÖ DataLoaders created:")
print(f"   Train batches: {len(train_loader):,}")
print(f"   Val batches: {len(val_loader):,}")
print(f"   Test batches: {len(test_loader):,}")

Creating DataLoaders...
Created 911,160 valid samples
Created 195,480 valid samples
Created 195,480 valid samples

‚úÖ DataLoaders created:
   Train batches: 3,560
   Val batches: 764
   Test batches: 764


In [7]:
# Verify batch shapes
sample_batch = next(iter(train_loader))

print("üì¶ Sample batch shapes:")
print(f"   Satellite (CNN input): {sample_batch['satellite'].shape}")
print(f"   Weather (LSTM input): {sample_batch['weather'].shape}")
print(f"   Static (MLP input): {sample_batch['static'].shape}")
print(f"   Flood labels: {sample_batch['flood_label'].shape}")
print(f"   Drought labels: {sample_batch['drought_label'].shape}")

üì¶ Sample batch shapes:
   Satellite (CNN input): torch.Size([256, 3, 5, 5])
   Weather (LSTM input): torch.Size([256, 7, 8])
   Static (MLP input): torch.Size([256, 4])
   Flood labels: torch.Size([256])
   Drought labels: torch.Size([256])


## 5. Initialize Model

In [8]:
# Create model
model = create_model(config.model, device)

# Move to device
model = model.to(device)

# Print model architecture
print("\nüèóÔ∏è Model Architecture:")
print(model)


Model Parameter Summary:
  cnn_encoder: 126,720
  lstm_encoder: 743,041
  mlp_encoder: 25,536
  fusion: 148,736
  flood_head: 8,386
  drought_head: 8,386
  total: 1,060,805


üèóÔ∏è Model Architecture:
DisasterPredictionModel(
  (cnn_encoder): CNNEncoder(
    (conv_layers): Sequential(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): ReLU(inplace=True)
      (3): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
      (4): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (5): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): ReLU(inplace=True)
      (7): MaxPool2d(kernel_size=2, stride=1, padding=0, dilation=1, ceil_mode=False)
      (8): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (9): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, 

In [9]:
# Test forward pass
model.eval()
with torch.no_grad():
    satellite = sample_batch['satellite'].to(device)
    weather = sample_batch['weather'].to(device)
    static = sample_batch['static'].to(device)
    
    outputs = model(satellite, weather, static)
    
    print("\n‚úÖ Forward pass successful!")
    print(f"   Flood logits: {outputs['flood_logits'].shape}")
    print(f"   Drought logits: {outputs['drought_logits'].shape}")
    print(f"   CNN features: {outputs['cnn_features'].shape}")
    print(f"   LSTM features: {outputs['lstm_features'].shape}")
    print(f"   MLP features: {outputs['mlp_features'].shape}")
    print(f"   Fused features: {outputs['fused_features'].shape}")


‚úÖ Forward pass successful!
   Flood logits: torch.Size([256, 2])
   Drought logits: torch.Size([256, 2])
   CNN features: torch.Size([256, 128])
   LSTM features: torch.Size([256, 128])
   MLP features: torch.Size([256, 128])
   Fused features: torch.Size([256, 128])


## 6. Setup Training Components

In [10]:
# Loss function with class weights
criterion = MultiTaskLoss(
    flood_weight=config.training.flood_loss_weight,
    drought_weight=config.training.drought_loss_weight,
    flood_class_weights=flood_weights,
    drought_class_weights=drought_weights
)

# Optimizer
optimizer = AdamW(
    model.parameters(),
    lr=config.training.learning_rate,
    weight_decay=config.training.weight_decay
)

# Learning rate scheduler (verbose removed for PyTorch compatibility)
scheduler = ReduceLROnPlateau(
    optimizer,
    mode='max',
    patience=config.training.scheduler_patience,
    factor=config.training.scheduler_factor,
    min_lr=config.training.scheduler_min_lr
)

# Early stopping
early_stopping = EarlyStopping(
    patience=config.training.early_stopping_patience,
    min_delta=config.training.early_stopping_delta,
    mode='max'
)

# Checkpoint manager
checkpoint_manager = CheckpointManager(
    model_dir=MODEL_DIR,
    experiment_name=config.experiment_name,
    save_best_only=config.training.save_best_only,
    mode='max'
)

# Training history
history = TrainingHistory()

# Mixed precision scaler
scaler = GradScaler() if config.training.use_amp and device.type == 'cuda' else None

print("‚úÖ Training components initialized!")

‚úÖ Training components initialized!


## 7. Training Functions

In [11]:
def train_epoch(model, dataloader, criterion, optimizer, device, scaler=None):
    """
    Train for one epoch
    """
    model.train()
    total_loss = 0
    metrics_calc = MetricsCalculator()
    
    pbar = tqdm(dataloader, desc="Training", leave=False)
    
    for batch in pbar:
        # Move data to device
        satellite = batch['satellite'].to(device)
        weather = batch['weather'].to(device)
        static = batch['static'].to(device)
        flood_labels = batch['flood_label'].to(device)
        drought_labels = batch['drought_label'].to(device)
        
        # Zero gradients
        optimizer.zero_grad()
        
        # Forward pass with mixed precision
        if scaler:
            with autocast():
                outputs = model(satellite, weather, static)
                losses = criterion(
                    outputs['flood_logits'],
                    outputs['drought_logits'],
                    flood_labels,
                    drought_labels
                )
            
            # Backward pass with scaling
            scaler.scale(losses['total_loss']).backward()
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.training.gradient_clip_value)
            scaler.step(optimizer)
            scaler.update()
        else:
            outputs = model(satellite, weather, static)
            losses = criterion(
                outputs['flood_logits'],
                outputs['drought_logits'],
                flood_labels,
                drought_labels
            )
            
            # Backward pass
            losses['total_loss'].backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), config.training.gradient_clip_value)
            optimizer.step()
        
        # Update metrics
        total_loss += losses['total_loss'].item()
        
        flood_preds = torch.argmax(outputs['flood_logits'], dim=1)
        drought_preds = torch.argmax(outputs['drought_logits'], dim=1)
        flood_probs = F.softmax(outputs['flood_logits'], dim=1)
        drought_probs = F.softmax(outputs['drought_logits'], dim=1)
        
        metrics_calc.update('flood', flood_preds, flood_labels, flood_probs)
        metrics_calc.update('drought', drought_preds, drought_labels, drought_probs)
        
        # Update progress bar
        pbar.set_postfix({'loss': f"{losses['total_loss'].item():.4f}"})
    
    avg_loss = total_loss / len(dataloader)
    metrics = metrics_calc.compute()
    
    return avg_loss, metrics


@torch.no_grad()
def validate_epoch(model, dataloader, criterion, device):
    """
    Validate for one epoch
    """
    model.eval()
    total_loss = 0
    metrics_calc = MetricsCalculator()
    
    pbar = tqdm(dataloader, desc="Validating", leave=False)
    
    for batch in pbar:
        # Move data to device
        satellite = batch['satellite'].to(device)
        weather = batch['weather'].to(device)
        static = batch['static'].to(device)
        flood_labels = batch['flood_label'].to(device)
        drought_labels = batch['drought_label'].to(device)
        
        # Forward pass
        outputs = model(satellite, weather, static)
        losses = criterion(
            outputs['flood_logits'],
            outputs['drought_logits'],
            flood_labels,
            drought_labels
        )
        
        # Update metrics
        total_loss += losses['total_loss'].item()
        
        flood_preds = torch.argmax(outputs['flood_logits'], dim=1)
        drought_preds = torch.argmax(outputs['drought_logits'], dim=1)
        flood_probs = F.softmax(outputs['flood_logits'], dim=1)
        drought_probs = F.softmax(outputs['drought_logits'], dim=1)
        
        metrics_calc.update('flood', flood_preds, flood_labels, flood_probs)
        metrics_calc.update('drought', drought_preds, drought_labels, drought_probs)
    
    avg_loss = total_loss / len(dataloader)
    metrics = metrics_calc.compute()
    
    return avg_loss, metrics

print("‚úÖ Training functions defined!")

‚úÖ Training functions defined!


## 8. Training Loop

In [None]:
print("="*60)
print("üöÄ STARTING TRAINING")
print("="*60)
print(f"Epochs: {config.training.num_epochs}")
print(f"Device: {device}")
print(f"Mixed Precision: {scaler is not None}")
print("="*60 + "\n")

best_val_f1 = 0
start_time = time.time()

for epoch in range(1, config.training.num_epochs + 1):
    epoch_start = time.time()
    
    # Training
    train_loss, train_metrics = train_epoch(
        model, train_loader, criterion, optimizer, device, scaler
    )
    
    # Validation
    val_loss, val_metrics = validate_epoch(
        model, val_loader, criterion, device
    )
    
    epoch_time = time.time() - epoch_start
    current_lr = optimizer.param_groups[0]['lr']
    
    # Calculate average F1 for monitoring
    val_f1_avg = (val_metrics.get('flood_f1', 0) + val_metrics.get('drought_f1', 0)) / 2
    
    # Update history
    history.update(train_loss, val_loss, train_metrics, val_metrics, current_lr, epoch_time)
    
    # Update scheduler
    scheduler.step(val_f1_avg)
    
    # Print epoch summary
    print(f"\nEpoch {epoch}/{config.training.num_epochs} | Time: {format_time(epoch_time)}")
    print(f"  Train Loss: {train_loss:.4f} | Val Loss: {val_loss:.4f}")
    print(f"  Flood   - Train F1: {train_metrics.get('flood_f1', 0):.4f} | Val F1: {val_metrics.get('flood_f1', 0):.4f}")
    print(f"  Drought - Train F1: {train_metrics.get('drought_f1', 0):.4f} | Val F1: {val_metrics.get('drought_f1', 0):.4f}")
    print(f"  Average Val F1: {val_f1_avg:.4f} | LR: {current_lr:.2e}")
    
    # Save best model
    if val_f1_avg > best_val_f1:
        best_val_f1 = val_f1_avg
        checkpoint_manager.save(
            model, optimizer, scheduler, epoch,
            val_metrics, val_f1_avg
        )
        print(f"  ‚úÖ New best model saved! F1: {val_f1_avg:.4f}")
    
    # Early stopping check
    if early_stopping(val_f1_avg, epoch):
        print(f"\n‚ö†Ô∏è Early stopping triggered at epoch {epoch}")
        break

total_time = time.time() - start_time
print("\n" + "="*60)
print("‚úÖ TRAINING COMPLETED")
print("="*60)
print(f"Total time: {format_time(total_time)}")
print(f"Best Val F1: {best_val_f1:.4f}")
print(f"Best epoch: {early_stopping.best_epoch}")
print("="*60)

üöÄ STARTING TRAINING
Epochs: 100
Device: mps
Mixed Precision: False



Training:   0%|          | 0/3560 [00:00<?, ?it/s]

## 9. Save Training History and Visualize

In [None]:
# Save training history
history_path = os.path.join(LOG_DIR, f"{config.experiment_name}_history.json")
history.save(history_path)

# Plot training curves
fig = history.plot(save_path=os.path.join(LOG_DIR, f"{config.experiment_name}_curves.png"))
plt.show()

## 10. Evaluate on Test Set

In [None]:
# Load best model
checkpoint_manager.load(model, filepath=os.path.join(MODEL_DIR, f"{config.experiment_name}_best.pt"))

# Evaluate on test set
test_loss, test_metrics = validate_epoch(model, test_loader, criterion, device)

print("\n" + "="*60)
print("üìä TEST SET EVALUATION")
print("="*60)
print(f"\nTest Loss: {test_loss:.4f}")
print(f"\nüåä Flood Prediction:")
print(f"   Accuracy: {test_metrics.get('flood_accuracy', 0):.4f}")
print(f"   Precision: {test_metrics.get('flood_precision', 0):.4f}")
print(f"   Recall: {test_metrics.get('flood_recall', 0):.4f}")
print(f"   F1 Score: {test_metrics.get('flood_f1', 0):.4f}")
if 'flood_roc_auc' in test_metrics:
    print(f"   ROC-AUC: {test_metrics.get('flood_roc_auc', 0):.4f}")

print(f"\nüèúÔ∏è Drought Prediction:")
print(f"   Accuracy: {test_metrics.get('drought_accuracy', 0):.4f}")
print(f"   Precision: {test_metrics.get('drought_precision', 0):.4f}")
print(f"   Recall: {test_metrics.get('drought_recall', 0):.4f}")
print(f"   F1 Score: {test_metrics.get('drought_f1', 0):.4f}")
if 'drought_roc_auc' in test_metrics:
    print(f"   ROC-AUC: {test_metrics.get('drought_roc_auc', 0):.4f}")

print(f"\nüìà Average Metrics:")
print(f"   Accuracy: {test_metrics.get('avg_accuracy', 0):.4f}")
print(f"   F1 Score: {test_metrics.get('avg_f1', 0):.4f}")
print("="*60)

In [None]:
# Generate detailed evaluation metrics and confusion matrices
model.eval()
all_flood_preds, all_flood_labels = [], []
all_drought_preds, all_drought_labels = [], []

with torch.no_grad():
    for batch in tqdm(test_loader, desc="Evaluating"):
        satellite = batch['satellite'].to(device)
        weather = batch['weather'].to(device)
        static = batch['static'].to(device)
        
        outputs = model(satellite, weather, static)
        
        flood_preds = torch.argmax(outputs['flood_logits'], dim=1)
        drought_preds = torch.argmax(outputs['drought_logits'], dim=1)
        
        all_flood_preds.extend(flood_preds.cpu().numpy())
        all_flood_labels.extend(batch['flood_label'].numpy())
        all_drought_preds.extend(drought_preds.cpu().numpy())
        all_drought_labels.extend(batch['drought_label'].numpy())

# Plot confusion matrices
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Flood confusion matrix
from sklearn.metrics import confusion_matrix
cm_flood = confusion_matrix(all_flood_labels, all_flood_preds)
sns.heatmap(cm_flood, annot=True, fmt='d', cmap='Blues', ax=axes[0],
            xticklabels=['No Flood', 'Flood'],
            yticklabels=['No Flood', 'Flood'])
axes[0].set_title('Flood Prediction Confusion Matrix')
axes[0].set_xlabel('Predicted')
axes[0].set_ylabel('True')

# Drought confusion matrix
cm_drought = confusion_matrix(all_drought_labels, all_drought_preds)
sns.heatmap(cm_drought, annot=True, fmt='d', cmap='Oranges', ax=axes[1],
            xticklabels=['No Drought', 'Drought'],
            yticklabels=['No Drought', 'Drought'])
axes[1].set_title('Drought Prediction Confusion Matrix')
axes[1].set_xlabel('Predicted')
axes[1].set_ylabel('True')

plt.tight_layout()
plt.savefig(os.path.join(LOG_DIR, f"{config.experiment_name}_confusion_matrices.png"), dpi=150)
plt.show()

## 11. Save Final Results

In [None]:
import json

# Save final results
results = {
    "experiment_name": config.experiment_name,
    "model_config": {
        "encoder_output_dim": config.model.encoder_output_dim,
        "cnn_channels": config.model.cnn_channels,
        "lstm_hidden_size": config.model.lstm_hidden_size,
        "lstm_num_layers": config.model.lstm_num_layers
    },
    "training_config": {
        "num_epochs": config.training.num_epochs,
        "learning_rate": config.training.learning_rate,
        "batch_size": config.data.batch_size,
        "sequence_length": config.data.sequence_length
    },
    "best_epoch": early_stopping.best_epoch,
    "best_val_f1": best_val_f1,
    "test_metrics": {k: float(v) for k, v in test_metrics.items()},
    "total_training_time": total_time
}

results_path = os.path.join(LOG_DIR, f"{config.experiment_name}_results.json")
with open(results_path, 'w') as f:
    json.dump(results, f, indent=2)

print(f"\n‚úÖ Results saved to {results_path}")
print(f"\nüéâ Training complete! Model saved to {MODEL_DIR}")

---
## Summary

This notebook trained a multi-encoder deep learning model for natural disaster prediction:

1. **CNN Encoder** processed satellite imagery features (NDVI, EVI, LST) to capture spatial patterns
2. **LSTM Encoder** processed weather sequences to capture temporal patterns
3. **MLP Encoder** processed static geographic features (elevation, landcover, coordinates)
4. **Mid-Level Fusion** combined all encoder outputs
5. **Dual Prediction Heads** made predictions for both flood and drought

The model was trained with:
- Class-weighted loss for handling imbalanced data
- Learning rate scheduling
- Early stopping
- Mixed precision training (when available)

Next steps:
- Run the evaluation notebook for detailed analysis
- Fine-tune hyperparameters if needed
- Deploy model for inference