https://nvidia.github.io/earth2studio/userguide/about/install.html

Install with uv
```csh
mkdir earth2studio-pangu && cd earth2studio-pangu
uv init --python=3.12
uv add "earth2studio @ git+https://github.com/NVIDIA/earth2studio.git"
uv add earth2studio --extra pangu
uv add earth2studio --extra perturbation
uv add earth2studio --extra data
source .venv/bin/activate.csh
uv pip install matplotlib ipykernel cartopy ipywidgets
python -m ipykernel install --user  --name="earth2studio-pangu"
```

In [None]:
import os
import sys

os.environ["LOGURU_LEVEL"] = "INFO"  # Change to "DEBUG", "WARNING", "ERROR", etc.
import cartopy.crs as ccrs
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import xarray as xr
from earth2studio.data import GFS, IFS, fetch_data
from earth2studio.io import NetCDF4Backend
from earth2studio.models.px import Pangu6
from earth2studio.perturbation import (
    BredVector,
    CorrelatedSphericalGaussian,
    Gaussian,
    HemisphericCentredBredVector,
    SphericalGaussian,
    Zero,
)
from earth2studio.run import deterministic, ensemble
from earth2studio.utils.coords import map_coords
from loguru import logger

In [None]:
# Tried FengWu but WARNING  | earth2studio.perturbation.bv:__call__:97 - Input data / models that require multiple lead times may lead to unexpected behavior
# and RuntimeError: Error in execution: Non-zero status code returned while running MatMul node. Name:'/enc/enc_list.0/layers.0/blocks.0/attn/proj/MatMul' Status Message: matmul_helper.h:59 Compute MatMul dimension mismatch
Model = Pangu6
package = Model.load_default_package()
model = Model.load_model(package)

In [None]:
!nvidia-smi

In [None]:
forecast_time = pd.Timestamp("20240510")
coords = model.input_coords()

control, coords = fetch_data(
    source=GFS(),
    time=[forecast_time],
    variable=coords["variable"],
    lead_time=coords["lead_time"],
)

In [None]:
noise_tensor = control.std(dim=[-2, -1])*1e-8  # std of variables
noise_tensor = noise_tensor.unsqueeze(-1).unsqueeze(-1)

In [None]:
device = torch.device("cuda")
bv = BredVector(
    model=model.to(device),
    noise_amplitude=noise_tensor,
    integration_steps=20,
    ensemble_perturb=False,
)
hcbv = HemisphericCentredBredVector(
    model.to(device),
    noise_amplitude=amplitude,
    data=GFS(),
    seeding_perturbation_method=CorrelatedSphericalGaussian(noise_amplitude=amplitude),
)

In [None]:
# --- 1. Setup Control State ---
# We use 'z500' (Geopotential Height) as it's great for visualizing perturbations.
target_variable = "z700"

# Create a boolean mask for the variable we want
var_mask = coords["variable"] == target_variable

# Slice the control state to just z500 for plotting and simple perturbations
x_control = control[:, :, var_mask]

# Base control on CPU for plotting
x_control_cpu = x_control.cpu()

amplitude = 0.05

print("Applying Gaussian Perturbation...")
pert_gauss = Gaussian(noise_amplitude=amplitude)
x_gauss, _ = pert_gauss(x_control_cpu, coords)

print("Applying SphericalGaussian Perturbation...")
pert_sphere = SphericalGaussian(noise_amplitude=amplitude)
x_sphere, _ = pert_sphere(x_control_cpu, coords)

print("Applying CorrelatedSphericalGaussian Perturbation...")
pert_corr = CorrelatedSphericalGaussian(noise_amplitude=amplitude)
x_corr, _ = pert_corr(x_control_cpu, coords)

print("Applying BredVector Perturbation...")
# 1. Inference model needs the FULL state (69 vars), so we use 'control' not 'x_control'
# 2. Move full state to GPU
x_control_gpu = control.to(device)
x_bv_full, _ = bv(x_control_gpu, coords)

# 3. Slice the output to get JUST the z500 channel to match the others
#    We use the same 'var_mask' we calculated at the top
x_bv = x_bv_full[:, :, var_mask].cpu()

