# Baseline MLP Climate Emulator

Build and train a Multi-Layer Perceptron (MLP) to emulate climate physics using JAX + Flax.

## Architecture Overview

**Input**: Flattened atmospheric state vector (~200-300 dimensions)
- Temperature profiles (60 levels)
- Humidity profiles (60 levels)
- Surface pressure, winds, etc.

**Model**: Deep MLP
- 4-6 hidden layers
- 512-1024 units per layer
- Swish/GELU activation
- Layer normalization
- Dropout for regularization

**Output**: Physics tendencies
- Temperature tendency (60 levels)
- Humidity tendency (60 levels)
- Other tendency variables

**Loss**: MSE + optional water conservation regularization

**Optimizer**: AdamW with learning rate scheduling

**Prerequisites**: Run `03_jax_preprocessing_pipeline.ipynb` first to prepare data!

In [None]:
# Import required packages
import numpy as np
import jax
import jax.numpy as jnp
from jax import random, jit, grad, vmap
from jax.tree_util import tree_map
import optax
from flax import linen as nn
from flax.training import train_state, checkpoints
import orbax.checkpoint as ocp
from pathlib import Path
import os
from typing import Sequence, Callable
import matplotlib.pyplot as plt
import time
from functools import partial

print("‚úÖ All imports successful!")
print(f"\nüìç JAX version: {jax.__version__}")
print(f"üìç Available devices: {jax.devices()}")
print(f"üìç Device count: {jax.device_count()}")

## 1. Configuration

In [None]:
class Config:
    """Configuration for MLP emulator."""
    
    # Paths
    USER = os.environ.get('USER', 'default')
    SCRATCH_DIR = Path(f"/home/jovyan/leap-scratch/{USER}")
    DATA_DIR = SCRATCH_DIR / "climsim_processed"
    MODEL_DIR = SCRATCH_DIR / "models" / "mlp_baseline"
    CHECKPOINT_DIR = MODEL_DIR / "checkpoints"
    
    # Model architecture
    HIDDEN_DIMS = [512, 512, 512, 512]  # 4 layers with 512 units each
    ACTIVATION = 'swish'  # or 'gelu'
    USE_LAYER_NORM = True
    DROPOUT_RATE = 0.1
    
    # Training
    BATCH_SIZE = 64
    NUM_EPOCHS = 50
    LEARNING_RATE = 1e-3
    WEIGHT_DECAY = 1e-4
    WARMUP_EPOCHS = 5
    
    # Loss
    USE_WATER_CONSERVATION = False  # Enable water conservation regularization
    WATER_CONSERVATION_WEIGHT = 0.1
    
    # Checkpointing
    SAVE_EVERY = 5  # Save checkpoint every N epochs
    KEEP_BEST = True  # Keep best model based on validation loss
    
    # Random seed
    SEED = 42

config = Config()

# Create directories
config.MODEL_DIR.mkdir(parents=True, exist_ok=True)
config.CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

print("=" * 70)
print("CONFIGURATION")
print("=" * 70)
print(f"Data directory:       {config.DATA_DIR}")
print(f"Model directory:      {config.MODEL_DIR}")
print(f"Checkpoint directory: {config.CHECKPOINT_DIR}")
print(f"\nModel architecture:")
print(f"  Hidden layers: {config.HIDDEN_DIMS}")
print(f"  Activation:    {config.ACTIVATION}")
print(f"  Layer norm:    {config.USE_LAYER_NORM}")
print(f"  Dropout:       {config.DROPOUT_RATE}")
print(f"\nTraining:")
print(f"  Batch size:    {config.BATCH_SIZE}")
print(f"  Epochs:        {config.NUM_EPOCHS}")
print(f"  Learning rate: {config.LEARNING_RATE}")
print(f"  Weight decay:  {config.WEIGHT_DECAY}")

## 2. Load Preprocessed Data

In [None]:
print(f"Loading preprocessed data from: {config.DATA_DIR}")

# Load data from npz (faster for small datasets)
data_path = config.DATA_DIR / 'climsim_nyc_processed.npz'

