# JAX Preprocessing Pipeline for ClimSim

Build a production-ready preprocessing pipeline for ClimSim data with JAX optimization.

## Features

1. **Load & Subsample** - Efficient data loading from Hugging Face
2. **Normalization** - Compute and apply mean/std normalization
3. **NYC Filtering** - Spatial subsetting for NYC region (40.5-41¬∞N, -74.3--73.7¬∞W)
4. **JAX Data Loaders** - Efficient batching with jax.numpy arrays
5. **Multi-GPU Sharding** - Automatic device parallelism
6. **Checkpoint Support** - Orbax serialization for model checkpoints
7. **Lazy Loading** - Memory-efficient processing
8. **Persistence** - Save processed data to leap-scratch as Zarr/npz

**Prerequisites:** Run `leap_startup.ipynb` first!

In [None]:
# Import required packages
import numpy as np
import jax
import jax.numpy as jnp
from jax import random
import pandas as pd
import xarray as xr
from pathlib import Path
import os
from typing import Dict, Tuple, Optional, Iterator
from functools import partial

# Hugging Face
from datasets import load_dataset

# Flax/Orbax for checkpointing
import orbax.checkpoint as ocp
from flax import serialization
from flax.training import train_state

# Visualization
import matplotlib.pyplot as plt

print("‚úÖ All imports successful!")
print(f"\nüìç JAX version: {jax.__version__}")
print(f"üìç Available devices: {jax.devices()}")
print(f"üìç Device count: {jax.device_count()}")

## 1. Configuration & Setup

Define paths, constants, and NYC bounding box.

In [None]:
# Configuration
class Config:
    # Paths
    USER = os.environ.get('USER', 'default')
    SCRATCH_DIR = Path(f"/home/jovyan/leap-scratch/{USER}")
    OUTPUT_DIR = SCRATCH_DIR / "climsim_processed"
    CHECKPOINT_DIR = SCRATCH_DIR / "checkpoints"
    
    # NYC Bounding Box
    NYC_LAT_MIN = 40.5
    NYC_LAT_MAX = 41.0
    NYC_LON_MIN = -74.3
    NYC_LON_MAX = -73.7
    
    # Data parameters
    DATASET_NAME = "LEAP/ClimSim_low-res"
    SAMPLE_SIZE = 10000  # Number of samples to load for demo
    TRAIN_SPLIT = 0.8
    VAL_SPLIT = 0.1
    TEST_SPLIT = 0.1
    
    # JAX parameters
    BATCH_SIZE = 32
    SEED = 42
    
    # Normalization
    NORMALIZE = True
    
    # Multi-GPU
    NUM_DEVICES = jax.device_count()
    
config = Config()

# Create directories
config.OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
config.CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

print("=" * 70)
print("CONFIGURATION")
print("=" * 70)
print(f"Scratch directory:    {config.SCRATCH_DIR}")
print(f"Output directory:     {config.OUTPUT_DIR}")
print(f"Checkpoint directory: {config.CHECKPOINT_DIR}")
print(f"\nNYC Bounding Box:")
print(f"  Latitude:  {config.NYC_LAT_MIN}¬∞N - {config.NYC_LAT_MAX}¬∞N")
print(f"  Longitude: {config.NYC_LON_MIN}¬∞W - {config.NYC_LON_MAX}¬∞W")
print(f"\nData splits:")
print(f"  Train: {config.TRAIN_SPLIT*100:.0f}%")
print(f"  Val:   {config.VAL_SPLIT*100:.0f}%")
print(f"  Test:  {config.TEST_SPLIT*100:.0f}%")
print(f"\nJAX configuration:")
print(f"  Devices:    {config.NUM_DEVICES}")
print(f"  Batch size: {config.BATCH_SIZE}")
print(f"  Batch per device: {config.BATCH_SIZE // config.NUM_DEVICES if config.NUM_DEVICES > 1 else config.BATCH_SIZE}")

## 2. Load ClimSim Dataset

Load a subsample of the dataset with geographic information.

In [None]:
print(f"Loading {config.SAMPLE_SIZE} samples from {config.DATASET_NAME}...")
print("=" * 70)

