# ClimSim Dataset Analysis

Deep dive into the ClimSim low-resolution dataset from Hugging Face.

## About ClimSim

ClimSim is a large-scale climate simulation dataset designed for training machine learning emulators. It contains:
- **Inputs**: Atmospheric state variables (temperature, humidity, etc.) at multiple vertical levels
- **Outputs**: Physical tendencies (how variables change over time)
- **Goal**: Train ML models to predict climate physics faster than traditional simulators

**Dataset:** [LEAP/ClimSim_low-res](https://huggingface.co/datasets/LEAP/ClimSim_low-res)

**Prerequisites:** Run `leap_startup.ipynb` first!

In [None]:
# Import required packages
import numpy as np
import pandas as pd
import xarray as xr
import matplotlib.pyplot as plt
from pathlib import Path
import os

# Hugging Face
from datasets import load_dataset
from huggingface_hub import hf_hub_download

# Set style for better plots
plt.style.use('seaborn-v0_8-darkgrid' if 'seaborn-v0_8-darkgrid' in plt.style.available else 'default')

print("‚úÖ All imports successful!")

## 1. Load ClimSim Dataset

We'll load a small subsample first to understand the structure.

In [None]:
repo_id = "LEAP/ClimSim_low-res"

print(f"Loading ClimSim dataset from: {repo_id}\n")
print("=" * 70)

# Load a small subset first (streaming mode to avoid downloading everything)
# We'll take just 1000 samples to analyze
try:
    dataset = load_dataset(
        repo_id,
        split="train[:1000]",  # Load first 1000 samples
        streaming=False  # Download this subset fully for analysis
    )
    
    print(f"‚úÖ Successfully loaded {len(dataset)} samples!\n")
    print(f"Dataset info:")
    print(f"  - Number of samples: {len(dataset)}")
    print(f"  - Features: {list(dataset.features.keys())}")
    
except Exception as e:
    print(f"‚ö†Ô∏è Error loading dataset: {e}")
    print("\nNote: This might fail if:")
    print("  1. The dataset structure is different than expected")
    print("  2. You need to accept the dataset terms on Hugging Face")
    print("  3. The dataset requires authentication")
    
    # Create synthetic data for demonstration
    print("\nüìù Creating synthetic ClimSim-like data for demonstration...")
    
    # Typical ClimSim structure
    n_samples = 1000
    n_levels = 60  # Vertical levels
    
    # Input variables (atmospheric state)
    state_t = np.random.randn(n_samples, n_levels) * 30 + 250  # Temperature (K)
    state_q0001 = np.random.randn(n_samples, n_levels) * 0.002 + 0.005  # Specific humidity
    state_ps = np.random.randn(n_samples) * 5000 + 100000  # Surface pressure (Pa)
    
    # Output variables (tendencies)
    ptend_t = np.random.randn(n_samples, n_levels) * 0.1  # Temperature tendency
    ptend_q0001 = np.random.randn(n_samples, n_levels) * 1e-6  # Humidity tendency
    
    # Create dictionary mimicking ClimSim structure
    dataset = {
        'state_t': state_t,
        'state_q0001': state_q0001,
        'state_ps': state_ps,
        'ptend_t': ptend_t,
        'ptend_q0001': ptend_q0001,
    }
    
    print("‚úÖ Synthetic dataset created for demonstration")
    print("   (Replace with real data loading above when available)")

## 2. Inspect Data Structure

Let's examine the shapes and types of inputs and outputs.

In [None]:
# Convert dataset to numpy arrays for inspection
if hasattr(dataset, '__getitem__') and hasattr(dataset, 'features'):
    # HuggingFace dataset
    first_sample = dataset[0]
    
    # Separate input (state) and output (ptend) variables
    input_vars = {k: v for k, v in first_sample.items() if k.startswith('state_')}
    output_vars = {k: v for k, v in first_sample.items() if k.startswith('ptend_')}
    
    print("=" * 70)
    print("INPUT VARIABLES (Atmospheric State)")
    print("=" * 70)
    for var_name in sorted(input_vars.keys()):
        var_data = np.array([dataset[i][var_name] for i in range(len(dataset))])
        print(f"{var_name:20s} shape: {str(var_data.shape):20s} dtype: {var_data.dtype}")
    
    print("\n" + "=" * 70)
    print("OUTPUT VARIABLES (Physical Tendencies)")
    print("=" * 70)
    for var_name in sorted(output_vars.keys()):
        var_data = np.array([dataset[i][var_name] for i in range(len(dataset))])
        print(f"{var_name:20s} shape: {str(var_data.shape):20s} dtype: {var_data.dtype}")
    
    # Store for later use
    data_dict = {}
    for key in list(input_vars.keys()) + list(output_vars.keys()):
        data_dict[key] = np.array([dataset[i][key] for i in range(len(dataset))])
        
else:
    # Synthetic dictionary
    data_dict = dataset
    
    input_vars = {k: v for k, v in data_dict.items() if k.startswith('state_')}
    output_vars = {k: v for k, v in data_dict.items() if k.startswith('ptend_')}
    
    print("=" * 70)
    print("INPUT VARIABLES (Atmospheric State)")
    print("=" * 70)
    for var_name in sorted(input_vars.keys()):
        print(f"{var_name:20s} shape: {str(input_vars[var_name].shape):20s} dtype: {input_vars[var_name].dtype}")
    
    print("\n" + "=" * 70)
    print("OUTPUT VARIABLES (Physical Tendencies)")
    print("=" * 70)
    for var_name in sorted(output_vars.keys()):
        print(f"{var_name:20s} shape: {str(output_vars[var_name].shape):20s} dtype: {output_vars[var_name].dtype}")

print("\nüí° Interpretation:")
print("  - state_* variables: Current atmospheric conditions (inputs)")
print("  - ptend_* variables: How variables change over time (outputs/targets)")
print("  - Vertical dimension: Multiple atmospheric levels (surface to top)")

## 3. Variable Details

ClimSim typically includes these variables:

In [None]:
variable_descriptions = {
    # Input variables (state)
    'state_t': 'Temperature [K] at each vertical level',
    'state_q0001': 'Specific humidity [kg/kg] - water vapor',
    'state_q0002': 'Cloud liquid water [kg/kg]',
    'state_q0003': 'Cloud ice [kg/kg]',
    'state_u': 'Zonal wind [m/s] - eastward component',
    'state_v': 'Meridional wind [m/s] - northward component',
    'state_ps': 'Surface pressure [Pa] - single value per column',
    
    # Output variables (tendencies)
    'ptend_t': 'Temperature tendency [K/s] - rate of temperature change',
    'ptend_q0001': 'Specific humidity tendency [kg/kg/s]',
    'ptend_q0002': 'Cloud liquid water tendency [kg/kg/s]',
    'ptend_q0003': 'Cloud ice tendency [kg/kg/s]',
    'ptend_u': 'Zonal wind tendency [m/s¬≤]',
    'ptend_v': 'Meridional wind tendency [m/s¬≤]',
}\n
print("=" * 70)
print("VARIABLE DESCRIPTIONS")
print("=" * 70)

present_vars = list(data_dict.keys())

print("\nüìä INPUT VARIABLES (State):")
for var, desc in variable_descriptions.items():
    if var.startswith('state_') and var in present_vars:
        print(f"  ‚úÖ {var:15s} - {desc}")
    elif var.startswith('state_'):
        print(f"  ‚ö™ {var:15s} - {desc} (not in this dataset)")

print("\nüìà OUTPUT VARIABLES (Tendencies):")
for var, desc in variable_descriptions.items():
    if var.startswith('ptend_') and var in present_vars:
        print(f"  ‚úÖ {var:15s} - {desc}")
    elif var.startswith('ptend_'):
        print(f"  ‚ö™ {var:15s} - {desc} (not in this dataset)")

print(f"\nüìã Total variables in dataset: {len(present_vars)}")

## 4. Vertical Levels

Climate models use vertical levels to represent the atmosphere from surface to top.

In [None]:
# Determine number of vertical levels
for var_name in sorted(input_vars.keys()):
    if len(data_dict[var_name].shape) == 2:  # 2D array (samples, levels)
        n_samples, n_levels = data_dict[var_name].shape
        break

print("=" * 70)
print("VERTICAL STRUCTURE")
print("=" * 70)
print(f"Number of samples:        {n_samples}")
print(f"Number of vertical levels: {n_levels}")
print(f"\nüí° Vertical levels typically represent:")
print(f"   - Level 0:  Top of atmosphere (~0-10 hPa)")
print(f"   - Level {n_levels//2}: Mid-troposphere (~500 hPa)")
print(f"   - Level {n_levels-1}: Near surface (~1000 hPa)")
print(f"\n   (Lower pressure = higher altitude)")

# Create approximate pressure levels (typical for climate models)
# These are hybrid sigma-pressure coordinates
if n_levels == 60:
    # Typical 60-level configuration
    pressure_levels = np.linspace(10, 1000, n_levels)  # hPa (mb)
elif n_levels == 30:
    pressure_levels = np.linspace(50, 1000, n_levels)
else:
    pressure_levels = np.linspace(100, 1000, n_levels)

print(f"\nApproximate pressure levels (hPa):")
print(f"   Top (level 0):    {pressure_levels[0]:.1f} hPa")
print(f"   Middle:           {pressure_levels[n_levels//2]:.1f} hPa")
print(f"   Bottom (level {n_levels-1}): {pressure_levels[-1]:.1f} hPa")

## 5. Sample Statistics

Compute mean and standard deviation for each variable.

In [None]:
# Compute statistics for all variables
print("=" * 70)
print("SAMPLE STATISTICS")
print("=" * 70)

print("\nüìä INPUT VARIABLES (State):")
print(f"{'Variable':<20} {'Mean':>15} {'Std':>15} {'Min':>15} {'Max':>15}")
print("-" * 80)

for var_name in sorted(input_vars.keys()):
    data = data_dict[var_name]
    mean_val = np.mean(data)
    std_val = np.std(data)
    min_val = np.min(data)
    max_val = np.max(data)
    
    print(f"{var_name:<20} {mean_val:>15.6f} {std_val:>15.6f} {min_val:>15.6f} {max_val:>15.6f}")

print("\nüìà OUTPUT VARIABLES (Tendencies):")
print(f"{'Variable':<20} {'Mean':>15} {'Std':>15} {'Min':>15} {'Max':>15}")
print("-" * 80)

for var_name in sorted(output_vars.keys()):
    data = data_dict[var_name]
    mean_val = np.mean(data)
    std_val = np.std(data)
    min_val = np.min(data)
    max_val = np.max(data)
    
    print(f"{var_name:<20} {mean_val:>15.9f} {std_val:>15.9f} {min_val:>15.9f} {max_val:>15.9f}")

print("\nüí° Notes on statistics:")
print("  - Input variables: Represent physical atmospheric state")
print("  - Output tendencies: Typically much smaller (rates of change)")
print("  - These statistics are crucial for normalization in ML training")

## 6. Check for Pre-Applied Normalization

In [None]:
# Check if data appears normalized
print("=" * 70)
print("NORMALIZATION CHECK")
print("=" * 70)

print("\nChecking if data is pre-normalized (mean~0, std~1)...\n")

def check_normalization(data, var_name):
    mean = np.mean(data)
    std = np.std(data)
    
    # Check if close to standard normal
    is_normalized = (abs(mean) < 0.5 and 0.5 < std < 1.5)
    
    status = "‚úÖ Likely normalized" if is_normalized else "‚ö™ Not normalized (raw data)"
    print(f"{var_name:<20} mean={mean:>8.3f}, std={std:>8.3f}  {status}")
    
    return is_normalized

print("INPUT VARIABLES:")
input_normalized = []
for var_name in sorted(input_vars.keys()):
    is_norm = check_normalization(data_dict[var_name], var_name)
    input_normalized.append(is_norm)

print("\nOUTPUT VARIABLES:")
output_normalized = []
for var_name in sorted(output_vars.keys()):
    is_norm = check_normalization(data_dict[var_name], var_name)
    output_normalized.append(is_norm)

if any(input_normalized) or any(output_normalized):
    print("\n‚úÖ Some variables appear pre-normalized")
    print("   You may not need additional normalization for ML training")
else:
    print("\n‚ö™ Variables appear to be raw (not normalized)")
    print("   You should normalize before ML training:")
    print("   - Standardization: (x - mean) / std")
    print("   - Min-max scaling: (x - min) / (max - min)")

## 7. Visualization: Temperature Vertical Profile

Visualize temperature as a function of height for a single atmospheric column.

In [None]:
# Get temperature data
temp_data = data_dict['state_t']

# Select a random column to visualize
sample_idx = np.random.randint(0, temp_data.shape[0])
temp_profile = temp_data[sample_idx, :]

# Create figure
fig, ax = plt.subplots(1, 1, figsize=(8, 10))

# Plot temperature vs pressure (height proxy)
ax.plot(temp_profile, pressure_levels, 'b-', linewidth=2, marker='o', markersize=4)

# Formatting
ax.set_xlabel('Temperature (K)', fontsize=12, fontweight='bold')
ax.set_ylabel('Pressure (hPa)', fontsize=12, fontweight='bold')
ax.set_title(f'Vertical Temperature Profile\\nSample #{sample_idx}', 
             fontsize=14, fontweight='bold', pad=20)

# Invert y-axis (pressure decreases with height)
ax.invert_yaxis()

# Add grid
ax.grid(True, alpha=0.3, linestyle='--')

# Add annotations
ax.axhline(y=500, color='r', linestyle='--', alpha=0.5, label='~500 hPa (mid-troposphere)')
ax.axhline(y=200, color='orange', linestyle='--', alpha=0.5, label='~200 hPa (upper troposphere)')

# Add freezing point reference
if temp_profile.min() < 273.15 < temp_profile.max():
    # Find approximate altitude where temp = 273.15K
    freezing_idx = np.argmin(np.abs(temp_profile - 273.15))
    ax.axvline(x=273.15, color='cyan', linestyle=':', alpha=0.7, label='Freezing point (273.15 K)')
    ax.plot(273.15, pressure_levels[freezing_idx], 'c*', markersize=15)

ax.legend(loc='best', fontsize=10)

# Add text with statistics
stats_text = f'Mean: {np.mean(temp_profile):.2f} K\\nStd: {np.std(temp_profile):.2f} K\\nMin: {np.min(temp_profile):.2f} K\\nMax: {np.max(temp_profile):.2f} K'
ax.text(0.02, 0.98, stats_text, transform=ax.transAxes, 
        verticalalignment='top', bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5),
        fontsize=9, family='monospace')