if not data_path.exists():
    print(f"‚ö†Ô∏è  Data file not found: {data_path}")
    print("Please run 03_jax_preprocessing_pipeline.ipynb first!")
    
    # Create synthetic data for demonstration
    print("\nüìù Creating synthetic data for demonstration...")
    n_train, n_val, n_test = 8000, 1000, 1000
    n_levels = 60
    
    # Synthetic inputs and outputs
    np.random.seed(config.SEED)
    
    train_data = {
        'train_input_state_t': np.random.randn(n_train, n_levels),
        'train_input_state_q0001': np.random.randn(n_train, n_levels),
        'train_input_state_ps': np.random.randn(n_train, 1),
        'train_output_ptend_t': np.random.randn(n_train, n_levels) * 0.1,
        'train_output_ptend_q0001': np.random.randn(n_train, n_levels) * 0.01,
    }
    
    val_data = {
        'val_input_state_t': np.random.randn(n_val, n_levels),
        'val_input_state_q0001': np.random.randn(n_val, n_levels),
        'val_input_state_ps': np.random.randn(n_val, 1),
        'val_output_ptend_t': np.random.randn(n_val, n_levels) * 0.1,
        'val_output_ptend_q0001': np.random.randn(n_val, n_levels) * 0.01,
    }
    
    test_data = {
        'test_input_state_t': np.random.randn(n_test, n_levels),
        'test_input_state_q0001': np.random.randn(n_test, n_levels),
        'test_input_state_ps': np.random.randn(n_test, 1),
        'test_output_ptend_t': np.random.randn(n_test, n_levels) * 0.1,
        'test_output_ptend_q0001': np.random.randn(n_test, n_levels) * 0.01,
    }
    
    # Combine
    data = {**train_data, **val_data, **test_data}
    
else:
    data = np.load(data_path)
    print(f"‚úÖ Loaded data from {data_path}")

# Extract train/val/test data
def extract_split(data, split='train'):
    """Extract input and output variables for a split."""
    prefix = f'{split}_'
    
    # Get all variable names
    input_vars = [k for k in data.files if k.startswith(f'{prefix}input_')]
    output_vars = [k for k in data.files if k.startswith(f'{prefix}output_')]
    
    # Concatenate all inputs into single array
    inputs = [data[var] for var in sorted(input_vars)]
    # Flatten each input if needed
    inputs_flat = []
    for inp in inputs:
        if len(inp.shape) == 1:
            inputs_flat.append(inp.reshape(-1, 1))
        else:
            inputs_flat.append(inp)
    X = np.concatenate(inputs_flat, axis=1)
    
    # Concatenate all outputs
    outputs = [data[var] for var in sorted(output_vars)]
    outputs_flat = []
    for out in outputs:
        if len(out.shape) == 1:
            outputs_flat.append(out.reshape(-1, 1))
        else:
            outputs_flat.append(out)
    y = np.concatenate(outputs_flat, axis=1)
    
    return X, y, input_vars, output_vars

# Extract splits
X_train, y_train, train_input_vars, train_output_vars = extract_split(data, 'train')
X_val, y_val, _, _ = extract_split(data, 'val')
X_test, y_test, _, _ = extract_split(data, 'test')

# Convert to JAX arrays
X_train = jnp.array(X_train)
y_train = jnp.array(y_train)
X_val = jnp.array(X_val)
y_val = jnp.array(y_val)
X_test = jnp.array(X_test)
y_test = jnp.array(y_test)

print(f"\nüìä Data shapes:")
print(f"  Train: X={X_train.shape}, y={y_train.shape}")
print(f"  Val:   X={X_val.shape}, y={y_val.shape}")
print(f"  Test:  X={X_test.shape}, y={y_test.shape}")
print(f"\n  Input dimension:  {X_train.shape[1]}")
print(f"  Output dimension: {y_train.shape[1]}")

