# Samudra and SamudraAdjoint Demo
 
This notebook demonstrates how to use the Samudra ocean emulator model and its adjoint version for sensitivity analysis.

Samudra is the U-Net model designed by Dheeshjith et al. in "Samudra: An AI Global Ocean Emulator for Climate" (https://arxiv.org/abs/2412.03795). 

SamudraAdjoint is a subclass designed by Shaunticlair Ruiz (author of this document) to compute Adjoints from the Samudra model.


## Setup
First, let's import necessary libraries and set up the path to the Samudra directory. Note that this notebook requires you to already have the Samudra github installed elsewhere (an old version of the notebook! I should fix this in a future version).



Imports:

In [None]:
import sys
import numpy as np
import torch
import xarray as xr
import matplotlib.pyplot as plt
from pathlib import Path

# Import setup utilities and model adjoint module
import setup
import model_adjoint



Getting our path and setting up our device for pytorch:

In [None]:
# Add the Samudra package to the path
# Replace this with your actual path to the Samudra directory
samudra_path = Path("./")  # Assuming the notebook is in the same directory as setup.py
sys.path.append(str(samudra_path))

# Configure the environment for CUDA or CPU and set random seeds for reproducibility
device = setup.torch_config_cuda_cpu_seed()
print(f"Using device: {device}")

Finally, some settings that determine which kind of model we're using.

In [None]:
# Set model parameters
hist = 1  # History length. DO NOT EDIT THIS VALUE.
N_test = 40  # Number of timesteps to use for testing
state_in_vars_config = "3D_thermo_dynamic_all"  # Options: "3D_thermo_all" or "3D_thermo_dynamic_all"
boundary_vars_config = "3D_all_hfds_anom"

## Loading
Now, we'll load up our dataset and model.



There are two models: ```3D_thermo_all``` and ```3D_thermo_dynamic_all```. We need to select the appropriate state variables and boundary variables.

In [None]:
# Get the variable lists and channel counts for our model configuration
list_list_str, list_num_channels = setup.choose_model(
    state_in_vars_config, 
    boundary_vars_config, 
    hist
)

# Unpack the data
input_list_str, boundary_list_str, output_list_str, vars_list_str = list_list_str
num_input_channels, num_output_channels = list_num_channels

print(f"Model will use {num_input_channels} input channels and {num_output_channels} output channels")

Next: we load the dataset and select our desired subsequence.

In [None]:
# Compute data indices
s_train, e_train, s_test, e_test = setup.compute_indices(
    hist=hist, 
    N_samples=2850,  # Used for training
    N_val=50,        # Used for validation 
    N_test=N_test    # Used for testing
)

# Load the data into a Test object
# The Test object handles normalization of the data using pre-computed means and standard deviations
print(f"Loading data from time indices {s_test} to {e_test}...")
test_data, wet, data_mean, data_std = setup.load_data(
    s_test, e_test, N_test,
    input_list_str, boundary_list_str, output_list_str,
    hist=hist, device=device
)
print("Data loaded successfully!")

# The 'wet' mask indicates ocean vs. land areas (1 for ocean, 0 for land)
print(f"Wet mask shape: {wet.shape}")



Finally, we load up our model of choice.

In [None]:
# Import Samudra model class
from model import Samudra

# Initialize the standard Samudra model
samudra_model = Samudra(
    n_out=num_output_channels, 
    ch_width=[num_input_channels]+[200,250,300,400], 
    wet=wet.to(device), 
    hist=hist
)

# Initialize the SamudraAdjoint model (extends Samudra with adjoint capabilities)
samudra_adjoint = model_adjoint.SamudraAdjoint(
    n_out=num_output_channels,
    ch_width=[num_input_channels]+[200,250,300,400],
    wet=wet.to(device),
    hist=hist
)

# Load weights into both models
print("Loading model weights...")
samudra_model = setup.load_weights(samudra_model, state_in_vars_config, device=device)
samudra_adjoint = setup.load_weights(samudra_adjoint, state_in_vars_config, device=device)
print("Model weights loaded successfully!")

# Brief summary of model functionalities
print("\nModel Summary:")
print("- Samudra: A deep learning ocean emulator that can simulate ocean dynamics")
print("  forward in time, predicting temperature, salinity, and sea surface height")
print("  (and optionally ocean velocities).")
print("- SamudraAdjoint: Extends Samudra with adjoint capabilities for sensitivity")
print("  analysis, allowing computation of how changes in initial conditions")
print("  propagate to affect future ocean states.")

## Running
Now we'll use the Samudra model to simulate the ocean state forward by 10 timesteps,and visualize the initial and final states.


Here, we run the model forward.

In [None]:

# Import model function for generating rollouts
from model import generate_model_rollout

# Number of timesteps to run the model forward
n_steps = 10

# Run the model forward to generate a rollout
print(f"Generating a {n_steps}-step rollout...")
model_pred, model_outs = generate_model_rollout(
    n_steps,
    test_data,
    samudra_model,
    hist,
    num_output_channels // (hist + 1),  # N_out
    len(boundary_list_str),             # N_extra
    initial_input=None,
    device=device
)
print("Rollout completed successfully!")

Data processing:

In [None]:
# Convert model predictions to xarray Dataset for easier visualization
from utils import post_processor, convert_train_data

# Create DataArray from model predictions
ds_prediction = xr.DataArray(
    data=model_pred,
    dims=["time", "x", "y", "var"]
)
ds_prediction = ds_prediction.to_dataset(name="predictions")

# Get ground truth data
ds_groundtruth = test_data.inputs.isel(time=slice(hist+1, hist+1+n_steps))
ds_groundtruth = convert_train_data(ds_groundtruth)

# Post-process predictions to match ground truth format
ds_prediction = post_processor(ds_prediction, ds_groundtruth, vars_list_str)

And plotting the results.

In [None]:
# Now let's plot the initial and final states for the 2.5m depth potential temperature
# First, let's define a helper function for plotting
def plot_ocean_temperature(ax, data, title, cmap='viridis', vmin=None, vmax=None):
    """Helper function to plot ocean temperature with a nice layout."""
    # Create a mask for land areas (NaN values in the data)
    mask = np.isnan(data)
    
    # Create a masked array
    masked_data = np.ma.array(data, mask=mask)
    
    # Plot the data
    im = ax.imshow(masked_data, origin='lower', cmap=cmap, vmin=vmin, vmax=vmax)
    ax.set_title(title)
    
    # Return the image for colorbar
    return im

# Set up the figure for plotting
plt.figure(figsize=(16, 10))

# Determine min and max values for consistent colormap
all_temps = []
for t in [0, 1, 8, 9]:
    if t < 2:  # Initial states
        # At time 0 and 1, we use the input data
        if t == 0:
            # First time step is in first half of channels
            temp = test_data[0][0][0, 0].cpu().numpy()  # Channel 0 for potential temperature at 2.5m depth
        else:
            # Second time step is in second half of channels
            temp = test_data[0][0][0, 77].cpu().numpy()  # Channel 77 for potential temperature at 2.5m depth
    else:  # Final states
        # For later time steps, use the model output
        temp = ds_prediction.thetao.isel(time=t-2, lev=0).values
    
    all_temps.append(temp[~np.isnan(temp)])  # Collect non-NaN values

# Compute global min and max for colormap
vmin = min([np.min(temp) for temp in all_temps])
vmax = max([np.max(temp) for temp in all_temps])

# Create subplots
axes = []
for i, t in enumerate([0, 1, 8, 9]):
    ax = plt.subplot(2, 2, i+1)
    axes.append(ax)
    
    if t < 2:  # Initial states
        # At time 0 and 1, we use the input data
        if t == 0:
            # First time step is in first half of channels
            temp = test_data[0][0][0, 0].cpu().numpy()  # Channel 0 for potential temperature at 2.5m depth
            title = f"Initial State (t=0)"
        else:
            # Second time step is in second half of channels
            temp = test_data[0][0][0, 77].cpu().numpy()  # Channel 77 for potential temperature at 2.5m depth
            title = f"Initial State (t=1)"
    else:  # Final states
        # For later time steps, use the model output
        temp = ds_prediction.thetao.isel(time=t-2, lev=0).values
        title = f"Final State (t={t})"
    
    im = plot_ocean_temperature(ax, temp, title, cmap='viridis', vmin=vmin, vmax=vmax)

# Add a colorbar
cbar_ax = plt.figure().add_axes([0.15, 0.05, 0.7, 0.02])
plt.colorbar(im, cax=cbar_ax, orientation='horizontal', label='Potential Temperature at 2.5m depth (Â°C)')

plt.tight_layout()
plt.show()

# Save the timestep=9 output for comparison with adjoint results
final_state_t9 = ds_prediction.thetao.isel(time=7, lev=0).values
np.save('final_state_t9.npy', final_state_t9)

print("Plotting complete! We can see how the model simulates ocean temperature evolution.")