plt.tight_layout()
plt.show()

print(f"\\n‚úÖ Visualized temperature profile for sample #{sample_idx}")
print(f"   Temperature range: {temp_profile.min():.2f} - {temp_profile.max():.2f} K")
print(f"   Vertical levels: {len(temp_profile)}")

## 8. Visualization: Temperature Tendency Profile

Visualize how temperature changes over time (the target variable for ML prediction).

In [None]:
# Get temperature tendency data
temp_tend_data = data_dict['ptend_t']

# Use same sample as before for consistency
temp_tend_profile = temp_tend_data[sample_idx, :]

# Create figure
fig, ax = plt.subplots(1, 1, figsize=(8, 10))

# Plot tendency vs pressure
ax.plot(temp_tend_profile, pressure_levels, 'r-', linewidth=2, marker='s', markersize=4)

# Add zero line
ax.axvline(x=0, color='k', linestyle='-', alpha=0.3, linewidth=1)

# Formatting
ax.set_xlabel('Temperature Tendency (K/s)', fontsize=12, fontweight='bold')
ax.set_ylabel('Pressure (hPa)', fontsize=12, fontweight='bold')
ax.set_title(f'Vertical Temperature Tendency Profile\\nSample #{sample_idx}', 
             fontsize=14, fontweight='bold', pad=20)