In [None]:
class ClimateEmulatorMLP(nn.Module):
    """Multi-Layer Perceptron for climate physics emulation."""
    
    hidden_dims: Sequence[int]
    output_dim: int
    activation: str = 'swish'
    use_layer_norm: bool = True
    dropout_rate: float = 0.1
    training: bool = True
    
    @nn.compact
    def __call__(self, x):
        """Forward pass.
        
        Args:
            x: Input tensor of shape (batch_size, input_dim)
            
        Returns:
            Output tensor of shape (batch_size, output_dim)
        """
        
        # Select activation function
        if self.activation == 'swish':
            activation_fn = nn.swish
        elif self.activation == 'gelu':
            activation_fn = nn.gelu
        elif self.activation == 'relu':
            activation_fn = nn.relu
        else:
            raise ValueError(f"Unknown activation: {self.activation}")
        
        # Hidden layers
        for i, dim in enumerate(self.hidden_dims):
            x = nn.Dense(dim, name=f'dense_{i}')(x)
            
            if self.use_layer_norm:
                x = nn.LayerNorm(name=f'ln_{i}')(x)
            
            x = activation_fn(x)
            
            if self.dropout_rate > 0:
                x = nn.Dropout(rate=self.dropout_rate, deterministic=not self.training)(x)
        
        # Output layer (no activation for regression)
        x = nn.Dense(self.output_dim, name='output')(x)
        
        return x

# Initialize model
output_dim = y_train.shape[1]

model = ClimateEmulatorMLP(
    hidden_dims=config.HIDDEN_DIMS,
    output_dim=output_dim,
    activation=config.ACTIVATION,
    use_layer_norm=config.USE_LAYER_NORM,
    dropout_rate=config.DROPOUT_RATE
)

# Initialize parameters
rng = random.PRNGKey(config.SEED)
rng, init_rng = random.split(rng)

# Dummy input for initialization
dummy_input = jnp.ones((1, X_train.shape[1]))
params = model.init(init_rng, dummy_input)

print("=" * 70)
print("MODEL ARCHITECTURE")
print("=" * 70)
print(model.tabulate(init_rng, dummy_input, compute_flops=True, compute_vjp_flops=True))

# Count parameters
param_count = sum(x.size for x in jax.tree_util.tree_leaves(params))
print(f"\nüìä Total parameters: {param_count:,}")

## 4. Define Loss Function

MSE loss with optional water conservation regularization.

In [None]:
def mse_loss(predictions, targets):
    """Mean squared error loss."""
    return jnp.mean((predictions - targets) ** 2)

def water_conservation_loss(predictions, humidity_indices):
    """Water conservation regularization.
    
    Penalizes violations of water conservation (sum of humidity tendencies should be ~0).
    
    Args:
        predictions: Model predictions
        humidity_indices: Indices of humidity-related variables
    
    Returns:
        Conservation loss value
    """
    # Extract humidity tendencies
    humidity_tends = predictions[:, humidity_indices]
    
    # Sum over vertical levels (should be close to 0 for conservation)
    total_water_tendency = jnp.sum(humidity_tends, axis=1)
    
    # Penalize deviation from zero
    return jnp.mean(total_water_tendency ** 2)

