# Earth System Simulation Example

This notebook demonstrates how to use the Earth system simulation framework to:
1. Set up and run a simulation
2. Visualize the results
3. Analyze component interactions

In [None]:
import sys
from pathlib import Path
import torch
import numpy as np
import matplotlib.pyplot as plt
import yaml

# Add project root to Python path
project_root = Path.cwd().parent
sys.path.append(str(project_root))

from scripts.run_simulation import EarthSystemSimulation
from scripts.visualize_results import create_visualizations

## 1. Setup

First, let's load the configuration and create a simulation instance.

In [None]:
# Load config
config_path = project_root / 'config' / 'model_config.yaml'

# Create simulation
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
sim = EarthSystemSimulation(str(config_path), device)

print(f"Using device: {device}")

## 2. Run Simulation

Now we'll run a short simulation and examine the results.

In [None]:
# Run simulation
trajectory = sim.run_simulation(
    num_steps=100,
    save_frequency=10
)

print("Simulation complete!")
print(f"Saved {len(trajectory['times'])} timesteps")

## 3. Basic Visualization

Let's create some basic visualizations of the simulation results.

In [None]:
# Load visualization config
viz_config_path = project_root / 'config' / 'visualization_config.yaml'
with open(viz_config_path, 'r') as f:
    viz_config = yaml.safe_load(f)

# Create output directory
output_dir = project_root / 'example_outputs'
output_dir.mkdir(exist_ok=True)

# Create visualizations
create_visualizations(
    trajectory,
    viz_config,
    str(output_dir),
    make_animations=True
)

## 4. Analyze Component Interactions

Let's examine how the different components interact over time.

In [None]:
# Extract time series
times = np.array([t['physical'] for t in trajectory['times']])
physical_states = np.array(trajectory['physical'])
bio_states = np.array(trajectory['biosphere'])
geo_states = np.array(trajectory['geosphere'])

# Create subplot figure
fig, axes = plt.subplots(3, 1, figsize=(12, 12), sharex=True)

# Plot physical system (mean temperature)
temp_mean = np.mean(physical_states[..., 0], axis=(1, 2))
axes[0].plot(times, temp_mean, 'r-', label='Mean Temperature')
axes[0].set_ylabel('Temperature (K)')
axes[0].legend()
axes[0].grid(True)

# Plot biosphere system (mean vegetation)
veg_mean = np.mean(bio_states[..., 0], axis=0)
axes[1].plot(times, veg_mean, 'g-', label='Mean Vegetation')
axes[1].set_ylabel('Vegetation Density')
axes[1].legend()
axes[1].grid(True)

# Plot geosphere system (mean elevation)
elev_mean = np.mean(geo_states[..., 0], axis=(1, 2))
axes[2].plot(times, elev_mean, 'b-', label='Mean Elevation')
axes[2].set_xlabel('Time')
axes[2].set_ylabel('Elevation (m)')
axes[2].legend()
axes[2].grid(True)

plt.tight_layout()
plt.show()

## 5. Examine Conservation Laws

Let's verify that physical conservation laws are being respected.

In [None]:
# Create input tensor for PINN
physical_input = torch.tensor(physical_states[np.newaxis], device=device)

# Get predictions and physics losses
with torch.no_grad():
    predictions, physics_losses = sim.physical(physical_input)

# Plot physics losses
fig, ax = plt.subplots(figsize=(10, 6))

for name, loss in physics_losses.items():
    ax.bar(name, loss.cpu().item())

ax.set_title('Physics Conservation Losses')
ax.set_ylabel('Loss Value')
ax.grid(True)

plt.show()

## 6. Interactive Visualization

Finally, let's create an interactive visualization using Plotly.

In [None]:
import plotly.graph_objects as go
from plotly.subplots import make_subplots

# Create figure with subplots
fig = make_subplots(rows=2, cols=2)

# Add temperature heatmap
fig.add_trace(
    go.Heatmap(z=physical_states[-1, :, :, 0],
               colorscale='RdBu_r',
               name='Temperature'),
    row=1, col=1
)

# Add vegetation heatmap
fig.add_trace(
    go.Heatmap(z=bio_states[-1, :, :, 0],
               colorscale='YlGn',
               name='Vegetation'),
    row=1, col=2
)

# Add elevation surface
fig.add_trace(
    go.Surface(z=geo_states[-1, :, :, 0],
              colorscale='terrain',
              name='Elevation'),
    row=2, col=1
)

# Add time series
fig.add_trace(
    go.Scatter(x=times, y=temp_mean,
               name='Temperature',
               line=dict(color='red')),
    row=2, col=2
)

# Update layout
fig.update_layout(
    title='Earth System Components',
    height=800,
    width=1000,
    showlegend=True
)

fig.show()