print("Applying HemisphericCentredBredVector Perturbation...")
# Run HCBV (Returns full state with perturbation)
x_hcbv_full, _ = hcbv(x_control_gpu, coords)

# 3. Slice the output to get JUST the z500 channel to match the others
#    We use the same 'var_mask' we calculated at the top
x_hcbv = x_hcbv_full[:, :, var_mask].cpu()

# --- 2. Calculate Noise Fields ---
# Now all tensors have shape [Batch, 1, Lat, Lon]
noise_gauss = x_gauss - x_control_cpu
noise_sphere = 10 * (x_sphere - x_control_cpu) # Keep your 10x scaling
noise_corr = x_corr - x_control_cpu
noise_bv = x_bv - x_control_cpu
noise_hcbv = x_hcbv - x_control_cpu

# --- 3. Visualize in Orthographic Projection ---
print("Plotting...")

def get_field(tensor):
    # Extract 2D field: [batch, var, time, lat, lon] -> [lat, lon]
    return tensor[0, 0, 0, :, :].numpy()

# Data mapping for the loop
noise_data_map = {
    1: ("Gaussian (White)", noise_gauss),
    2: ("SphericalGaussian (x10)", noise_sphere),
    3: ("CorrelatedSpherical (Red)", noise_corr),
    4: ("Bred Vector (Flow)", noise_bv),
    5: ("Hemispheric Centred Bred Vector (Flow)", noise_hcbv),
}

# Create a 2x3 Grid (6 slots) to fit 5 plots comfortably
fig, axes = plt.subplots(
    2, 4, figsize=(19, 9),
    subplot_kw={'projection': ccrs.Orthographic(central_latitude=45, central_longitude=-55)}
)
axes = axes.flatten()

# Common setup function for maps
def setup_map(ax):
    ax.coastlines(resolution='110m', color='black', alpha=0.4)
    ax.gridlines(color='gray', alpha=0.3)

# --- Plot 0: The Control State (Absolute Values) ---
ax_ctrl = axes[0]
setup_map(ax_ctrl)
field_ctrl = get_field(x_control_cpu)

im_ctrl = ax_ctrl.imshow(
    field_ctrl, transform=ccrs.PlateCarree(), origin='upper',
    cmap='viridis', vmin=field_ctrl.min(), vmax=field_ctrl.max()
)
ax_ctrl.set_title(f"Control (GFS {target_variable})")
# Small colorbar just for control
fig.colorbar(im_ctrl, ax=ax_ctrl, fraction=0.046, pad=0.04, orientation='horizontal')

# --- Plots 1-4: The Noise Fields (Diverging Scale) ---

# Determine shared scaling for noise plots to make comparisons fair
# This caused the error before, but now all tensors are size 1 in dim 1
all_noise_tensors = torch.cat([noise_gauss, noise_sphere, noise_corr, noise_bv, noise_hcbv])
max_abs_noise = torch.abs(all_noise_tensors).max().item()
noise_vmin = -max_abs_noise
noise_vmax = max_abs_noise

print(f"Noise color scale set to: +/- {noise_vmax:.1f}")

for i, (title, tensor) in noise_data_map.items():
    ax = axes[i]
    setup_map(ax)
    field = get_field(tensor)

    im_noise = ax.imshow(
        field, transform=ccrs.PlateCarree(), origin='upper',
        extent=[coords['lon'].min(), coords['lon'].max(), coords['lat'].min(), coords['lat'].max()],
        cmap='RdBu_r', vmin=noise_vmin, vmax=noise_vmax
    )
    ax.set_title(title)

# --- Clean up and Shared Colorbar ---
# Hide last 2 (bottom right) since we only have 6 plots
axes[6].axis('off')
axes[7].axis('off')

# Create a large shared colorbar in the empty space of the 6th slot
cbar_ax = fig.add_axes([0.75, 0.15, 0.02, 0.3]) # Adjust position to sit in the empty slot area
fig.colorbar(im_noise, cax=cbar_ax, label='Perturbation Magnitude')

plt.suptitle(f"Perturbation Method Comparison ({forecast_time.date()})", y=0.98, fontsize=14)