def compute_loss(params, batch_x, batch_y, model, rng, use_water_conservation=False):
    """Compute total loss.
    
    Args:
        params: Model parameters
        batch_x: Input batch
        batch_y: Target batch
        model: Flax model
        rng: Random key for dropout
        use_water_conservation: Whether to add water conservation regularization
    
    Returns:
        Total loss value
    """
    # Forward pass
    predictions = model.apply(params, batch_x, rngs={'dropout': rng})
    
    # MSE loss
    loss = mse_loss(predictions, batch_y)
    
    # Add water conservation regularization if enabled
    if use_water_conservation:
        # Assume humidity variables are in second half of outputs (adjust as needed)
        n_out = predictions.shape[1]
        humidity_indices = jnp.arange(n_out // 2, n_out)
        
        water_loss = water_conservation_loss(predictions, humidity_indices)
        loss = loss + config.WATER_CONSERVATION_WEIGHT * water_loss
    
    return loss

print("‚úÖ Loss functions defined")
print(f"   Base loss: MSE")
print(f"   Water conservation: {'Enabled' if config.USE_WATER_CONSERVATION else 'Disabled'}")

## 5. Training Setup

Configure optimizer and create training state.

In [None]:
# Create learning rate schedule with warmup
def create_learning_rate_schedule(base_lr, warmup_epochs, total_epochs, steps_per_epoch):
    """Create learning rate schedule with warmup and cosine decay."""
    
    warmup_steps = warmup_epochs * steps_per_epoch
    total_steps = total_epochs * steps_per_epoch
    
    warmup_schedule = optax.linear_schedule(
        init_value=0.0,
        end_value=base_lr,
        transition_steps=warmup_steps
    )
    
    cosine_schedule = optax.cosine_decay_schedule(
        init_value=base_lr,
        decay_steps=total_steps - warmup_steps,
        alpha=0.1  # End at 10% of base_lr
    )
    
    schedule = optax.join_schedules(
        schedules=[warmup_schedule, cosine_schedule],
        boundaries=[warmup_steps]
    )
    
    return schedule

# Calculate steps per epoch
steps_per_epoch = len(X_train) // config.BATCH_SIZE

# Create optimizer
lr_schedule = create_learning_rate_schedule(
    base_lr=config.LEARNING_RATE,
    warmup_epochs=config.WARMUP_EPOCHS,
    total_epochs=config.NUM_EPOCHS,
    steps_per_epoch=steps_per_epoch
)

optimizer = optax.adamw(
    learning_rate=lr_schedule,
    weight_decay=config.WEIGHT_DECAY
)

# Create training state
class TrainState(train_state.TrainState):
    """Extended train state with dropout RNG."""
    dropout_rng: jax.Array

# Initialize training state
state = TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optimizer,
    dropout_rng=rng
)

print("=" * 70)
print("TRAINING SETUP")
print("=" * 70)
print(f"Optimizer: AdamW")
print(f"  Base learning rate: {config.LEARNING_RATE}")
print(f"  Weight decay:       {config.WEIGHT_DECAY}")
print(f"  Warmup epochs:      {config.WARMUP_EPOCHS}")
print(f"\nSteps per epoch: {steps_per_epoch}")
print(f"Total steps:     {config.NUM_EPOCHS * steps_per_epoch}")

## 6. Training Step (JIT Compiled)

Define training and evaluation functions with JIT compilation.

In [None]:
@jit
def train_step(state, batch_x, batch_y):
    """Single training step (JIT compiled).
    
    Args:
        state: Training state
        batch_x: Input batch
        batch_y: Target batch
    
    Returns:
        Updated state and loss value
    """
    # Split RNG for dropout
    dropout_rng, new_dropout_rng = random.split(state.dropout_rng)
    
    # Compute loss and gradients
    def loss_fn(params):
        return compute_loss(
            params, batch_x, batch_y, 
            state.apply_fn, dropout_rng,
            use_water_conservation=config.USE_WATER_CONSERVATION
        )
    
    loss, grads = jax.value_and_grad(loss_fn)(state.params)
    
    # Update parameters
    state = state.apply_gradients(grads=grads)
    state = state.replace(dropout_rng=new_dropout_rng)
    
    return state, loss

@jit
def eval_step(state, batch_x, batch_y):
    """Single evaluation step (JIT compiled).
    
    Args:
        state: Training state
        batch_x: Input batch
        batch_y: Target batch
    
    Returns:
        Loss value and predictions
    """
    # No dropout during evaluation
    predictions = state.apply_fn(
        state.params, batch_x,
        training=False,
        rngs={'dropout': state.dropout_rng}
    )
    
    loss = mse_loss(predictions, batch_y)
    
    return loss, predictions