try:
    # Load dataset
    dataset = load_dataset(
        config.DATASET_NAME,
        split=f"train[:{config.SAMPLE_SIZE}]",
        streaming=False
    )
    
    print(f"‚úÖ Loaded {len(dataset)} samples")
    print(f"\nAvailable features: {list(dataset.features.keys())[:10]}...")
    
    # Extract first sample to understand structure
    first_sample = dataset[0]
    
    # Identify input and output variables
    input_vars = [k for k in first_sample.keys() if k.startswith('state_')]
    output_vars = [k for k in first_sample.keys() if k.startswith('ptend_')]
    
    print(f"\nüìä Found {len(input_vars)} input variables")
    print(f"üìà Found {len(output_vars)} output variables")
    
    # Check for lat/lon coordinates
    coord_vars = [k for k in first_sample.keys() if 'lat' in k.lower() or 'lon' in k.lower()]
    print(f"üìç Coordinate variables: {coord_vars if coord_vars else 'Not found in features'}")
    
except Exception as e:
    print(f"‚ö†Ô∏è Error loading dataset: {e}")
    print("\nüìù Creating synthetic dataset for demonstration...")
    
    # Create synthetic ClimSim-like data with NYC columns
    n_samples = config.SAMPLE_SIZE
    n_levels = 60
    
    # Generate random lat/lon for each sample
    # Some will be in NYC, some outside
    np.random.seed(config.SEED)
    lats = np.random.uniform(35, 45, n_samples)  # Mix of locations
    lons = np.random.uniform(-80, -70, n_samples)
    
    # Create sample data
    synthetic_data = []
    for i in range(n_samples):
        sample = {
            'lat': lats[i],
            'lon': lons[i],
            'state_t': np.random.randn(n_levels) * 30 + 250,
            'state_q0001': np.random.randn(n_levels) * 0.002 + 0.005,
            'state_ps': np.random.randn() * 5000 + 100000,
            'ptend_t': np.random.randn(n_levels) * 0.1,
            'ptend_q0001': np.random.randn(n_levels) * 1e-6,
        }
        synthetic_data.append(sample)
    
    # Create a dict-like structure
    class SyntheticDataset:
        def __init__(self, data):
            self.data = data
        def __len__(self):
            return len(self.data)
        def __getitem__(self, idx):
            return self.data[idx]
    
    dataset = SyntheticDataset(synthetic_data)
    
    first_sample = dataset[0]
    input_vars = [k for k in first_sample.keys() if k.startswith('state_')]
    output_vars = [k for k in first_sample.keys() if k.startswith('ptend_')]
    
    print(f"‚úÖ Created synthetic dataset with {len(dataset)} samples")
    print(f"   {len(input_vars)} input vars, {len(output_vars)} output vars")

In [None]:
def filter_nyc_samples(dataset, config):
    """Filter dataset to NYC region based on lat/lon."""
    
    print("Filtering samples to NYC region...")
    print(f"NYC Box: [{config.NYC_LAT_MIN}, {config.NYC_LAT_MAX}] lat, "
          f"[{config.NYC_LON_MIN}, {config.NYC_LON_MAX}] lon")
    
    nyc_samples = []
    
    for i, sample in enumerate(dataset):
        # Check if sample has lat/lon
        if 'lat' in sample and 'lon' in sample:
            lat = sample['lat'] if isinstance(sample['lat'], (int, float)) else sample['lat'][0]
            lon = sample['lon'] if isinstance(sample['lon'], (int, float)) else sample['lon'][0]
            
            # Check if in NYC box
            if (config.NYC_LAT_MIN <= lat <= config.NYC_LAT_MAX and
                config.NYC_LON_MIN <= lon <= config.NYC_LON_MAX):
                nyc_samples.append(sample)
        
        if i % 1000 == 0:
            print(f"  Processed {i}/{len(dataset)} samples, found {len(nyc_samples)} in NYC")
    
    print(f"\n‚úÖ Filtered {len(nyc_samples)} NYC samples from {len(dataset)} total")
    print(f"   ({len(nyc_samples)/len(dataset)*100:.1f}% of data)")
    
    return nyc_samples

# Apply NYC filtering
nyc_dataset = filter_nyc_samples(dataset, config)

if len(nyc_dataset) == 0:
    print("\n‚ö†Ô∏è No samples in NYC region!")
    print("   Using full dataset for demonstration...")
    nyc_dataset = [dataset[i] for i in range(len(dataset))]
