# Earth System Simulation - Google Colab Version

This notebook runs the Earth system simulation with debug output.

In [None]:
# Install dependencies
!pip install torch numpy matplotlib pyyaml tqdm plotly

# Clone repository
!git clone https://github.com/yourusername/earth_system_sim.git
%cd earth_system_sim

# Install package
!pip install -e .

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import yaml
from pathlib import Path
import sys
from typing import Dict, List, Tuple, Optional
from tqdm.notebook import tqdm

# Add project root to Python path
project_root = Path.cwd()
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))

from scripts.run_simulation import EarthSystemSimulation

In [None]:
# Load and validate configuration
with open('config/model_config.yaml', 'r') as f:
    config = yaml.safe_load(f)

# Print configuration
print("Configuration:")
print("\nPhysical System:")
print(f"  Input dimension: {config['physical_system']['input_dim']}")
print(f"  Grid size: {config['grid_height']}x{config['grid_width']}")

print("\nBiosphere System:")
print(f"  State dimension: {config['biosphere']['state_dim']}")
print(f"  Action dimension: {config['biosphere']['action_dim']}")

print("\nGeosphere System:")
print(f"  State dimension: {config['geosphere']['state_dim']}")
print(f"  Action dimension: {config['geosphere']['action_dim']}")

# Save modified config
with open('config/colab_config.yaml', 'w') as f:
    yaml.dump(config, f)

In [None]:
def print_tensor_info(name: str, tensor: torch.Tensor):
    """Debug helper to print tensor information."""
    print(f"{name}:")
    print(f"  Shape: {tensor.shape}")
    print(f"  Device: {tensor.device}")
    print(f"  Requires grad: {tensor.requires_grad}")
    if tensor.numel() > 0:
        print(f"  Range: [{tensor.min().item():.3f}, {tensor.max().item():.3f}]")
    print()

In [None]:
def run_simulation_with_debug(
    sim: EarthSystemSimulation,
    num_steps: int,
    save_frequency: int
) -> Dict:
    """Run simulation with debug output."""
    try:
        print("Starting simulation...")
        
        # Initialize states with debug output
        with torch.no_grad():
            physical_state, biosphere_state, geosphere_state = sim._initialize_states()
            
            print("\nInitial states:")
            print_tensor_info("Physical", physical_state)
            print_tensor_info("Biosphere", biosphere_state)
            print_tensor_info("Geosphere", geosphere_state)
        
        # Storage for trajectory
        trajectory = {
            'physical': [],
            'biosphere': [],
            'geosphere': [],
            'times': []
        }
        
        # Save initial states
        with torch.no_grad():
            trajectory['physical'].append(physical_state.cpu().numpy())
            trajectory['biosphere'].append(biosphere_state.cpu().numpy())
            trajectory['geosphere'].append(geosphere_state.cpu().numpy())
            trajectory['times'].append(sim.synchronizer.current_times)
        
        # Run simulation steps with progress bar
        for step in tqdm(range(num_steps), desc='Simulation Progress'):
            if step % 10 == 0:
                print(f"\nStep {step}/{num_steps}:")
            
            # Run timestep with gradient disabled
            with torch.no_grad():
                physical_state, biosphere_state, geosphere_state = sim.run_timestep(
                    physical_state, biosphere_state, geosphere_state
                )
            
            # Save states periodically
            if step % save_frequency == 0:
                with torch.no_grad():
                    trajectory['physical'].append(physical_state.cpu().numpy())
                    trajectory['biosphere'].append(biosphere_state.cpu().numpy())
                    trajectory['geosphere'].append(geosphere_state.cpu().numpy())
                    trajectory['times'].append(sim.synchronizer.current_times)
                    
                    if step % 10 == 0:
                        print("\nCurrent states:")
                        print_tensor_info("Physical", physical_state)
                        print_tensor_info("Biosphere", biosphere_state)
                        print_tensor_info("Geosphere", geosphere_state)
        
        print("\nSimulation completed successfully!")
        return trajectory
    
    except Exception as e:
        print(f"\nError during simulation: {str(e)}")
        
        # Print current state information
        print("\nState information at error:")
        try:
            print_tensor_info("Physical", physical_state)
            print_tensor_info("Biosphere", biosphere_state)
            print_tensor_info("Geosphere", geosphere_state)
        except:
            print("Could not print state information")
            
        raise

In [None]:
# Initialize simulation
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}")

sim = EarthSystemSimulation(
    config_path='config/colab_config.yaml',
    device=device
)
sim.debug = True  # Enable debug mode after initialization

# Run simulation with debug output
trajectory = run_simulation_with_debug(sim, num_steps=50, save_frequency=5)

In [None]:
def plot_results(trajectory: Dict):
    """Plot simulation results with debug information."""
    fig, axes = plt.subplots(3, 1, figsize=(12, 12))
    
    # Plot physical system results
    temp_mean = np.mean([state[..., 1] for state in trajectory['physical']], axis=(1, 2))
    axes[0].plot(temp_mean, 'r-', label='Mean Temperature')
    axes[0].set_title('Physical System')
    axes[0].grid(True)
    axes[0].legend()
    
    # Plot biosphere results
    veg_mean = np.mean([state[..., 0] for state in trajectory['biosphere']], axis=1)
    axes[1].plot(veg_mean, 'g-', label='Mean Vegetation')
    axes[1].set_title('Biosphere System')
    axes[1].grid(True)
    axes[1].legend()
    
    # Plot geosphere results
    elev_mean = np.mean([state[..., 0] for state in trajectory['geosphere']], axis=1)
    axes[2].plot(elev_mean, 'b-', label='Mean Elevation')
    axes[2].set_title('Geosphere System')
    axes[2].grid(True)
    axes[2].legend()
    
    plt.tight_layout()
    plt.show()
    
    # Print statistical information
    print("\nTrajectory Statistics:")
    print("\nPhysical System:")
    print(f"  Temperature range: [{temp_mean.min():.3f}, {temp_mean.max():.3f}]")
    
    print("\nBiosphere System:")
    print(f"  Vegetation range: [{veg_mean.min():.3f}, {veg_mean.max():.3f}]")
    
    print("\nGeosphere System:")
    print(f"  Elevation range: [{elev_mean.min():.3f}, {elev_mean.max():.3f}]")

# Plot results with debug information
plot_results(trajectory)