def create_batches(X, y, batch_size, rng):
    """Create shuffled batches."""
    n_samples = len(X)
    
    # Shuffle indices
    perm = random.permutation(rng, n_samples)
    X_shuffled = X[perm]
    y_shuffled = y[perm]
    
    # Create batches
    n_batches = n_samples // batch_size
    X_batches = X_shuffled[:n_batches * batch_size].reshape(n_batches, batch_size, -1)
    y_batches = y_shuffled[:n_batches * batch_size].reshape(n_batches, batch_size, -1)
    
    return X_batches, y_batches

def evaluate(state, X, y, batch_size):
    """Evaluate model on dataset.
    
    Args:
        state: Training state
        X: Input data
        y: Target data
        batch_size: Batch size
    
    Returns:
        Average loss
    """
    losses = []
    
    n_samples = len(X)
    n_batches = n_samples // batch_size
    
    for i in range(n_batches):
        start_idx = i * batch_size
        end_idx = start_idx + batch_size
        
        batch_x = X[start_idx:end_idx]
        batch_y = y[start_idx:end_idx]
        
        loss, _ = eval_step(state, batch_x, batch_y)
        losses.append(loss)
    
    return jnp.mean(jnp.array(losses))

print("‚úÖ Training and evaluation functions defined (JIT compiled)")

In [None]:
# Training loop
print("=" * 70)
print("STARTING TRAINING")
print("=" * 70)

# Initialize tracking
train_losses = []
val_losses = []
best_val_loss = float('inf')
best_state = None

# Training
start_time = time.time()

for epoch in range(config.NUM_EPOCHS):
    epoch_start = time.time()
    
    # Create batches for this epoch
    rng, batch_rng = random.split(rng)
    X_batches, y_batches = create_batches(X_train, y_train, config.BATCH_SIZE, batch_rng)
    
    # Train for one epoch
    epoch_losses = []
    for batch_x, batch_y in zip(X_batches, y_batches):
        state, loss = train_step(state, batch_x, batch_y)
        epoch_losses.append(loss)
    
    train_loss = jnp.mean(jnp.array(epoch_losses))
    train_losses.append(float(train_loss))
    
    # Evaluate on validation set
    val_loss = evaluate(state, X_val, y_val, config.BATCH_SIZE)
    val_losses.append(float(val_loss))
    
    # Track best model
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        best_state = state
        best_epoch = epoch
    
    epoch_time = time.time() - epoch_start
    
    # Print progress
    if (epoch + 1) % 5 == 0 or epoch == 0:
        print(f"Epoch {epoch+1:3d}/{config.NUM_EPOCHS} | "
              f"Train Loss: {train_loss:.6f} | "
              f"Val Loss: {val_loss:.6f} | "
              f"Time: {epoch_time:.2f}s")
    
    # Save checkpoint
    if config.SAVE_EVERY > 0 and (epoch + 1) % config.SAVE_EVERY == 0:
        ckpt_path = config.CHECKPOINT_DIR / f"checkpoint_epoch_{epoch+1}"
        checkpointer = ocp.PyTreeCheckpointer()
        checkpointer.save(ckpt_path, state)
        print(f"  üíæ Saved checkpoint: {ckpt_path}")

total_time = time.time() - start_time

print(f"\n{'=' * 70}")
print(f"TRAINING COMPLETE")
print(f"{'=' * 70}")
print(f"Total time: {total_time:.2f}s ({total_time/60:.2f} min)")
print(f"Time per epoch: {total_time/config.NUM_EPOCHS:.2f}s")
print(f"\nFinal train loss: {train_losses[-1]:.6f}")
print(f"Final val loss:   {val_losses[-1]:.6f}")
print(f"Best val loss:    {best_val_loss:.6f} (epoch {best_epoch+1})")

## 8. Save Best Model