# Invert y-axis
ax.invert_yaxis()

# Add grid
ax.grid(True, alpha=0.3, linestyle='--')

# Add shading for cooling/warming
warming_mask = temp_tend_profile > 0
cooling_mask = temp_tend_profile < 0

if np.any(warming_mask):
    ax.fill_betweenx(pressure_levels, 0, temp_tend_profile, 
                      where=warming_mask, alpha=0.2, color='red', label='Warming')
if np.any(cooling_mask):
    ax.fill_betweenx(pressure_levels, 0, temp_tend_profile, 
                      where=cooling_mask, alpha=0.2, color='blue', label='Cooling')

ax.legend(loc='best', fontsize=10)

# Add statistics box
stats_text = f'Mean: {np.mean(temp_tend_profile):.2e} K/s\\nStd: {np.std(temp_tend_profile):.2e} K/s\\nMin: {np.min(temp_tend_profile):.2e} K/s\\nMax: {np.max(temp_tend_profile):.2e} K/s'
ax.text(0.02, 0.98, stats_text, transform=ax.transAxes, 
        verticalalignment='top', bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.5),
        fontsize=9, family='monospace')

# Add interpretation text
interp_text = 'Positive: Heating\\nNegative: Cooling\\n\\nPhysical processes:\\n‚Ä¢ Radiation\\n‚Ä¢ Convection\\n‚Ä¢ Cloud formation'
ax.text(0.98, 0.02, interp_text, transform=ax.transAxes, 
        verticalalignment='bottom', horizontalalignment='right',
        bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.5),
        fontsize=8)

