In [None]:
# Imports
import sys
from pathlib import Path

# Add src to path
project_root = Path.cwd().parent
sys.path.insert(0, str(project_root / 'src'))

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import CosineAnnealingWarmRestarts
from tqdm.notebook import tqdm
import numpy as np
from datetime import datetime

# Local imports
from data.dataset import create_dataloaders
from models.efficientnet import create_efficientnet_b0
from utils.training import (
    calculate_class_weights,
    EarlyStopping,
    MetricsCalculator,
    TensorBoardLogger,
    save_checkpoint,
    load_checkpoint
)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")

## 1. Configuration

In [None]:
# Configuration
class Config:
    # Paths
    DATA_DIR = project_root.parent / 'data' / 'processed'
    CHECKPOINT_DIR = project_root.parent / 'models_exported'
    LOG_DIR = project_root.parent / 'runs' / 'efficientnet'
    
    # Model
    NUM_CLASSES = 38
    PRETRAINED = True
    DROPOUT = 0.3
    
    # Training - Phase 1 (Warmup)
    WARMUP_EPOCHS = 5
    WARMUP_LR = 1e-3
    
    # Training - Phase 2 (Fine-tuning)
    FINETUNE_EPOCHS = 25
    BACKBONE_LR = 1e-4
    CLASSIFIER_LR = 1e-3
    UNFREEZE_BLOCKS = 3
    
    # Scheduler
    T_0 = 5  # First restart period
    T_MULT = 2  # Period multiplier
    
    # Data
    BATCH_SIZE = 32
    NUM_WORKERS = 4
    IMAGE_SIZE = 224
    
    # Training
    EARLY_STOPPING_PATIENCE = 5
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    SEED = 42

config = Config()

# Create directories
config.CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
config.LOG_DIR.mkdir(parents=True, exist_ok=True)

print(f"Data directory: {config.DATA_DIR}")
print(f"Checkpoint directory: {config.CHECKPOINT_DIR}")
print(f"Device: {config.DEVICE}")

In [None]:
# Set random seeds for reproducibility
def set_seed(seed: int):
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    np.random.seed(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

set_seed(config.SEED)
print(f"Random seed set to {config.SEED}")

## 2. Load Data

In [None]:
# Create data loaders
dataloaders, class_names = create_dataloaders(
    data_dir=config.DATA_DIR,
    batch_size=config.BATCH_SIZE,
    num_workers=config.NUM_WORKERS,
    image_size=config.IMAGE_SIZE
)

train_loader = dataloaders['train']
val_loader = dataloaders['val']

print(f"Number of classes: {len(class_names)}")
print(f"Training samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(val_loader.dataset)}")
print(f"Training batches: {len(train_loader)}")
print(f"Validation batches: {len(val_loader)}")

In [None]:
# Calculate class weights for handling class imbalance
class_weights = calculate_class_weights(train_loader.dataset)
class_weights = class_weights.to(config.DEVICE)

print(f"Class weights shape: {class_weights.shape}")
print(f"Min weight: {class_weights.min():.4f}")
print(f"Max weight: {class_weights.max():.4f}")

## 3. Create Model

In [None]:
# Create model
model = create_efficientnet_b0(
    num_classes=config.NUM_CLASSES,
    pretrained=config.PRETRAINED,
    dropout=config.DROPOUT
)
model = model.to(config.DEVICE)

print(f"Model created: EfficientNet-B0")
print(f"Total parameters: {model.get_num_params(trainable_only=False):,}")
print(f"Trainable parameters: {model.get_num_params(trainable_only=True):,}")

In [None]:
# Freeze backbone for Phase 1 (warmup)
model.freeze_backbone()
print(f"Trainable parameters after freezing: {model.get_num_params(trainable_only=True):,}")

## 4. Training Setup

In [None]:
# Loss function with class weights
criterion = nn.CrossEntropyLoss(weight=class_weights)

# Metrics calculator
metrics_calculator = MetricsCalculator(num_classes=config.NUM_CLASSES)

# TensorBoard logger
experiment_name = f"efficientnet_b0_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
logger = TensorBoardLogger(log_dir=config.LOG_DIR / experiment_name)

# Early stopping
early_stopping = EarlyStopping(
    patience=config.EARLY_STOPPING_PATIENCE,
    mode='max',  # Maximize validation accuracy
    min_delta=0.001
)

print(f"Experiment: {experiment_name}")
print(f"TensorBoard logs: {config.LOG_DIR / experiment_name}")

In [None]:
# Training functions
def train_one_epoch(model, train_loader, criterion, optimizer, device, epoch, logger=None):
    """Train for one epoch."""
    model.train()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    pbar = tqdm(train_loader, desc=f"Epoch {epoch}")
    for batch_idx, (inputs, labels) in enumerate(pbar):
        inputs, labels = inputs.to(device), labels.to(device)
        
        optimizer.zero_grad()
        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())
        
        pbar.set_postfix({'loss': loss.item()})
    
    epoch_loss = running_loss / len(train_loader)
    epoch_acc = np.mean(np.array(all_preds) == np.array(all_labels))
    
    if logger:
        logger.log_scalar('train/loss', epoch_loss, epoch)
        logger.log_scalar('train/accuracy', epoch_acc, epoch)
        logger.log_scalar('train/lr', optimizer.param_groups[0]['lr'], epoch)
    
    return epoch_loss, epoch_acc