In [None]:
# Save best model
if config.KEEP_BEST and best_state is not None:
    best_model_path = config.MODEL_DIR / "best_model"
    checkpointer = ocp.PyTreeCheckpointer()
    checkpointer.save(best_model_path, best_state)
    
    print(f"\nüíæ Saved best model to: {best_model_path}")
    print(f"   Best validation loss: {best_val_loss:.6f}")
    print(f"   Achieved at epoch: {best_epoch+1}")
    
    # Save model config
    model_config = {
        'hidden_dims': config.HIDDEN_DIMS,
        'output_dim': output_dim,
        'input_dim': X_train.shape[1],
        'activation': config.ACTIVATION,
        'use_layer_norm': config.USE_LAYER_NORM,
        'dropout_rate': config.DROPOUT_RATE,
        'best_val_loss': float(best_val_loss),
        'best_epoch': int(best_epoch),
        'total_params': int(param_count),
    }
    
    config_path = config.MODEL_DIR / "model_config.npz"
    np.savez(config_path, **{k: str(v) for k, v in model_config.items()})
    print(f"   Saved config to: {config_path}")

print("\n‚úÖ Model saved successfully!")

## 9. Visualization: Training Curves

In [None]:
# Plot training curves
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Loss curves
ax = axes[0]
epochs = np.arange(1, len(train_losses) + 1)
ax.plot(epochs, train_losses, 'b-', label='Train Loss', linewidth=2)
ax.plot(epochs, val_losses, 'r-', label='Val Loss', linewidth=2)
ax.axvline(best_epoch + 1, color='g', linestyle='--', alpha=0.7, label=f'Best Model (epoch {best_epoch+1})')
ax.set_xlabel('Epoch', fontsize=12, fontweight='bold')
ax.set_ylabel('Loss (MSE)', fontsize=12, fontweight='bold')
ax.set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_yscale('log')

# Loss difference
ax = axes[1]
loss_diff = np.array(val_losses) - np.array(train_losses)
ax.plot(epochs, loss_diff, 'purple', linewidth=2)
ax.axhline(0, color='k', linestyle='--', alpha=0.5)
ax.set_xlabel('Epoch', fontsize=12, fontweight='bold')
ax.set_ylabel('Val Loss - Train Loss', fontsize=12, fontweight='bold')
ax.set_title('Generalization Gap', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3)
ax.fill_between(epochs, 0, loss_diff, where=(loss_diff>0), alpha=0.3, color='red', label='Overfitting')
ax.fill_between(epochs, 0, loss_diff, where=(loss_diff<0), alpha=0.3, color='green', label='Underfitting')
ax.legend(fontsize=10)

plt.suptitle('MLP Climate Emulator - Training Progress', fontsize=15, fontweight='bold', y=1.02)
plt.tight_layout()
plt.show()

# Save plot
plot_path = config.MODEL_DIR / 'training_curves.png'
fig.savefig(plot_path, dpi=150, bbox_inches='tight')
print(f"\nüíæ Saved training curves to: {plot_path}")

## 10. Evaluation on Test Set

In [None]:
# Evaluate best model on test set
print("=" * 70)
print("TEST SET EVALUATION")
print("=" * 70)

test_loss = evaluate(best_state, X_test, y_test, config.BATCH_SIZE)
print(f"\nTest Loss (MSE): {test_loss:.6f}")

# Get predictions for analysis
test_predictions = []
n_test_batches = len(X_test) // config.BATCH_SIZE

for i in range(n_test_batches):
    start_idx = i * config.BATCH_SIZE
    end_idx = start_idx + config.BATCH_SIZE
    
    batch_x = X_test[start_idx:end_idx]
    batch_y = y_test[start_idx:end_idx]
    
    _, preds = eval_step(best_state, batch_x, batch_y)
    test_predictions.append(preds)

test_predictions = jnp.concatenate(test_predictions, axis=0)
test_targets = y_test[:len(test_predictions)]

# Compute additional metrics
mae = jnp.mean(jnp.abs(test_predictions - test_targets))
rmse = jnp.sqrt(jnp.mean((test_predictions - test_targets) ** 2))

# R¬≤ score
ss_res = jnp.sum((test_targets - test_predictions) ** 2)
ss_tot = jnp.sum((test_targets - jnp.mean(test_targets)) ** 2)
r2 = 1 - (ss_res / ss_tot)

