# SwellSight Wave Analysis - Complete Training Pipeline

## Training Loop with Real-time Monitoring

In [None]:
def train_epoch(model, train_loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    total_loss = 0
    height_losses = []
    wave_type_losses = []
    direction_losses = []
    
    pbar = tqdm(train_loader, desc='Training')
    for batch_idx, (images, targets) in enumerate(pbar):
        images = images.to(device)
        targets = {k: v.to(device) for k, v in targets.items()}
        
        optimizer.zero_grad()
        
        # Forward pass
        predictions = model(images)
        loss_dict = criterion(predictions, targets)
        
        # Backward pass
        loss_dict['total_loss'].backward()
        optimizer.step()
        
        # Track losses
        total_loss += loss_dict['total_loss'].item()
        height_losses.append(loss_dict['height_loss'].item())
        wave_type_losses.append(loss_dict['wave_type_loss'].item())
        direction_losses.append(loss_dict['direction_loss'].item())
        
        # Update progress bar
        pbar.set_postfix({
            'Loss': f"{loss_dict['total_loss'].item():.4f}",
            'Height': f"{loss_dict['height_loss'].item():.4f}",
            'Type': f"{loss_dict['wave_type_loss'].item():.4f}",
            'Dir': f"{loss_dict['direction_loss'].item():.4f}"
        })
    
    return {
        'total_loss': total_loss / len(train_loader),
        'height_loss': np.mean(height_losses),
        'wave_type_loss': np.mean(wave_type_losses),
        'direction_loss': np.mean(direction_losses)
    }

def validate_epoch(model, val_loader, criterion, device):
    """Validate for one epoch."""
    model.eval()
    total_loss = 0
    height_losses = []
    wave_type_losses = []
    direction_losses = []
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc='Validation')
        for images, targets in pbar:
            images = images.to(device)
            targets = {k: v.to(device) for k, v in targets.items()}
            
            predictions = model(images)
            loss_dict = criterion(predictions, targets)
            
            total_loss += loss_dict['total_loss'].item()
            height_losses.append(loss_dict['height_loss'].item())
            wave_type_losses.append(loss_dict['wave_type_loss'].item())
            direction_losses.append(loss_dict['direction_loss'].item())
    
    return {
        'total_loss': total_loss / len(val_loader),
        'height_loss': np.mean(height_losses),
        'wave_type_loss': np.mean(wave_type_losses),
        'direction_loss': np.mean(direction_losses)
    }

In [None]:
# Training loop with real-time plotting
import matplotlib.pyplot as plt
from IPython.display import clear_output

def plot_training_progress(train_losses, val_losses, epoch):
    """Plot training progress in real-time."""
    clear_output(wait=True)
    
    fig, axes = plt.subplots(2, 2, figsize=(15, 10))
    
    # Total loss
    axes[0, 0].plot([l['total_loss'] for l in train_losses], label='Train', color='blue')
    axes[0, 0].plot([l['total_loss'] for l in val_losses], label='Val', color='red')
    axes[0, 0].set_title('Total Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True)
    
    # Height loss
    axes[0, 1].plot([l['height_loss'] for l in train_losses], label='Train', color='blue')
    axes[0, 1].plot([l['height_loss'] for l in val_losses], label='Val', color='red')
    axes[0, 1].set_title('Height Loss (Regression)')
    axes[0, 1].legend()
    axes[0, 1].grid(True)
    
    # Wave type loss
    axes[1, 0].plot([l['wave_type_loss'] for l in train_losses], label='Train', color='blue')
    axes[1, 0].plot([l['wave_type_loss'] for l in val_losses], label='Val', color='red')
    axes[1, 0].set_title('Wave Type Loss (Classification)')
    axes[1, 0].legend()
    axes[1, 0].grid(True)
    
    # Direction loss
    axes[1, 1].plot([l['direction_loss'] for l in train_losses], label='Train', color='blue')
    axes[1, 1].plot([l['direction_loss'] for l in val_losses], label='Val', color='red')
    axes[1, 1].set_title('Direction Loss (Classification)')
    axes[1, 1].legend()
    axes[1, 1].grid(True)
    
    plt.suptitle(f'Training Progress - Epoch {epoch}')
    plt.tight_layout()
    plt.show()

# Main training loop
num_epochs = 20  # Reduced for demo
train_losses = []
val_losses = []
best_val_loss = float('inf')

print("Starting training...")
for epoch in range(num_epochs):
    print(f"\nEpoch {epoch+1}/{num_epochs}")
    
    # Train
    train_loss = train_epoch(model, train_loader, criterion, optimizer, device)
    train_losses.append(train_loss)
    
    # Validate
    val_loss = validate_epoch(model, val_loader, criterion, device)
    val_losses.append(val_loss)
    
    # Update learning rate
    scheduler.step()
    
    # Save best model
    if val_loss['total_loss'] < best_val_loss:
        best_val_loss = val_loss['total_loss']
        torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'val_loss': val_loss,
            'config': config
        }, checkpoints_dir / 'best_model.pth')
        print(f"âœ… New best model saved! Val loss: {best_val_loss:.4f}")
    
    # Plot progress every 5 epochs
    if (epoch + 1) % 5 == 0:
        plot_training_progress(train_losses, val_losses, epoch + 1)
    
    # Print epoch summary
    print(f"Train Loss: {train_loss['total_loss']:.4f} | Val Loss: {val_loss['total_loss']:.4f}")
    print(f"LR: {scheduler.get_last_lr()[0]:.6f}")

print("\nðŸŽ‰ Training completed!")