# Interactive Model Training

Train Arctic ice classification models interactively in this notebook.

This notebook allows you to:
- Train models with custom parameters
- Monitor training in real-time
- Visualize training progress
- Experiment with hyperparameters

In [None]:
import sys
sys.path.append('../training')
sys.path.append('../data')

import os
import json
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from torchvision import models
from tqdm.notebook import tqdm

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## 1. Configuration

In [None]:
# Training configuration
CONFIG = {
    'data_dir': '../data/processed',
    'batch_size': 16,  # Reduce if out of memory
    'num_epochs': 10,  # Start with fewer epochs for testing
    'learning_rate': 0.001,
    'num_classes': 3,
    'device': 'cuda' if torch.cuda.is_available() else 'cpu',
    'save_dir': '../models',
}

print("Training Configuration:")
for key, value in CONFIG.items():
    print(f"  {key}: {value}")

# Create save directory
os.makedirs(CONFIG['save_dir'], exist_ok=True)

## 2. Dataset Loader

In [None]:
class IceDataset(Dataset):
    """Dataset for ice imagery"""
    
    def __init__(self, data_dir, split='train'):
        self.data_dir = f"{data_dir}/{split}"
        self.image_dir = f"{self.data_dir}/images"
        self.label_dir = f"{self.data_dir}/labels"
        
        # Get all samples
        self.samples = sorted([f for f in os.listdir(self.image_dir) if f.endswith('.npy')])
        
        print(f"Loaded {len(self.samples)} {split} samples")
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        sample_name = self.samples[idx]
        
        # Load image
        image = np.load(f"{self.image_dir}/{sample_name}")
        image = torch.from_numpy(image).permute(2, 0, 1).float()
        
        # Load label
        label = np.load(f"{self.label_dir}/{sample_name}")
        
        # Get dominant class
        unique, counts = np.unique(label, return_counts=True)
        dominant_class = unique[np.argmax(counts)]
        
        return image, torch.tensor(dominant_class, dtype=torch.long)

# Load datasets
train_dataset = IceDataset(CONFIG['data_dir'], 'train')
val_dataset = IceDataset(CONFIG['data_dir'], 'val')

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=CONFIG['batch_size'], 
                         shuffle=True, num_workers=2)
val_loader = DataLoader(val_dataset, batch_size=CONFIG['batch_size'], 
                       shuffle=False, num_workers=2)

print(f"\nTrain batches: {len(train_loader)}")
print(f"Val batches: {len(val_loader)}")

## 3. Model Definition

In [None]:
class IceClassifier(nn.Module):
    """ResNet50-based ice classifier"""
    
    def __init__(self, num_classes=3, pretrained=True):
        super().__init__()
        
        # Load ResNet50
        self.backbone = models.resnet50(pretrained=pretrained)
        
        # Freeze early layers
        for param in list(self.backbone.parameters())[:-30]:
            param.requires_grad = False
        
        # Replace classifier
        in_features = self.backbone.fc.in_features
        self.backbone.fc = nn.Sequential(
            nn.Linear(in_features, 512),
            nn.ReLU(),
            nn.Dropout(0.5),
            nn.Linear(512, 128),
            nn.ReLU(),
            nn.Linear(128, num_classes)
        )
    
    def forward(self, x):
        return self.backbone(x)

# Create model
model = IceClassifier(num_classes=CONFIG['num_classes'], pretrained=True)
model = model.to(CONFIG['device'])

# Count parameters
total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"\nModel: ResNet50 Ice Classifier")
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Frozen parameters: {total_params - trainable_params:,}")

## 4. Training Setup

In [None]:
# Loss function
criterion = nn.CrossEntropyLoss()

# Optimizer
optimizer = optim.Adam(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=CONFIG['learning_rate']
)

# Learning rate scheduler
scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, mode='min', factor=0.5, patience=3, verbose=True
)

# Metrics tracking
train_losses = []
val_losses = []
val_accuracies = []

print("Training setup complete!")

