# JWST Phase Retrieval - Training Notebook

This notebook trains a ResNet-18 model to predict Zernike coefficients from PSF images.

## üöÄ Google Colab Setup

If running on Colab, uncomment and run the setup cells below to:
1. Install dependencies
2. Clone/upload the project
3. Generate training data
4. Train the model

## üíª Local Setup

If running locally, just skip the Colab setup cells and run the training directly.

## üì¶ Install Dependencies (Colab Only)

In [None]:
# Uncomment if running on Colab
# !pip install torch torchvision matplotlib pyyaml tqdm tensorboard scipy -q

## üìÇ Setup Project (Colab Only)

In [None]:
# Uncomment if running on Colab
# import os
# if not os.path.exists('neural_wavefront'):
#     !git clone https://github.com/YOUR_USERNAME/neural_wavefront.git
#     os.chdir('neural_wavefront')
# else:
#     os.chdir('neural_wavefront')

## üîß Import Libraries

In [None]:
import sys
import os
from pathlib import Path
import platform

import torch
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

# Add src to path for imports
if 'src' not in sys.path:
    sys.path.insert(0, str(Path.cwd() / 'src'))

from neural_wavefront.utils.config import load_config
from neural_wavefront.data.dataset import create_dataloaders
from neural_wavefront.models.resnet import create_model
from neural_wavefront.training.trainer import Trainer

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

## ‚öôÔ∏è Configuration

In [None]:
# Load configuration
config = load_config("configs/config.yaml")

# Experiment settings
experiment_name = "colab_training"  # Change this to customize
device = 'cuda' if torch.cuda.is_available() else 'cpu'

print(f"Device: {device}")
print(f"Experiment: {experiment_name}")

## üìä Generate Training Data (if needed)

In [None]:
# Check if data exists
train_path = Path("data/processed/train.npz")
val_path = Path("data/processed/val.npz")

if not train_path.exists() or not val_path.exists():
    print("‚ö†Ô∏è Data not found. Generating datasets...")
    print("This will take a few minutes...")
    
    # Run data generation script
    import subprocess
    result = subprocess.run(
        ["python", "scripts/generate_data.py"],
        capture_output=True,
        text=True
    )
    
    if result.returncode != 0:
        print("‚ùå Error generating data:")
        print(result.stderr)
    else:
        print("‚úÖ Data generation complete!")
else:
    print("‚úÖ Training data already exists")

## üìÅ Load Datasets

In [None]:
# Use num_workers=0 on Windows to avoid multiprocessing issues
num_workers = 0 if platform.system() == 'Windows' else config['data'].get('num_workers', 4)

print("Loading datasets...")
train_loader, val_loader = create_dataloaders(
    train_path=str(train_path),
    val_path=str(val_path),
    batch_size=config['data']['batch_size'],
    num_workers=num_workers,
    log_scale=config['visualization'].get('log_scale', True),
    normalize_coeffs=False,  # Keep coefficients in radians
    augment_train=config['data']['augmentation'].get('enable', True)
)

print(f"‚úÖ Datasets loaded:")
print(f"   Train batches: {len(train_loader)}")
print(f"   Val batches: {len(val_loader)}")
print(f"   Batch size: {config['data']['batch_size']}")

## üèóÔ∏è Create Model

In [None]:
print("Creating model...")
model = create_model(
    model_name=config['model']['name'],
    n_modes=config['model']['output_dim'],
    pretrained=config['model']['pretrained'],
    dropout=config['model']['dropout']
)

# Count parameters
n_params = sum(p.numel() for p in model.parameters())
print(f"‚úÖ Model: {config['model']['name']}")
print(f"   Parameters: {n_params:,}")

## ‚öôÔ∏è Setup Optimizer & Loss

In [None]:
# Create optimizer
optimizer_name = config['training']['optimizer']
lr = config['training']['learning_rate']

if optimizer_name == 'Adam':
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
elif optimizer_name == 'AdamW':
    optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
elif optimizer_name == 'SGD':
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
else:
    raise ValueError(f"Unknown optimizer: {optimizer_name}")

# Create loss function
loss_name = config['training']['loss']
if loss_name == 'MSE':
    criterion = nn.MSELoss()
elif loss_name == 'MAE':
    criterion = nn.L1Loss()
elif loss_name == 'Huber':
    criterion = nn.HuberLoss()
else:
    raise ValueError(f"Unknown loss: {loss_name}")

print(f"‚úÖ Optimizer: {optimizer_name} (lr={lr:.0e})")
print(f"   Loss: {loss_name}")

## üéØ Initialize Trainer

In [None]:
print("Initializing trainer...")
trainer = Trainer(
    model=model,
    train_loader=train_loader,
    val_loader=val_loader,
    optimizer=optimizer,
    criterion=criterion,
    device=device,
    config=config,
    experiment_name=experiment_name
)

print(f"‚úÖ Trainer ready!")
print(f"   Experiment dir: {trainer.exp_dir}")

## üöÄ Train Model

This cell will train the model for the specified number of epochs. Training progress will be displayed with loss metrics and a progress bar.

In [None]:
# Number of epochs to train
n_epochs = config['training']['epochs']

print(f"üöÄ Starting training for {n_epochs} epochs...")
print("="*70)

try:
    # Train the model
    trainer.train(n_epochs=n_epochs)
    
    print("\n" + "="*70)
    print("‚úÖ Training complete!")
    print(f"   Best validation loss: {trainer.best_val_loss:.6f}")
    print(f"   Best model saved to: {trainer.exp_dir / 'checkpoints' / 'best_model.pth'}")
    
except KeyboardInterrupt:
    print("\n‚ö†Ô∏è Training interrupted by user")
    print(f"   Saving checkpoint...")
    trainer.save_checkpoint('interrupted.pth')
    print(f"   Checkpoint saved to: {trainer.exp_dir / 'checkpoints' / 'interrupted.pth'}")

## üìä Plot Training History

In [None]:
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Plot training and validation loss
epochs = range(1, len(trainer.train_losses) + 1)
axes[0].plot(epochs, trainer.train_losses, 'b-', label='Train Loss', linewidth=2)
axes[0].plot(epochs, trainer.val_losses, 'r-', label='Val Loss', linewidth=2)
axes[0].set_xlabel('Epoch', fontsize=12)
axes[0].set_ylabel('Loss (MSE)', fontsize=12)
axes[0].set_title('Training and Validation Loss', fontsize=14)
axes[0].legend(fontsize=11)
axes[0].grid(True, alpha=0.3)

# Plot validation MAE
axes[1].plot(epochs, [m['mae'] for m in trainer.val_metrics], 'g-', linewidth=2)
axes[1].set_xlabel('Epoch', fontsize=12)
axes[1].set_ylabel('MAE (radians)', fontsize=12)
axes[1].set_title('Validation MAE', fontsize=14)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(trainer.exp_dir / 'training_curves.png', dpi=150, bbox_inches='tight')
plt.show()

print(f"üìä Training curves saved to: {trainer.exp_dir / 'training_curves.png'}")

## üéØ Next Steps

After training is complete, you can:

1. **Evaluate the model** on the test set using `scripts/evaluate.py`
2. **View TensorBoard logs**: Run `tensorboard --logdir outputs/experiments/{experiment_name}/tensorboard`
3. **Load the checkpoint** to continue training or make predictions
4. **Download the model** (if on Colab) from the experiment directory

The best model is saved at: `outputs/experiments/{experiment_name}/checkpoints/best_model.pth`