# HENS

First, some background info on HENS and that there's more to explore
in the folder...

Quite a few configs to be tweaked, so we will set the most important ones here
and then turn them into a config opbject afterwards.







In [1]:
project = 'helene'

start_times = ["2024-09-24 12:00:00"]
nsteps = 16
nensemble = 4
batch_size = 2

model_packages = '/media/mkoch/9ee63bf8-5a14-4872-86f2-7f16b120269b/hens_data/hens_checkpoints'
max_num_checkpoints = 2

output_vars = ["t2m", 'u10m', 't850', 'q850', 'z500']
out_dir = './outputs'


next, let's do some imports and fully configure the inference. for more detailed
info about the configurations, see the README in the 11_hens folder.

In [2]:
from omegaconf import DictConfig
from physicsnemo.distributed import DistributedManager
import sys
import os
from dotenv import load_dotenv

sys.path.append(os.path.join(os.getcwd(), '11_hens'))
load_dotenv()
DistributedManager.initialize()

# Create the configuration dictionary
cfg = DictConfig({
    'project': project,
    'random_seed': 377778,
    'start_times': start_times,
    'nsteps': nsteps,          # number of forecasting steps
    'nensemble': nensemble,       # ensemble size per checkpoint
    'batch_size': batch_size,      # inference batch size

    'forecast_model': {
        'architecture': 'earth2studio.models.px.SFNO',   # forecast model class
        'package': model_packages,
        'max_num_checkpoints': max_num_checkpoints  # max number of checkpoints which will be used
    },

    'data_source': {
        '_target_': 'earth2studio.data.GFS'  # data source class
    },

    'cyclone_tracking': {
        'out_dir': out_dir
    },

    'file_output': {
        'path': out_dir,       # directory to which outfiles are written
        'output_vars': output_vars,
        'thread_io': False,      # write out in separate thread
        'format': {              # io backend class
            '_target_': 'earth2studio.io.NetCDF4Backend',
            '_partial_': True,
            'backend_kwargs': {
                'mode': 'w',
                'diskless': False,
                'persist': False,
                'chunks': {
                    'ensemble': 1,
                    'time': 1,
                    'lead_time': 1
                }
            }
        }
    }
})