## 5. Training Loop

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch"""
    model.train()
    total_loss = 0
    correct = 0
    total = 0
    
    pbar = tqdm(train_loader, desc='Training')
    for images, labels in pbar:
        images, labels = images.to(device), labels.to(device)
        
        # Forward pass
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        
        # Metrics
        total_loss += loss.item()
        _, predicted = outputs.max(1)
        total += labels.size(0)
        correct += predicted.eq(labels).sum().item()
        
        # Update progress
        pbar.set_postfix({
            'loss': f"{loss.item():.4f}",
            'acc': f"{100. * correct / total:.2f}%"
        })
    
    avg_loss = total_loss / len(train_loader)
    accuracy = 100. * correct / total
    
    return avg_loss, accuracy


def validate(model, val_loader, criterion, device):
    """Validate model"""
    model.eval()
    total_loss = 0
    correct = 0
    total = 0
    
    with torch.no_grad():
        for images, labels in tqdm(val_loader, desc='Validation'):
            images, labels = images.to(device), labels.to(device)
            
            outputs = model(images)
            loss = criterion(outputs, labels)
            
            total_loss += loss.item()
            _, predicted = outputs.max(1)
            total += labels.size(0)
            correct += predicted.eq(labels).sum().item()
    
    avg_loss = total_loss / len(val_loader)
    accuracy = 100. * correct / total
    
    return avg_loss, accuracy

print("Training functions defined!")

## 6. Train Model

In [None]:
print("="*60)
print("Starting Training")
print("="*60)

best_val_loss = float('inf')

for epoch in range(CONFIG['num_epochs']):
    print(f"\nEpoch {epoch + 1}/{CONFIG['num_epochs']}")
    print("-" * 60)
    
    # Train
    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, CONFIG['device'])
    train_losses.append(train_loss)
    
    # Validate
    val_loss, val_acc = validate(model, val_loader, criterion, CONFIG['device'])
    val_losses.append(val_loss)
    val_accuracies.append(val_acc)
    
    # Learning rate scheduling
    scheduler.step(val_loss)
    
    # Print results
    print(f"\nEpoch {epoch + 1} Results:")
    print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.2f}%")
    print(f"  Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.2f}%")
    
    # Save best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        model_path = f"{CONFIG['save_dir']}/ice_classifier_notebook.pth"
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'val_acc': val_acc,
        }, model_path)
        print(f"  ✅ Saved best model (val_loss: {val_loss:.4f})")

print("\n" + "="*60)
print("Training Complete!")
print("="*60)
print(f"Best val loss: {best_val_loss:.4f}")
print(f"Best val acc: {max(val_accuracies):.2f}%")

## 7. Visualize Training Progress

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
ax1.plot(train_losses, label='Train Loss', marker='o', linewidth=2)
ax1.plot(val_losses, label='Val Loss', marker='s', linewidth=2)
ax1.set_xlabel('Epoch', fontsize=12)
ax1.set_ylabel('Loss', fontsize=12)
ax1.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
ax1.legend(fontsize=11)
ax1.grid(alpha=0.3)

# Accuracy curve
ax2.plot(val_accuracies, label='Val Accuracy', marker='o', linewidth=2, color='green')
ax2.set_xlabel('Epoch', fontsize=12)
ax2.set_ylabel('Accuracy (%)', fontsize=12)
ax2.set_title('Validation Accuracy', fontsize=14, fontweight='bold')
ax2.legend(fontsize=11)
ax2.grid(alpha=0.3)

plt.tight_layout()
plt.savefig(f"{CONFIG['save_dir']}/training_curves_notebook.png", dpi=150, bbox_inches='tight')
plt.show()

print("✅ Training curves saved!")

## 8. Test Model on Sample

In [None]:
def test_on_sample(model, dataset, idx=0):
    """Test model on a single sample"""
    model.eval()
    
    image, true_label = dataset[idx]
    
    # Add batch dimension
    image_batch = image.unsqueeze(0).to(CONFIG['device'])
    
    # Predict
    with torch.no_grad():
        output = model(image_batch)
        probabilities = torch.softmax(output, dim=1)
        predicted_class = torch.argmax(probabilities, dim=1).item()
        confidence = probabilities[0, predicted_class].item()
    
    # Class names
    class_names = ['Open Water', 'Thin Ice', 'Thick Ice']
    
    # Visualize
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
    
    # Image
    img_display = image.permute(1, 2, 0).numpy()
    ax1.imshow(img_display)
    ax1.set_title('Test Image', fontsize=12, fontweight='bold')
    ax1.axis('off')
    
    # Predictions
    probs = probabilities[0].cpu().numpy()
    bars = ax2.barh(class_names, probs * 100, color=['blue', 'orange', 'green'])
    ax2.set_xlabel('Confidence (%)', fontsize=11)
    ax2.set_title('Prediction Probabilities', fontsize=12, fontweight='bold')
    ax2.grid(axis='x', alpha=0.3)
    
    # Highlight predicted class
    bars[predicted_class].set_color('red')
    bars[predicted_class].set_alpha(0.8)
    
    plt.tight_layout()
    plt.show()
    
    print(f"\nTrue Label: {class_names[true_label]}")
    print(f"Predicted: {class_names[predicted_class]} ({confidence*100:.1f}% confidence)")
    
    if true_label == predicted_class:
        print("✅ Correct prediction!")
    else:
        print("❌ Incorrect prediction")

# Test on a few samples
for i in range(min(3, len(val_dataset))):
    print(f"\n{'='*60}")
    print(f"Sample {i}")
    print('='*60)
    test_on_sample(model, val_dataset, i)

## 9. Save Final Model

In [None]:
# Save final model
final_model_path = f"{CONFIG['save_dir']}/ice_classifier_final.pth"
torch.save({
    'model_state_dict': model.state_dict(),
    'config': CONFIG,
    'train_losses': train_losses,
    'val_losses': val_losses,
    'val_accuracies': val_accuracies,
}, final_model_path)

print(f"✅ Final model saved to: {final_model_path}")
print(f"\nModel Performance:")
print(f"  Best Val Loss: {min(val_losses):.4f}")
print(f"  Best Val Accuracy: {max(val_accuracies):.2f}%")
print(f"\nTo use this model in production:")
print(f"  cp {final_model_path} ../../backend/app/models/")

## Conclusion

You've successfully trained an ice classification model!

**Results**:
- Model trained for `{CONFIG['num_epochs']}` epochs
- Best validation accuracy: Check output above
- Model saved and ready for deployment

**Next Steps**:
1. Run `03_model_evaluation.ipynb` for detailed evaluation
2. Experiment with hyperparameters (learning rate, batch size, epochs)
3. Try different model architectures
4. Deploy model to backend