print(f"\nAdditional Metrics:")
print(f"  MAE:  {mae:.6f}")
print(f"  RMSE: {rmse:.6f}")
print(f"  R¬≤:   {r2:.6f}")

# Per-variable analysis (if we know variable boundaries)
print(f"\nPrediction statistics:")
print(f"  Mean prediction: {jnp.mean(test_predictions):.6f}")
print(f"  Std prediction:  {jnp.std(test_predictions):.6f}")
print(f"  Mean target:     {jnp.mean(test_targets):.6f}")
print(f"  Std target:      {jnp.std(test_targets):.6f}")

## 11. Visualization: Predictions vs Targets

In [None]:
# Visualize predictions vs targets
fig, axes = plt.subplots(2, 2, figsize=(14, 12))

# 1. Scatter plot: Predictions vs Targets
ax = axes[0, 0]
# Sample subset for visibility
n_plot = min(5000, len(test_predictions))
idx = np.random.choice(len(test_predictions), n_plot, replace=False)

scatter = ax.scatter(test_targets[idx].flatten(), 
                     test_predictions[idx].flatten(),
                     alpha=0.3, s=1, c='blue')
                     
# Perfect prediction line
min_val = min(test_targets.min(), test_predictions.min())
max_val = max(test_targets.max(), test_predictions.max())
ax.plot([min_val, max_val], [min_val, max_val], 'r--', linewidth=2, label='Perfect Prediction')