else:
    print(f"\nüìç Using {len(nyc_dataset)} NYC-filtered samples")

# Save metadata
nyc_metadata = {
    'n_samples': len(nyc_dataset),
    'n_total': len(dataset),
    'bbox': {
        'lat_min': config.NYC_LAT_MIN,
        'lat_max': config.NYC_LAT_MAX,
        'lon_min': config.NYC_LON_MIN,
        'lon_max': config.NYC_LON_MAX
    }
}

print(f"\nüìã NYC Dataset Metadata:")
for key, val in nyc_metadata.items():
    print(f"   {key}: {val}")

## 4. Convert to NumPy Arrays

Extract and stack all variables into numpy arrays.

In [None]:
def extract_arrays(samples, input_vars, output_vars):
    """Convert list of samples to numpy arrays."""
    
    print("Converting to numpy arrays...")
    
    # Initialize storage
    inputs_dict = {var: [] for var in input_vars}
    outputs_dict = {var: [] for var in output_vars}
    
    # Extract data
    for sample in samples:
        for var in input_vars:
            val = sample[var]
            # Handle scalar vs array
            if isinstance(val, (int, float)):
                inputs_dict[var].append([val])
            else:
                inputs_dict[var].append(np.array(val))
        
        for var in output_vars:
            val = sample[var]
            if isinstance(val, (int, float)):
                outputs_dict[var].append([val])
            else:
                outputs_dict[var].append(np.array(val))
    
    # Stack into arrays
    inputs_arrays = {}
    outputs_arrays = {}
    
    for var in input_vars:
        inputs_arrays[var] = np.stack(inputs_dict[var])
        print(f"  Input  {var:20s} shape: {inputs_arrays[var].shape}")
    
    for var in output_vars:
        outputs_arrays[var] = np.stack(outputs_dict[var])
        print(f"  Output {var:20s} shape: {outputs_arrays[var].shape}")
    
    return inputs_arrays, outputs_arrays

# Convert to arrays
inputs_np, outputs_np = extract_arrays(nyc_dataset, input_vars, output_vars)

print(f"\n‚úÖ Extracted {len(inputs_np)} input variables")
print(f"‚úÖ Extracted {len(outputs_np)} output variables")

## 5. Train/Val/Test Split

In [None]:
def create_splits(inputs_dict, outputs_dict, config):
    """Split data into train/val/test sets."""
    
    n_samples = len(next(iter(inputs_dict.values())))
    
    # Shuffle indices
    np.random.seed(config.SEED)
    indices = np.random.permutation(n_samples)
    
    # Calculate split points
    n_train = int(n_samples * config.TRAIN_SPLIT)
    n_val = int(n_samples * config.VAL_SPLIT)
    
    train_idx = indices[:n_train]
    val_idx = indices[n_train:n_train + n_val]
    test_idx = indices[n_train + n_val:]
    
    print(f"Creating train/val/test splits...")
    print(f"  Train: {len(train_idx)} samples ({config.TRAIN_SPLIT*100:.0f}%)")
    print(f"  Val:   {len(val_idx)} samples ({config.VAL_SPLIT*100:.0f}%)")
    print(f"  Test:  {len(test_idx)} samples ({config.TEST_SPLIT*100:.0f}%)")
    
    # Split data
    splits = {}
    for split_name, split_idx in [('train', train_idx), ('val', val_idx), ('test', test_idx)]:
        splits[split_name] = {
            'inputs': {k: v[split_idx] for k, v in inputs_dict.items()},
            'outputs': {k: v[split_idx] for k, v in outputs_dict.items()},
            'indices': split_idx
        }
    
    return splits

# Create splits
data_splits = create_splits(inputs_np, outputs_np, config)

print(f"\n‚úÖ Data splits created")
print(f"   Keys: {list(data_splits.keys())}")

## 6. Compute Normalization Statistics

Calculate mean and std from training split only (avoid data leakage).