plt.tight_layout()
plt.show()

print(f"\\n‚úÖ Visualized temperature tendency for sample #{sample_idx}")
print(f"   Tendency range: {temp_tend_profile.min():.2e} - {temp_tend_profile.max():.2e} K/s")
print(f"   Mean tendency: {np.mean(temp_tend_profile):.2e} K/s")
print(f"\\nüí° Interpretation:")
print(f"   - Positive values: Warming (heating from radiation, convection)")
print(f"   - Negative values: Cooling (radiative cooling, evaporation)")
print(f"   - Goal of ML emulator: Predict these tendencies from state variables")

## 9. Multi-Variable Comparison

Compare input state and output tendency for multiple variables.

In [None]:
# Create side-by-side comparison
fig, axes = plt.subplots(1, 2, figsize=(14, 8))

# Left plot: Input state variables
ax = axes[0]
if 'state_t' in data_dict:
    ax.plot(data_dict['state_t'][sample_idx], pressure_levels, 'b-', label='Temperature (K)', linewidth=2)

if 'state_q0001' in data_dict:
    # Scale humidity for visibility
    q_scaled = data_dict['state_q0001'][sample_idx] * 1000  # g/kg
    ax2 = ax.twiny()
    ax2.plot(q_scaled, pressure_levels, 'g--', label='Specific humidity (g/kg)', linewidth=2)
    ax2.set_xlabel('Specific Humidity (g/kg)', fontsize=11, color='g')
    ax2.tick_params(axis='x', labelcolor='g')