def validate(model, val_loader, criterion, device, epoch, metrics_calculator, logger=None):
    """Validate the model."""
    model.eval()
    running_loss = 0.0
    all_preds = []
    all_labels = []
    
    with torch.no_grad():
        for inputs, labels in tqdm(val_loader, desc="Validation"):
            inputs, labels = inputs.to(device), labels.to(device)
            
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            
            running_loss += loss.item()
            _, preds = torch.max(outputs, 1)
            all_preds.extend(preds.cpu().numpy())
            all_labels.extend(labels.cpu().numpy())
    
    epoch_loss = running_loss / len(val_loader)
    
    # Calculate metrics
    metrics = metrics_calculator.calculate(
        np.array(all_labels), 
        np.array(all_preds)
    )
    
    if logger:
        logger.log_scalar('val/loss', epoch_loss, epoch)
        logger.log_scalar('val/accuracy', metrics['accuracy'], epoch)
        logger.log_scalar('val/f1_macro', metrics['f1_macro'], epoch)
        logger.log_scalar('val/f1_weighted', metrics['f1_weighted'], epoch)
    
    return epoch_loss, metrics

print("Training functions defined.")

## 5. Phase 1: Warmup Training (Frozen Backbone)

In [None]:
print("=" * 60)
print("PHASE 1: Warmup Training (Frozen Backbone)")
print("=" * 60)

# Phase 1 optimizer - only classifier parameters
optimizer_phase1 = optim.AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=config.WARMUP_LR,
    weight_decay=0.01
)

# Phase 1 scheduler
scheduler_phase1 = CosineAnnealingWarmRestarts(
    optimizer_phase1,
    T_0=config.T_0,
    T_mult=config.T_MULT
)

best_val_acc_phase1 = 0.0

for epoch in range(1, config.WARMUP_EPOCHS + 1):
    print(f"\n--- Epoch {epoch}/{config.WARMUP_EPOCHS} (Phase 1) ---")
    
    # Train
    train_loss, train_acc = train_one_epoch(
        model, train_loader, criterion, optimizer_phase1, 
        config.DEVICE, epoch, logger
    )
    
    # Validate
    val_loss, val_metrics = validate(
        model, val_loader, criterion, config.DEVICE, 
        epoch, metrics_calculator, logger
    )
    
    # Step scheduler
    scheduler_phase1.step()
    
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_metrics['accuracy']:.4f}, F1: {val_metrics['f1_macro']:.4f}")
    
    # Save best model
    if val_metrics['accuracy'] > best_val_acc_phase1:
        best_val_acc_phase1 = val_metrics['accuracy']
        save_checkpoint(
            model, optimizer_phase1, epoch, val_metrics['accuracy'],
            config.CHECKPOINT_DIR / 'efficientnet_phase1_best.pth'
        )
        print(f"  -> New best model saved! (Acc: {best_val_acc_phase1:.4f})")

print(f"\nPhase 1 Complete! Best Val Accuracy: {best_val_acc_phase1:.4f}")

## 6. Phase 2: Fine-tuning (Unfrozen Last 3 Blocks)

In [None]:
print("=" * 60)
print("PHASE 2: Fine-tuning (Progressive Unfreezing)")
print("=" * 60)

# Unfreeze last N blocks
model.unfreeze_last_n_blocks(n=config.UNFREEZE_BLOCKS)
print(f"Trainable parameters after unfreezing: {model.get_num_params(trainable_only=True):,}")

In [None]:
# Phase 2 optimizer with differential learning rates
param_groups = model.get_optimizer_param_groups(
    backbone_lr=config.BACKBONE_LR,
    classifier_lr=config.CLASSIFIER_LR
)

optimizer_phase2 = optim.AdamW(
    param_groups,
    weight_decay=0.01
)

# Phase 2 scheduler
scheduler_phase2 = CosineAnnealingWarmRestarts(
    optimizer_phase2,
    T_0=config.T_0,
    T_mult=config.T_MULT
)