In [None]:
def compute_normalization_stats(train_data):
    """Compute mean and std from training data."""
    
    print("Computing normalization statistics from training split...")
    
    stats = {'inputs': {}, 'outputs': {}}
    
    # Input statistics
    for var_name, var_data in train_data['inputs'].items():
        mean = np.mean(var_data)
        std = np.std(var_data)
        stats['inputs'][var_name] = {'mean': mean, 'std': std}
        print(f"  Input  {var_name:20s} mean={mean:>10.4f}, std={std:>10.4f}")
    
    # Output statistics  
    for var_name, var_data in train_data['outputs'].items():
        mean = np.mean(var_data)
        std = np.std(var_data)
        stats['outputs'][var_name] = {'mean': mean, 'std': std}
        print(f"  Output {var_name:20s} mean={mean:>10.6f}, std={std:>10.6f}")
    
    return stats

# Compute stats
if config.NORMALIZE:
    norm_stats = compute_normalization_stats(data_splits['train'])
    print(f"\n‚úÖ Normalization statistics computed")
else:
    norm_stats = None
    print("\n‚ö™ Normalization disabled")

## 7. Apply Normalization

Normalize all splits using training statistics.

In [None]:
def normalize_data(data, stats):
    """Normalize data using computed statistics."""
    
    normalized = {'inputs': {}, 'outputs': {}}
    
    # Normalize inputs
    for var_name, var_data in data['inputs'].items():
        mean = stats['inputs'][var_name]['mean']
        std = stats['inputs'][var_name]['std']
        # Avoid division by zero
        std = std if std > 1e-8 else 1.0
        normalized['inputs'][var_name] = (var_data - mean) / std
    
    # Normalize outputs
    for var_name, var_data in data['outputs'].items():
        mean = stats['outputs'][var_name]['mean']
        std = stats['outputs'][var_name]['std']
        std = std if std > 1e-8 else 1.0
        normalized['outputs'][var_name] = (var_data - mean) / std
    
    normalized['indices'] = data['indices']
    
    return normalized

if config.NORMALIZE:
    print("Applying normalization to all splits...")
    
    normalized_splits = {}
    for split_name, split_data in data_splits.items():
        normalized_splits[split_name] = normalize_data(split_data, norm_stats)
        
        # Verify normalization on training split
        if split_name == 'train':
            print(f"\n  Verifying normalization on {split_name} split:")
            for var_name in list(normalized_splits[split_name]['inputs'].keys())[:3]:
                data = normalized_splits[split_name]['inputs'][var_name]
                print(f"    {var_name:20s} mean={np.mean(data):>8.4f}, std={np.std(data):>8.4f}")
    
    print(f"\n‚úÖ Normalization applied to all splits")
else:
    normalized_splits = data_splits
    print("\n‚ö™ Using unnormalized data")

## 8. Create JAX Data Loaders

Build efficient data loaders with JAX arrays and device sharding.

In [None]:
class JAXDataLoader:
    """JAX-friendly data loader with device sharding support."""
    
    def __init__(self, inputs, outputs, batch_size, shuffle=True, seed=42):
        """
        Args:
            inputs: Dict of input arrays
            outputs: Dict of output arrays  
            batch_size: Batch size
            shuffle: Whether to shuffle data
            seed: Random seed
        """
        self.inputs = {k: jnp.array(v) for k, v in inputs.items()}
        self.outputs = {k: jnp.array(v) for k, v in outputs.items()}
        self.batch_size = batch_size
        self.shuffle = shuffle
        self.seed = seed
        
        # Get number of samples
        self.n_samples = len(next(iter(self.inputs.values())))
        self.n_batches = self.n_samples // batch_size
        
        # Create index array
        self.rng = random.PRNGKey(seed)
        self._create_epoch()
    
    def _create_epoch(self):
        """Create new epoch with shuffled indices."""
        if self.shuffle:
            self.rng, shuffle_rng = random.split(self.rng)
            self.indices = random.permutation(shuffle_rng, self.n_samples)
        else:
            self.indices = jnp.arange(self.n_samples)
        self.current_idx = 0
    
    def __iter__(self):
        self._create_epoch()
        return self
    
    def __next__(self):
        if self.current_idx >= self.n_batches * self.batch_size:
            raise StopIteration
        
        # Get batch indices
        start_idx = self.current_idx
        end_idx = start_idx + self.batch_size
        batch_indices = self.indices[start_idx:end_idx]
        
        # Extract batch
        batch_inputs = {k: v[batch_indices] for k, v in self.inputs.items()}
        batch_outputs = {k: v[batch_indices] for k, v in self.outputs.items()}
        
        self.current_idx += self.batch_size
        
        return batch_inputs, batch_outputs
    
    def __len__(self):
        return self.n_batches