ax.set_xlabel('Temperature (K)', fontsize=11, color='b')
ax.set_ylabel('Pressure (hPa)', fontsize=12, fontweight='bold')
ax.set_title('Input: Atmospheric State', fontsize=13, fontweight='bold')
ax.invert_yaxis()
ax.grid(True, alpha=0.3)
ax.tick_params(axis='x', labelcolor='b')

# Right plot: Output tendencies
ax = axes[1]
if 'ptend_t' in data_dict:
    ax.plot(data_dict['ptend_t'][sample_idx] * 3600, pressure_levels,  # Convert to K/hour
             'r-', label='Temp. tendency (K/hr)', linewidth=2)

if 'ptend_q0001' in data_dict:
    # Scale humidity tendency for visibility
    q_tend_scaled = data_dict['ptend_q0001'][sample_idx] * 3600 * 1000  # g/kg/hr
    ax2 = ax.twiny()
    ax2.plot(q_tend_scaled, pressure_levels, 'm--', label='Humidity tendency', linewidth=2)
    ax2.set_xlabel('Humidity Tendency (g/kg/hr)', fontsize=11, color='m')
    ax2.tick_params(axis='x', labelcolor='m')

ax.axvline(x=0, color='k', linestyle='-', alpha=0.3)
ax.set_xlabel('Temperature Tendency (K/hr)', fontsize=11, color='r')
ax.set_ylabel('Pressure (hPa)', fontsize=12, fontweight='bold')
ax.set_title('Output: Physical Tendencies', fontsize=13, fontweight='bold')
ax.invert_yaxis()
ax.grid(True, alpha=0.3)
ax.tick_params(axis='x', labelcolor='r')

plt.suptitle(f'ClimSim Data Analysis - Sample #{sample_idx}', 
             fontsize=15, fontweight='bold', y=0.98)
plt.tight_layout()
plt.show()

print("\\n‚úÖ Multi-variable comparison complete!")
print("\\nüí° Machine Learning Task:")
print("   INPUT:  Atmospheric state (temperature, humidity, etc.)")
print("   OUTPUT: Physical tendencies (how state changes)")
print("   GOAL:   Train neural network to predict tendencies from state")

## Summary & Next Steps

### What We Learned

1. **Data Structure**: ClimSim contains atmospheric columns with vertical profiles
2. **Inputs**: State variables (temperature, humidity, winds, etc.)
3. **Outputs**: Tendencies (rates of change) to be predicted
4. **Vertical Levels**: Multiple atmospheric layers from surface to top
5. **Statistics**: Ranges and distributions of each variable
6. **Normalization**: Whether data is pre-normalized or needs preprocessing

### Next Steps for ML Emulator

1. **Data Preprocessing**
   - Normalize/standardize variables
   - Handle missing values
   - Create train/validation/test splits
   
2. **Model Architecture**
   - Design neural network (MLP, CNN, or Transformer)
   - Consider physical constraints
   - Account for vertical structure
   
3. **Training**
   - Define loss function (MSE, MAE, or physics-informed)
   - Use JAX/Flax for efficient training
   - Monitor validation metrics
   
4. **Evaluation**
   - Compare predictions vs targets
   - Check physical consistency
   - Test on unseen data

### Resources

- **ClimSim Paper**: [Link to paper if available]
- **Hugging Face Dataset**: https://huggingface.co/datasets/LEAP/ClimSim_low-res
- **LEAP Documentation**: Check hackathon materials

Good luck building your climate emulator! üåçüöÄ