# Reset early stopping for phase 2
early_stopping = EarlyStopping(
    patience=config.EARLY_STOPPING_PATIENCE,
    mode='max',
    min_delta=0.001
)

print(f"Backbone LR: {config.BACKBONE_LR}")
print(f"Classifier LR: {config.CLASSIFIER_LR}")

In [None]:
best_val_acc_phase2 = best_val_acc_phase1
start_epoch = config.WARMUP_EPOCHS + 1

for epoch in range(start_epoch, start_epoch + config.FINETUNE_EPOCHS):
    print(f"\n--- Epoch {epoch}/{start_epoch + config.FINETUNE_EPOCHS - 1} (Phase 2) ---")
    
    # Train
    train_loss, train_acc = train_one_epoch(
        model, train_loader, criterion, optimizer_phase2,
        config.DEVICE, epoch, logger
    )
    
    # Validate
    val_loss, val_metrics = validate(
        model, val_loader, criterion, config.DEVICE,
        epoch, metrics_calculator, logger
    )
    
    # Step scheduler
    scheduler_phase2.step()
    
    print(f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}")
    print(f"Val Loss: {val_loss:.4f}, Val Acc: {val_metrics['accuracy']:.4f}, F1: {val_metrics['f1_macro']:.4f}")
    print(f"LR - Backbone: {optimizer_phase2.param_groups[0]['lr']:.6f}, Classifier: {optimizer_phase2.param_groups[1]['lr']:.6f}")
    
    # Save best model
    if val_metrics['accuracy'] > best_val_acc_phase2:
        best_val_acc_phase2 = val_metrics['accuracy']
        save_checkpoint(
            model, optimizer_phase2, epoch, val_metrics['accuracy'],
            config.CHECKPOINT_DIR / 'efficientnet_best.pth'
        )
        print(f"  -> New best model saved! (Acc: {best_val_acc_phase2:.4f})")
    
    # Early stopping check
    early_stopping(val_metrics['accuracy'])
    if early_stopping.early_stop:
        print(f"\nEarly stopping triggered at epoch {epoch}!")
        break

print(f"\nPhase 2 Complete! Best Val Accuracy: {best_val_acc_phase2:.4f}")

## 7. Save Final Model

In [None]:
# Save final model
save_checkpoint(
    model, optimizer_phase2, epoch, val_metrics['accuracy'],
    config.CHECKPOINT_DIR / 'efficientnet_final.pth',
    extra_info={
        'class_names': class_names,
        'config': {
            'num_classes': config.NUM_CLASSES,
            'dropout': config.DROPOUT,
            'image_size': config.IMAGE_SIZE
        }
    }
)

print(f"Final model saved to: {config.CHECKPOINT_DIR / 'efficientnet_final.pth'}")

In [None]:
# Close TensorBoard logger
logger.close()

print("\n" + "=" * 60)
print("TRAINING COMPLETE!")
print("=" * 60)
print(f"Best Validation Accuracy: {best_val_acc_phase2:.4f}")
print(f"\nTo view TensorBoard logs, run:")
print(f"  tensorboard --logdir {config.LOG_DIR}")

## 8. Quick Evaluation on Validation Set

In [None]:
# Load best model for evaluation
best_checkpoint = load_checkpoint(
    config.CHECKPOINT_DIR / 'efficientnet_best.pth',
    model
)

print(f"Loaded best model from epoch {best_checkpoint['epoch']}")
print(f"Best validation accuracy: {best_checkpoint['best_metric']:.4f}")

In [None]:
# Final evaluation
model.eval()
all_preds = []
all_labels = []

with torch.no_grad():
    for inputs, labels in tqdm(val_loader, desc="Final Evaluation"):
        inputs = inputs.to(config.DEVICE)
        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.numpy())

# Calculate final metrics
final_metrics = metrics_calculator.calculate(
    np.array(all_labels),
    np.array(all_preds)
)

print("\nFinal Validation Metrics:")
print(f"  Accuracy: {final_metrics['accuracy']:.4f}")
print(f"  F1 (Macro): {final_metrics['f1_macro']:.4f}")
print(f"  F1 (Weighted): {final_metrics['f1_weighted']:.4f}")

In [None]:
# Display per-class F1 scores (top 10 worst performing)
import pandas as pd

f1_scores = final_metrics['f1_per_class']
class_f1_df = pd.DataFrame({
    'Class': class_names,
    'F1 Score': f1_scores
}).sort_values('F1 Score')

print("\nTop 10 Worst Performing Classes:")
print(class_f1_df.head(10).to_string(index=False))