# Create data loaders
print(f"Creating JAX data loaders...")
print(f"  Batch size: {config.BATCH_SIZE}")
print(f"  Devices: {config.NUM_DEVICES}")

train_loader = JAXDataLoader(
    normalized_splits['train']['inputs'],
    normalized_splits['train']['outputs'],
    batch_size=config.BATCH_SIZE,
    shuffle=True,
    seed=config.SEED
)

val_loader = JAXDataLoader(
    normalized_splits['val']['inputs'],
    normalized_splits['val']['outputs'],
    batch_size=config.BATCH_SIZE,
    shuffle=False
)

test_loader = JAXDataLoader(
    normalized_splits['test']['inputs'],
    normalized_splits['test']['outputs'],
    batch_size=config.BATCH_SIZE,
    shuffle=False
)

print(f"\n‚úÖ Data loaders created:")
print(f"   Train: {len(train_loader)} batches ({train_loader.n_samples} samples)")
print(f"   Val:   {len(val_loader)} batches ({val_loader.n_samples} samples)")
print(f"   Test:  {len(test_loader)} batches ({test_loader.n_samples} samples)")

# Test loader
print(f"\nüß™ Testing data loader...")
batch_inputs, batch_outputs = next(iter(train_loader))
print(f"   Batch shapes:")
for var_name, var_data in list(batch_inputs.items())[:3]:
    print(f"     Input  {var_name}: {var_data.shape}")
for var_name, var_data in list(batch_outputs.items())[:3]:
    print(f"     Output {var_name}: {var_data.shape}")

In [None]:
def shard_batch(batch, num_devices):
    """Shard a batch across multiple devices."""
    
    def shard_array(arr):
        # Reshape to (num_devices, batch_per_device, ...)
        batch_size = arr.shape[0]
        batch_per_device = batch_size // num_devices
        
        # Truncate if not evenly divisible
        truncated_size = batch_per_device * num_devices
        arr = arr[:truncated_size]
        
        # Reshape and move to devices
        new_shape = (num_devices, batch_per_device) + arr.shape[1:]
        return arr.reshape(new_shape)
    
    # Shard inputs and outputs
    sharded_inputs = {k: shard_array(v) for k, v in batch[0].items()}
    sharded_outputs = {k: shard_array(v) for k, v in batch[1].items()}
    
    return sharded_inputs, sharded_outputs

if config.NUM_DEVICES > 1:
    print(f"Setting up device sharding for {config.NUM_DEVICES} devices...")
    
    # Test sharding
    batch = next(iter(train_loader))
    sharded_batch = shard_batch(batch, config.NUM_DEVICES)
    
    print(f"\n‚úÖ Device sharding configured")
    print(f"   Original batch shape: {next(iter(batch[0].values())).shape}")
    print(f"   Sharded batch shape:  {next(iter(sharded_batch[0].values())).shape}")
    print(f"   (devices, batch_per_device, ...)")
    
    # Example of using sharded batch with pmap
    print(f"\nüí° Usage with jax.pmap:")
    print(f"   @jax.pmap")
    print(f"   def train_step(state, batch):")
    print(f"       # Automatically executes on each device")
    print(f"       ...")
else:
    print(f"\n‚ö™ Single device mode (no sharding needed)")

## 10. Save Processed Data to Leap-Scratch

Persist NYC subset as Zarr for efficient access.