- why is batch generation progress not shown? To be fixed (cursor knows how, but it isn't clean)

In [None]:

import pandas as pd
from ensemble_utilities import EnsembleBase
from reproduce_utilities import create_base_seed_string

from utilities import (
    initialise,
    initialise_output,
    store_tracks,
    update_model_dict,
    write_to_disk,
)

(
    ensemble_configs,
    model_dict,
    dx_model_dict,
    cyclone_tracking,
    data,
    output_coords_dict,
    base_random_seed,
    all_tracks_dict,
    _, _
) = initialise(cfg)


ensemble configs include XYZ and represent blah. Let's have a look at the content and explore how many cases we will run, depending on number of ensemble members, batch size, number of checkpoints and initial conditions.

The inference is parallelised across enssemble config entries, hence across IC-package pairs. This also menas that you cannot use more GPUs than number of ICs multiplied by number of checkpoints. If more GPUs are available, they remain idle. 

In [None]:
for ii, (pkg, ic, ens_idx, batch_ids_produce) in enumerate(ensemble_configs):
    print(f'ensemble config {ii+1} of {len(ensemble_configs)}:')
    print(f'    package: {pkg}')
    print(f'    initial condition: {ic}')
    print(f'    ensemble index of first member: {ens_idx}')
    print(f'    batch ids to produce: {batch_ids_produce}\n')


model dict includes model and weights. Let's have a look at its contents. At each loop, it will be updated according to pkg in provided in the ensemble config.

In [None]:
from termcolor import colored

print(colored('The model class is:', attrs=['bold']))
print(model_dict['class'], '\n')

print(colored('The model package (weights), which is currently loaded:', attrs=['bold']))
print(model_dict['package'], '\n')

print(colored('The fully initialised model is provided in:', attrs=['bold']))
print(model_dict['model'].parameters, '\n')

need to set up HENS perturbation.
to see how it is aseembled form basic blocks porvided in e2studio, have a look into `11_hens/hens_perturbation.py`.

In [6]:
# for perturbation
skill_path = "/media/mkoch/9ee63bf8-5a14-4872-86f2-7f16b120269b/hens_data/hens_checkpoints/d2m_sfno_linear_74chq_sc2_layers8_edim620_wstgl2-epoch70_seed16.nc"
noise_amplification = 0.35
perturbed_var = ["z500"]
integration_steps = 3

from hens_perturbation import HENSPerturbation

from numpy import ndarray, datetime64

from earth2studio.data import DataSource
from earth2studio.models.px import PrognosticModel
from earth2studio.perturbation import Perturbation

def initialise_perturbation(
    model: PrognosticModel,
    data: DataSource,
    start_time: ndarray[datetime64],
) -> Perturbation:
    perturbation = HENSPerturbation(
        model=model,
        data=data,
        start_time=start_time,
        skill_path=skill_path,
        noise_amplification=noise_amplification,
        perturbed_var=perturbed_var,
        integration_steps=integration_steps
    )

    return perturbation

now bring everyhting together:
- loop over ensemble configs
- update model dict (if package has changed)
- initialise output
- initialise perturbation (as ICs might have changed)
- run inference, where all ensemble members are produced
- write to disk

In [None]:
# loop over ensemble configs
for pkg, ic, ens_idx, batch_ids_produce in ensemble_configs:
    # create seed base string required for reproducibility of individual batches
    base_seed_string = create_base_seed_string(pkg, ic, base_random_seed)

    # load new weights if necessary
    model_dict = update_model_dict(model_dict, pkg)

    # create new io object
    io_dict = initialise_output(cfg, ic, model_dict, output_coords_dict)

    # initialise perturbation with updated IC and checkpoint
    perturbation = initialise_perturbation(
        model=model_dict["model"], data=data, start_time=ic
    )

    # initialise inference pipeline with updated IC and checkpoint
    run_hens = EnsembleBase(
        time=[ic],
        nsteps=cfg.nsteps,
        nensemble=cfg.nensemble,
        prognostic=model_dict["model"],
        data=data,
        io_dict=io_dict,
        perturbation=perturbation,
        output_coords_dict=output_coords_dict,
        dx_model_dict=dx_model_dict,
        cyclone_tracking=cyclone_tracking,
        batch_size=cfg.batch_size,
        ensemble_idx_base=ens_idx,
        batch_ids_produce=batch_ids_produce,
        base_seed_string=base_seed_string,
    )

    # run inference
    df_tracks_dict, io_dict = run_hens()

    # store tracks
    for k, v in df_tracks_dict.items():
        v["ic"] = pd.to_datetime(ic)
        all_tracks_dict[k].append(v)

    # if in-memory flavour of io backend was chosen, write content to disk now
    if io_dict:
        _, _ = write_to_disk(
            cfg,
            ic,
            model_dict,
            io_dict,
            None,
            None,
        )

# write cyclone tracks to disk
if "cyclone_tracking" in cfg:
    for area_name, all_tracks in all_tracks_dict.items():
        store_tracks(area_name, all_tracks, cfg)

Now, there should be a folder called "output" with all the forecasts.
Also the tracks should be in there.

plot tracks and fields.

as we can see there are 4 members, one IC (time), 4 lead time in the field output

track headers mean xyz...


In [None]:
import xarray as xr
from plotting.fork_n_spoon import extract_tracks_from_csv

ds = xr.load_dataset('outputs/global/helene_2024-09-24T12_pkg_seed102.nc')
display(ds)

tracks = pd.read_csv('outputs/global/helene_tracks_rank_000.csv', sep=',')
print('tracks columns:')
print(list(tracks.columns))


plot global fields through:

In [None]:
ds['t2m'].isel(ensemble=0, lead_time=0, time=0).plot(figsize=(16, 6))

let's extract helene tracks

In [None]:
helene_ibtracs_coords = pd.DataFrame({
    'time': pd.date_range('2024-09-21 18:00:00', '2024-09-28 12:00:00', freq='3h'),
    'lat': [13.60, 13.80, 14.00, 14.20, 14.40, 14.60, 14.80, 15.00, 15.20, 15.40,
            15.60, 15.70, 16.00, 16.60, 17.20, 17.60, 17.90, 18.10, 18.20, 18.40,
            18.60, 19.00, 19.30, 19.40, 19.40, 19.50, 19.80, 20.00, 20.30, 20.70,
            21.10, 21.50, 22.00, 22.40, 22.80, 23.20, 23.60, 24.10, 24.70, 25.60,
            26.70, 27.70, 28.70, 29.90, 31.30, 32.90, 34.40, 35.70, 36.70, 37.60,
            38.10, 37.90, 37.40, 37.00, 36.60],
    'lon': [277.3, 277.4, 277.4, 277.4, 277.4, 277.4, 277.4, 277.4, 277.4,
            277.4, 277.5, 277.7, 278. , 278.1, 278.2, 278.2, 278.1, 278. ,
            277.9, 277.7, 277.3, 276.8, 276.3, 275.8, 275.4, 275. , 274.7,
            274.4, 274.1, 273.9, 273.8, 273.7, 273.5, 273.4, 273.3, 273.3,
            273.5, 273.7, 274.1, 274.6, 275.1, 275.4, 275.7, 276.2, 276.7,
            276.9, 276.8, 276.1, 275.1, 274.2, 273.4, 272.5, 272. , 272. ,
            272.4]
})

# tracks = pd.read_csv('outputs/global/helene_tracks_rank_000.csv', sep=',')
track_list, _ = extract_tracks_from_csv('outputs/global/helene_tracks_rank_000.csv',
                                ic=start_times[0],
                                tc_centres=helene_ibtracs_coords,
                                max_dist=2.5,
                                min_len=4,
                                max_stp=nsteps)

print(f'found {len(track_list)} tracks')

now, let's focus on the gulf of mexico and plot the tracks of Hurricane Helene:

In [36]:
variable = 'u10m'
ensemble_member = 1

max_frames = 17 # maximum number of frames to plot
scale = 1

lat_min, lat_max = 10, 40
lon_min, lon_max = 250, 300

In [14]:
import cartopy.crs as ccrs
import cartopy.feature as cfeature
# import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import numpy as np
# import pandas as pd
# from matplotlib.colors import TwoSlopeNorm
import matplotlib.animation as animation
from cartopy.mpl.ticker import LongitudeFormatter, LatitudeFormatter

# import xarray as xr

dx = scale*.25

countries = cfeature.NaturalEarthFeature(
    category="cultural",
    name="admin_0_countries",
    scale="110m",
    facecolor="none",
    edgecolor="black",
)

reg_ds = ds.sel(lat=list(np.arange(lat_min, lat_max, dx)),
                lon=list(np.arange(lon_min, lon_max, dx)))

time_str = 'lead time:'
projection=ccrs.PlateCarree()
var_ds = reg_ds[variable] # np.sqrt(np.square(reg_ds.u10m) + np.square(reg_ds.v10m))

min_val = float(np.min(var_ds[ensemble_member,0,:,:,:]))
max_val = float(np.max(var_ds[ensemble_member,0,:,:,:]))


In [None]:
# import matplotlib.colors as colors

# define plots
def make_figure():
    fig = plt.figure(figsize=(11,5))
    ax = fig.add_subplot(1, 1, 1, projection=projection)

    ax.add_feature(cfeature.COASTLINE,lw=.5)
    ax.add_feature(cfeature.RIVERS,lw=.5)

    lon_formatter = LongitudeFormatter(zero_direction_label=False)
    lat_formatter = LatitudeFormatter()
    ax.xaxis.set_major_formatter(lon_formatter)
    ax.yaxis.set_major_formatter(lat_formatter)

    return fig, ax

def make_frame(frame):
    print(f'\rprocessing frame {frame+1} of {min(max_frames, var_ds.shape[2])}', end='')
    plot_ds = var_ds[ensemble_member, 0, max(frame,0), :, :]
    pc = ax.pcolormesh(reg_ds.lon, reg_ds.lat, plot_ds, transform=projection,
                       cmap='plasma',
                       vmin=min_val, vmax=max_val
                       )

    if frame==-1:
        cbar = fig.colorbar(pc, extend='both', shrink=0.8, ax=ax)
    else:
        track = track_list[ensemble_member]
        max_len = min(frame, len(track['lon']))
        ax.plot(track['lon'][:max_len]-360, track['lat'][:max_len],
                color='white', linewidth=2, alpha=1)

    header = time_str + " " + f'{frame*6}:00:00'
    ax.set_title(header, fontsize=14)

    return pc

def animate(frame):
    return make_frame(frame)

def first_frame():
    return make_frame(-1)

# make animation
%matplotlib inline
plt.rcParams["animation.html"] = "jshtml"
fig, ax = make_figure()
ani = animation.FuncAnimation(fig,
                              animate,
                              min(max_frames, var_ds.shape[2]),
                              init_func=first_frame,
                              blit=False,
                              repeat=False,
                              interval=.1)
plt.close('all')
ani

and finally, the full spaghetti plot

In [None]:
plt.close('all')

# fig_size = (lon_max-lon_min,
#             lat_max-lat_min)

# Set the style to dark background
# plt.style.use('dark_background')

# Create figure and axis
# fig, ax = plt.subplots(figsize=(11, 5),  projection=ccrs.PlateCarree())
fig = plt.figure(figsize=(11,5))
ax = fig.add_subplot(1, 1, 1, projection=projection)

ax.add_feature(cfeature.COASTLINE,lw=.5)
ax.add_feature(cfeature.RIVERS,lw=.5)
ax.add_feature(cfeature.OCEAN)
ax.add_feature(cfeature.LAND)

lon_formatter = LongitudeFormatter(zero_direction_label=False)
lat_formatter = LatitudeFormatter()
ax.xaxis.set_major_formatter(lon_formatter)
ax.yaxis.set_major_formatter(lat_formatter)

# Plot the line in white
for track in track_list:
    ax.plot(track['lon']-360, track['lat'],
            color='crimson', linewidth=2, alpha=.4)

ax.set_extent([lon_min, lon_max, lat_min, lat_max])
plt.show()