# Model Training Guide

This notebook covers training the spectrum denoiser and retrieval models.

## Contents
1. Generating synthetic training data
2. Training the denoiser model
3. Evaluating model performance
4. Fine-tuning on real data

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm

import sys
sys.path.insert(0, '..')

from models.architectures.denoiser import create_denoiser

# Set random seeds for reproducibility
torch.manual_seed(42)
np.random.seed(42)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

## 1. Generate Synthetic Training Data

We create synthetic transmission spectra using simplified atmospheric models.

In [None]:
def generate_synthetic_spectrum(wavelength, n_features=5, base_depth=100):
    """Generate a synthetic transmission spectrum with random molecular features."""
    spectrum = np.ones_like(wavelength) * base_depth
    
    for _ in range(n_features):
        # Random feature parameters
        center = np.random.uniform(wavelength.min() + 0.5, wavelength.max() - 0.5)
        width = np.random.uniform(0.05, 0.3)
        depth = np.random.uniform(5, 30)
        
        # Add Gaussian absorption feature
        spectrum += depth * np.exp(-((wavelength - center) ** 2) / (2 * width ** 2))
    
    return spectrum

def add_noise(spectrum, snr):
    """Add Gaussian noise based on signal-to-noise ratio."""
    noise_std = spectrum.std() / snr
    noise = np.random.normal(0, noise_std, len(spectrum))
    return spectrum + noise

# Generate dataset
n_samples = 1000
n_wavelengths = 512
wavelength = np.linspace(0.5, 5.0, n_wavelengths)

clean_spectra = np.zeros((n_samples, n_wavelengths))
noisy_spectra = np.zeros((n_samples, n_wavelengths))

for i in tqdm(range(n_samples), desc='Generating spectra'):
    n_features = np.random.randint(3, 8)
    clean_spectra[i] = generate_synthetic_spectrum(wavelength, n_features)
    snr = np.random.uniform(10, 50)
    noisy_spectra[i] = add_noise(clean_spectra[i], snr)

print(f"Generated {n_samples} spectra with {n_wavelengths} wavelength points")

In [None]:
# Visualize a few examples
fig, axes = plt.subplots(2, 3, figsize=(15, 8))

for i, ax in enumerate(axes.flat):
    ax.plot(wavelength, noisy_spectra[i], 'b.', alpha=0.5, markersize=2, label='Noisy')
    ax.plot(wavelength, clean_spectra[i], 'r-', linewidth=1.5, label='Clean')
    ax.set_xlabel('Wavelength (μm)')
    ax.set_ylabel('Transit Depth (ppm)')
    ax.set_title(f'Sample {i+1}')
    ax.legend()

plt.tight_layout()
plt.show()

## 2. Prepare DataLoaders and Train Model

In [None]:
# Convert to tensors
X = torch.FloatTensor(noisy_spectra).unsqueeze(1)  # (N, 1, L)
y = torch.FloatTensor(clean_spectra).unsqueeze(1)

# Split into train/val
train_size = int(0.8 * len(X))
X_train, X_val = X[:train_size], X[train_size:]
y_train, y_val = y[:train_size], y[train_size:]

train_dataset = TensorDataset(X_train, y_train)
val_dataset = TensorDataset(X_val, y_val)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32)

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")

In [None]:
# Create model
model = create_denoiser(model_type='v1', base_channels=32, num_residual_blocks=3)
model = model.to(device)

# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, patience=5, factor=0.5)

print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

In [None]:
# Training loop
n_epochs = 50
train_losses = []
val_losses = []
best_val_loss = float('inf')

for epoch in range(n_epochs):
    # Training
    model.train()
    epoch_train_loss = 0
    
    for batch_x, batch_y in train_loader:
        batch_x, batch_y = batch_x.to(device), batch_y.to(device)
        
        optimizer.zero_grad()
        output = model(batch_x)
        loss = criterion(output, batch_y)
        loss.backward()
        optimizer.step()
        
        epoch_train_loss += loss.item() * batch_x.size(0)
    
    train_loss = epoch_train_loss / len(train_dataset)
    train_losses.append(train_loss)
    
    # Validation
    model.eval()
    epoch_val_loss = 0
    
    with torch.no_grad():
        for batch_x, batch_y in val_loader:
            batch_x, batch_y = batch_x.to(device), batch_y.to(device)
            output = model(batch_x)
            loss = criterion(output, batch_y)
            epoch_val_loss += loss.item() * batch_x.size(0)
    
    val_loss = epoch_val_loss / len(val_dataset)
    val_losses.append(val_loss)
    
    scheduler.step(val_loss)
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), '../models/checkpoints/denoiser_demo.pt')
    
    if (epoch + 1) % 10 == 0:
        print(f"Epoch {epoch+1}/{n_epochs} | Train Loss: {train_loss:.6f} | Val Loss: {val_loss:.6f}")

In [None]:
# Plot training curves
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Training Loss')
plt.plot(val_losses, label='Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('MSE Loss')
plt.title('Training Progress')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

print(f"Best validation loss: {best_val_loss:.6f}")

## 3. Evaluate Model Performance

In [None]:
# Load best model
model.load_state_dict(torch.load('../models/checkpoints/denoiser_demo.pt'))
model.eval()

# Test on validation set
with torch.no_grad():
    X_val_gpu = X_val.to(device)
    predictions = model(X_val_gpu).cpu().numpy()

# Calculate SNR improvement
input_error = np.mean((noisy_spectra[train_size:] - clean_spectra[train_size:]) ** 2, axis=1)
output_error = np.mean((predictions.squeeze() - clean_spectra[train_size:]) ** 2, axis=1)
snr_improvement = 10 * np.log10(input_error / output_error)

print(f"Average SNR improvement: {snr_improvement.mean():.2f} ± {snr_improvement.std():.2f} dB")

In [None]:
# Visualize results
fig, axes = plt.subplots(2, 3, figsize=(15, 8))

for i, ax in enumerate(axes.flat):
    idx = i
    ax.plot(wavelength, X_val[idx, 0].numpy(), 'b.', alpha=0.3, markersize=2, label='Noisy')
    ax.plot(wavelength, predictions[idx, 0], 'g-', linewidth=1.5, label='Recovered')
    ax.plot(wavelength, y_val[idx, 0].numpy(), 'r--', linewidth=1, label='Ground Truth')
    ax.set_xlabel('Wavelength (μm)')
    ax.set_ylabel('Transit Depth (ppm)')
    ax.set_title(f'SNR Improvement: {snr_improvement[idx]:.1f} dB')
    ax.legend(loc='upper right', fontsize=8)

plt.tight_layout()
plt.show()