In [None]:
def save_as_zarr(data_splits, norm_stats, output_path):
    """Save processed data as Zarr format."""
    
    print(f"Saving processed data to Zarr...")
    print(f"  Output path: {output_path}")
    
    output_path.mkdir(parents=True, exist_ok=True)
    
    # Save each split
    for split_name, split_data in data_splits.items():
        split_path = output_path / f"{split_name}.zarr"
        
        # Create xarray dataset
        data_vars = {}
        
        # Add inputs
        for var_name, var_data in split_data['inputs'].items():
            # Create dimensions based on shape
            if len(var_data.shape) == 1:
                dims = ['sample']
            elif len(var_data.shape) == 2:
                dims = ['sample', 'level']
            else:
                dims = [f'dim_{i}' for i in range(len(var_data.shape))]
            
            data_vars[f'input_{var_name}'] = (dims, var_data)
        
        # Add outputs
        for var_name, var_data in split_data['outputs'].items():
            if len(var_data.shape) == 1:
                dims = ['sample']
            elif len(var_data.shape) == 2:
                dims = ['sample', 'level']
            else:
                dims = [f'dim_{i}' for i in range(len(var_data.shape))]
            
            data_vars[f'output_{var_name}'] = (dims, var_data)
        
        # Create dataset
        ds = xr.Dataset(data_vars)
        
        # Add metadata
        ds.attrs['split'] = split_name
        ds.attrs['n_samples'] = len(split_data['indices'])
        ds.attrs['normalized'] = config.NORMALIZE
        
        # Save to zarr
        ds.to_zarr(split_path, mode='w')
        print(f"  ‚úÖ Saved {split_name} split ({ds.nbytes / 1e6:.2f} MB)")
    
    # Save normalization stats
    if norm_stats:
        stats_path = output_path / 'norm_stats.npz'
        
        stats_flat = {}
        for io_type in ['inputs', 'outputs']:
            for var_name, var_stats in norm_stats[io_type].items():
                stats_flat[f'{io_type}_{var_name}_mean'] = var_stats['mean']
                stats_flat[f'{io_type}_{var_name}_std'] = var_stats['std']
        
        np.savez(stats_path, **stats_flat)
        print(f"  ‚úÖ Saved normalization statistics")
    
    print(f"\n‚úÖ All data saved to: {output_path}")
    return output_path

# Save data
zarr_path = save_as_zarr(normalized_splits, norm_stats, config.OUTPUT_DIR)

# Also save as compact npz for small datasets
print(f"\nSaving compact .npz format...")
npz_path = config.OUTPUT_DIR / 'climsim_nyc_processed.npz'

npz_data = {}
for split_name, split_data in normalized_splits.items():
    for var_name, var_data in split_data['inputs'].items():
        npz_data[f'{split_name}_input_{var_name}'] = var_data
    for var_name, var_data in split_data['outputs'].items():
        npz_data[f'{split_name}_output_{var_name}'] = var_data

np.savez_compressed(npz_path, **npz_data)
print(f"‚úÖ Saved to: {npz_path}")
print(f"   Size: {npz_path.stat().st_size / 1e6:.2f} MB")

## 11. Checkpoint Management with Orbax

Set up checkpoint saving/loading for model training.

In [None]:
class CheckpointManager:
    """Manage model checkpoints with Orbax."""
    
    def __init__(self, checkpoint_dir):
        self.checkpoint_dir = Path(checkpoint_dir)
        self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
        
        # Create Orbax checkpoint manager
        self.checkpointer = ocp.PyTreeCheckpointer()
        
        print(f"Checkpoint manager initialized:")
        print(f"  Directory: {self.checkpoint_dir}")
    
    def save(self, state, step):
        """Save checkpoint."""
        ckpt_path = self.checkpoint_dir / f"checkpoint_{step}"
        self.checkpointer.save(ckpt_path, state)
        print(f"‚úÖ Checkpoint saved: {ckpt_path}")
        return ckpt_path
    
    def restore(self, step, state_template):
        """Restore checkpoint."""
        ckpt_path = self.checkpoint_dir / f"checkpoint_{step}"
        restored = self.checkpointer.restore(ckpt_path, item=state_template)
        print(f"‚úÖ Checkpoint restored from: {ckpt_path}")
        return restored
    
    def latest_checkpoint(self):
        """Find latest checkpoint."""
        checkpoints = list(self.checkpoint_dir.glob("checkpoint_*"))
        if not checkpoints:
            return None
        
        # Sort by step number
        latest = max(checkpoints, key=lambda p: int(p.name.split('_')[1]))
        return latest

# Initialize checkpoint manager
ckpt_manager = CheckpointManager(config.CHECKPOINT_DIR)

# Example: Save preprocessing config
preprocessing_config = {
    'nyc_bbox': {
        'lat_min': config.NYC_LAT_MIN,
        'lat_max': config.NYC_LAT_MAX,
        'lon_min': config.NYC_LON_MIN,
        'lon_max': config.NYC_LON_MAX,
    },
    'normalization': config.NORMALIZE,
    'n_samples': {
        'train': train_loader.n_samples,
        'val': val_loader.n_samples,
        'test': test_loader.n_samples,
    },
    'batch_size': config.BATCH_SIZE,
    'input_vars': input_vars,
    'output_vars': output_vars,
}