In [None]:
noise_sphere.std(), noise_corr.std(), noise_bv.std(), noise_hcbv.std()

In [None]:
device = torch.device("cuda")

xi, _ = hcbv(control.to(device), coords)
output_coords = model.output_coords(coords)
x, out_coords = (model.to(device))(xi, model.input_coords())
for i in range(1):
    x, out_coords = (model.to(device))(x, out_coords)
out_coords["batch"] = [0]

In [None]:
da = xr.DataArray(
    x.cpu().numpy(),
    coords=out_coords,
)

p = da.sel(variable='z500').plot(
    subplot_kws={"projection": ccrs.Orthographic(central_latitude=45, central_longitude=-55)},
    transform=ccrs.PlateCarree(),
    figsize=(10,10),
)

setup_map(p.axes)

In [None]:
amplitude = 0.05
from pathlib import Path
odir = Path(os.getenv("SCRATCH")) / "tmp"

perturbations = [Zero(), hcbv, bv, CorrelatedSphericalGaussian(amplitude), SphericalGaussian(amplitude*10)]
# Yes. SphericalGaussian is 10x Gaussian and CorrelatedSphericalGaussian. Otherwise the perturbation is much smaller than the others.
for perturbation in perturbations:
    ofile = odir / f"{Model.__name__}.{perturbation.__class__.__name__ }.nc"
    print(ofile)
    if os.path.exists(ofile):
        if os.path.getsize(ofile) < 110000:
            os.remove(ofile)
        else:
            continue
    io = NetCDF4Backend(ofile)
    results = ensemble(
        time=[forecast_time],
        nsteps=3,
        nensemble=2,
        prognostic=model,
        data=GFS(),
        io=io,
        perturbation=perturbation,
        device='cuda',
    )
    io.close()

In [None]:
x_control.shape

In [None]:
ifiles = [odir / f"{Model.__name__}.{p.__class__.__name__}.nc" for p in perturbations]
ds_out = xr.open_mfdataset(ifiles, combine="nested", concat_dim="perturbation").assign_coords(
    perturbation=[p.__class__.__name__ for p in perturbations]
)
ds_out

In [None]:
control_da = xr.DataArray(
    control.cpu().numpy(),
    coords=out_coords,
)

p = (ds_out.sel(lead_time=0) - control_da.to_dataset(dim="variable")).std(dim=["ensemble", "lat", "lon"])

In [None]:
ds_out.sel(perturbation="BredVector").z500.mean(dim=['lat','lon']).load()

In [None]:
device = torch.device("cuda")
output_coords = model.output_coords(coords)
control, coords = map_coords(control, coords, output_coords)
out, out_coords = (model.to(device))(control.to(device), model.input_coords())
out_coords["batch"] = [0]

In [None]:
out.cpu().numpy().shape

In [None]:
out = xr.DataArray(
    out.cpu().numpy(),
    coords=out_coords,
)

out.mean(dim=['lat','lon'])

In [None]:
p.u300.load()

In [None]:
variable = "z500"

perturbation = ds_out.sel(lead_time=0) - control_da.sel(variable=variable)
variable_to_plot = perturbation[variable]
# Uncomment to scale a perturbation in-place
variable_to_plot.loc[dict(perturbation="BredVector")] /= 40000

fg = variable_to_plot.plot(
    col="perturbation",
    row="ensemble",
    subplot_kws={"projection": ccrs.Orthographic(central_latitude=45, central_longitude=-55)},
    transform=ccrs.PlateCarree(),
)

for ax in fg.axes.flat:
    setup_map(ax)

In [None]:
variable_to_plot.groupby("perturbation").std(dim=["ensemble", "lat", "lon"]).compute()

In [None]:
variable

In [None]:
mean = ds_out.mean("ensemble")
fg = mean[variable].plot(
    col="perturbation",
    row="lead_time",
    subplot_kws={"projection": ccrs.Orthographic(central_latitude=45, central_longitude=-55)},
    transform=ccrs.PlateCarree(),
)
for ax in fg.axes.flat:
    setup_map(ax)

In [None]:
spread = ds_out.std("ensemble")
spread[variable].plot(col="perturbation", row="lead_time")