ax.set_xlabel('True Values', fontsize=11, fontweight='bold')
ax.set_ylabel('Predicted Values', fontsize=11, fontweight='bold')
ax.set_title(f'Predictions vs Targets (R¬≤={r2:.4f})', fontsize=12, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

# 2. Residual histogram
ax = axes[0, 1]
residuals = (test_predictions - test_targets).flatten()
ax.hist(residuals, bins=50, alpha=0.7, color='green', edgecolor='black')
ax.axvline(0, color='r', linestyle='--', linewidth=2, label='Zero Error')
ax.set_xlabel('Residual (Predicted - True)', fontsize=11, fontweight='bold')
ax.set_ylabel('Frequency', fontsize=11, fontweight='bold')
ax.set_title(f'Residual Distribution (Mean={jnp.mean(residuals):.6f})', fontsize=12, fontweight='bold')
ax.legend()
ax.grid(True, alpha=0.3)

# 3. Sample vertical profile comparison
ax = axes[1, 0]
# Assume first 60 dimensions are temperature tendency
sample_idx = 0
if test_predictions.shape[1] >= 60:
    true_profile = test_targets[sample_idx, :60]
    pred_profile = test_predictions[sample_idx, :60]
    levels = np.arange(60)
    
    ax.plot(true_profile, levels, 'b-', linewidth=2, marker='o', markersize=4, label='True')
    ax.plot(pred_profile, levels, 'r--', linewidth=2, marker='s', markersize=4, label='Predicted')
    ax.invert_yaxis()
    ax.set_xlabel('Temperature Tendency', fontsize=11, fontweight='bold')
    ax.set_ylabel('Vertical Level', fontsize=11, fontweight='bold')
    ax.set_title(f'Sample Vertical Profile (Test Sample {sample_idx})', fontsize=12, fontweight='bold')
    ax.legend()
    ax.grid(True, alpha=0.3)
else:
    ax.text(0.5, 0.5, 'Need 60+ output dims\\nfor vertical profile', 
            ha='center', va='center', transform=ax.transAxes, fontsize=12)
    ax.set_title('Vertical Profile', fontsize=12, fontweight='bold')

# 4. Error by output dimension
ax = axes[1, 1]
mse_per_dim = jnp.mean((test_predictions - test_targets) ** 2, axis=0)
dims = np.arange(len(mse_per_dim))
ax.bar(dims, mse_per_dim, alpha=0.7, color='purple')
ax.set_xlabel('Output Dimension', fontsize=11, fontweight='bold')
ax.set_ylabel('MSE', fontsize=11, fontweight='bold')
ax.set_title('Error by Output Dimension', fontsize=12, fontweight='bold')
ax.grid(True, alpha=0.3, axis='y')

plt.suptitle('Model Evaluation - Test Set Performance', fontsize=15, fontweight='bold', y=0.995)
plt.tight_layout()
plt.show()

# Save plot
eval_plot_path = config.MODEL_DIR / 'evaluation_plots.png'
fig.savefig(eval_plot_path, dpi=150, bbox_inches='tight')
print(f"\nüíæ Saved evaluation plots to: {eval_plot_path}")

## 12. Model Loading Example

Show how to load and use the trained model.

In [None]:
print("=" * 70)
print("MODEL LOADING EXAMPLE")
print("=" * 70)

# Example: How to load the saved model
print("\nüìù Code to load the trained model:\n")
print("""
from flax import linen as nn
import orbax.checkpoint as ocp
from pathlib import Path

# Recreate model architecture
model = ClimateEmulatorMLP(
    hidden_dims=[512, 512, 512, 512],
    output_dim=121,  # Adjust to your output dimension
    activation='swish',
    use_layer_norm=True,
    dropout_rate=0.1
)

# Load checkpoint
checkpoint_dir = Path('/home/jovyan/leap-scratch/$USER/models/mlp_baseline')
checkpointer = ocp.PyTreeCheckpointer()
restored_state = checkpointer.restore(
    checkpoint_dir / 'best_model',
    item=state  # Provide template
)

# Make predictions
predictions = restored_state.apply_fn(
    restored_state.params,
    input_data,
    training=False
)
""")

## Summary & Next Steps

### What We Built

1. ‚úÖ **MLP Architecture** - 4-layer MLP with 512 units per layer
2. ‚úÖ **Advanced Features** - Layer norm, dropout, swish activation
3. ‚úÖ **Training** - AdamW optimizer with warmup + cosine decay
4. ‚úÖ **JIT Compilation** - Fast training with JAX JIT
5. ‚úÖ **Evaluation** - Comprehensive metrics on test set
6. ‚úÖ **Checkpointing** - Saved best model with Orbax
7. ‚úÖ **Visualization** - Training curves and prediction analysis

### Model Performance

```
Best Validation Loss: {best_val_loss:.6f}
Test Loss (MSE):      {test_loss:.6f}
R¬≤ Score:             {r2:.6f}
MAE:                  {mae:.6f}
```

### Saved Files

```
/home/jovyan/leap-scratch/$USER/models/mlp_baseline/
‚îú‚îÄ‚îÄ best_model/                 # Best model checkpoint
‚îú‚îÄ‚îÄ model_config.npz            # Model configuration
‚îú‚îÄ‚îÄ training_curves.png         # Training visualization
‚îú‚îÄ‚îÄ evaluation_plots.png        # Evaluation visualization
‚îî‚îÄ‚îÄ checkpoints/                # Periodic checkpoints
    ‚îú‚îÄ‚îÄ checkpoint_epoch_5/
    ‚îú‚îÄ‚îÄ checkpoint_epoch_10/
    ‚îî‚îÄ‚îÄ ...
```

### Next Steps

1. **Hyperparameter Tuning**
   - Try different layer sizes (256, 1024)
   - Experiment with more/fewer layers
   - Adjust learning rate and weight decay

2. **Advanced Architectures**
   - Residual connections (ResNet-style)
   - Attention mechanisms
   - Physics-informed architectures

3. **Regularization**
   - Enable water conservation loss
   - Energy conservation constraints
   - Physical consistency checks

4. **Deployment**
   - Create inference pipeline
   - Optimize for production
   - Monitor performance

5. **Ensemble Methods**
   - Train multiple models
   - Average predictions
   - Uncertainty quantification

### Key Metrics

- **Training Time**: ~{total_time:.1f}s for {config.NUM_EPOCHS} epochs
- **Parameters**: {param_count:,}
- **Throughput**: ~{steps_per_epoch * config.NUM_EPOCHS / total_time:.1f} steps/sec

Great work! Your baseline climate emulator is ready! üåçüöÄ