# Save config
config_path = config.CHECKPOINT_DIR / 'preprocessing_config.npz'
np.savez(config_path, **{k: str(v) for k, v in preprocessing_config.items()})
print(f"\n‚úÖ Preprocessing config saved to: {config_path}")

print(f"\nüí° Usage in training:")
print(f"   # Save model checkpoint")
print(f"   ckpt_manager.save(train_state, step=epoch)")
print(f"   ")
print(f"   # Restore checkpoint")
print(f"   train_state = ckpt_manager.restore(step=10, state_template=train_state)")

## 12. Lazy Loading Utility

Create utility for lazy loading large datasets.

In [None]:
class LazyDataLoader:
    """Lazy data loader for large datasets."""
    
    def __init__(self, zarr_path, split='train', batch_size=32):
        """
        Args:
            zarr_path: Path to zarr directory
            split: Which split to load ('train', 'val', 'test')
            batch_size: Batch size
        """
        self.zarr_path = Path(zarr_path) / f"{split}.zarr"
        self.batch_size = batch_size
        
        # Open zarr without loading into memory
        self.ds = xr.open_zarr(self.zarr_path)
        
        # Get dimensions
        self.n_samples = self.ds.dims['sample']
        self.n_batches = self.n_samples // batch_size
        
        print(f"Lazy loader initialized:")
        print(f"  Split: {split}")
        print(f"  Samples: {self.n_samples}")
        print(f"  Batches: {self.n_batches}")
        print(f"  Variables: {list(self.ds.data_vars)[:5]}...")
    
    def get_batch(self, batch_idx):
        """Load a single batch (lazy loading)."""
        start_idx = batch_idx * self.batch_size
        end_idx = start_idx + self.batch_size
        
        # Load only this batch from disk
        batch_ds = self.ds.isel(sample=slice(start_idx, end_idx))
        batch_ds = batch_ds.load()  # Load into memory
        
        # Convert to JAX arrays
        inputs = {}
        outputs = {}
        
        for var_name in batch_ds.data_vars:
            if var_name.startswith('input_'):
                key = var_name.replace('input_', '')
                inputs[key] = jnp.array(batch_ds[var_name].values)
            elif var_name.startswith('output_'):
                key = var_name.replace('output_', '')
                outputs[key] = jnp.array(batch_ds[var_name].values)
        
        return inputs, outputs
    
    def __iter__(self):
        for batch_idx in range(self.n_batches):
            yield self.get_batch(batch_idx)
    
    def __len__(self):
        return self.n_batches

# Test lazy loader
print(f"Testing lazy data loader...")
lazy_loader = LazyDataLoader(zarr_path, split='train', batch_size=config.BATCH_SIZE)

# Load one batch
batch_inputs, batch_outputs = lazy_loader.get_batch(0)
print(f"\n‚úÖ Lazy loading works!")
print(f"   Loaded batch shapes:")
for var_name, var_data in list(batch_inputs.items())[:3]:
    print(f"     {var_name}: {var_data.shape}")

print(f"\nüí° Benefits of lazy loading:")
print(f"   - Only loads data when needed")
print(f"   - Works with datasets larger than memory")
print(f"   - Efficient for distributed training")

## 13. Visualization: Data Distribution

Visualize normalized data distributions.

In [None]:
# Plot distributions before and after normalization
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Select a variable to visualize
example_input_var = list(input_vars)[0]
example_output_var = list(output_vars)[0]

# Before normalization (from original data)
ax = axes[0, 0]
data_raw = inputs_np[example_input_var].flatten()
ax.hist(data_raw, bins=50, alpha=0.7, color='blue', edgecolor='black')
ax.set_title(f'Input: {example_input_var} (Raw)', fontweight='bold', fontsize=12)
ax.set_xlabel('Value')
ax.set_ylabel('Frequency')
ax.grid(True, alpha=0.3)

# After normalization
ax = axes[0, 1]
data_norm = normalized_splits['train']['inputs'][example_input_var].flatten()
ax.hist(data_norm, bins=50, alpha=0.7, color='green', edgecolor='black')
ax.set_title(f'Input: {example_input_var} (Normalized)', fontweight='bold', fontsize=12)
ax.set_xlabel('Value (standardized)')
ax.set_ylabel('Frequency')
ax.axvline(0, color='r', linestyle='--', linewidth=2, label='Mean=0')
ax.legend()
ax.grid(True, alpha=0.3)

# Output raw
ax = axes[1, 0]
data_raw = outputs_np[example_output_var].flatten()
ax.hist(data_raw, bins=50, alpha=0.7, color='blue', edgecolor='black')
ax.set_title(f'Output: {example_output_var} (Raw)', fontweight='bold', fontsize=12)
ax.set_xlabel('Value')
ax.set_ylabel('Frequency')
ax.grid(True, alpha=0.3)

# Output normalized
ax = axes[1, 1]
data_norm = normalized_splits['train']['outputs'][example_output_var].flatten()
ax.hist(data_norm, bins=50, alpha=0.7, color='green', edgecolor='black')
ax.set_title(f'Output: {example_output_var} (Normalized)', fontweight='bold', fontsize=12)
ax.set_xlabel('Value (standardized)')
ax.set_ylabel('Frequency')
ax.axvline(0, color='r', linestyle='--', linewidth=2, label='Mean=0')
ax.legend()
ax.grid(True, alpha=0.3)

plt.suptitle('Data Distributions: Before and After Normalization', 
             fontsize=14, fontweight='bold', y=0.995)
plt.tight_layout()
plt.show()

print("‚úÖ Distribution visualization complete!")

## Summary & Next Steps

### What We Built

1. ‚úÖ **Data Loading** - Loaded ClimSim from Hugging Face
2. ‚úÖ **NYC Filtering** - Filtered samples to NYC region (40.5-41¬∞N, -74.3--73.7¬∞W)
3. ‚úÖ **Train/Val/Test Splits** - 80/10/10 split with proper shuffling
4. ‚úÖ **Normalization** - Computed stats from training split only
5. ‚úÖ **JAX Data Loaders** - Efficient batching with jax.numpy arrays
6. ‚úÖ **Multi-GPU Sharding** - Device parallelism support with pmap
7. ‚úÖ **Checkpoint Management** - Orbax-based checkpoint saving/loading
8. ‚úÖ **Lazy Loading** - Memory-efficient loading for large datasets
9. ‚úÖ **Persistence** - Saved processed data as Zarr and npz

### Saved Files

```
/home/jovyan/leap-scratch/$USER/
‚îú‚îÄ‚îÄ climsim_processed/
‚îÇ   ‚îú‚îÄ‚îÄ train.zarr/          # Training data (Zarr format)
‚îÇ   ‚îú‚îÄ‚îÄ val.zarr/            # Validation data
‚îÇ   ‚îú‚îÄ‚îÄ test.zarr/           # Test data
‚îÇ   ‚îú‚îÄ‚îÄ norm_stats.npz       # Normalization statistics
‚îÇ   ‚îî‚îÄ‚îÄ climsim_nyc_processed.npz  # Compact npz format
‚îú‚îÄ‚îÄ checkpoints/
‚îÇ   ‚îî‚îÄ‚îÄ preprocessing_config.npz   # Preprocessing configuration
```

### Usage in Training

```python
# Load processed data
from pathlib import Path

# Quick loading from npz
data = np.load(config.OUTPUT_DIR / 'climsim_nyc_processed.npz')
train_inputs = {k.split('_', 2)[2]: data[k] 
                for k in data.files if k.startswith('train_input_')}

# Or use lazy loading for large datasets
lazy_train = LazyDataLoader(config.OUTPUT_DIR, split='train', batch_size=32)
for batch_inputs, batch_outputs in lazy_train:
    # Train your model
    loss = train_step(model, batch_inputs, batch_outputs)
```

### Next Steps

1. **Build Model Architecture** - Create JAX/Flax neural network
2. **Training Loop** - Implement training with pmap for multi-GPU
3. **Evaluation** - Compute metrics on validation set
4. **Hyperparameter Tuning** - Optimize model configuration
5. **Production Deployment** - Save and serve final model

### Key Files to Reference

- Data loaders: `JAXDataLoader`, `LazyDataLoader`
- Checkpoint manager: `CheckpointManager`
- Normalization stats: `norm_stats` dictionary
- NYC filtering: `filter_nyc_samples()` function

Happy training! üöÄüåç