# HENS

First, some background info on HENS and that there's more to explore
in the folder...

In [1]:
from hydra import initialize, compose
from omegaconf import DictConfig
import sys
import os
from dotenv import load_dotenv

sys.path.append(os.path.join(os.getcwd(), '11_hens'))
load_dotenv()

# Import the module
from hens import main

def load_config(config_name: str = "config") -> DictConfig:
    with initialize(version_base=None, config_path="./11_hens/conf/"):
        cfg = compose(config_name=config_name)
        return cfg

cfg = load_config("helene_nb.yaml")  # Loads helene.yaml


- why is batch generation progress not shown? To be fixed (cursor knows how, but it isn't clean)

In [2]:
# main(cfg)

In [None]:

from datetime import datetime

import pandas as pd
from ensemble_utilities import EnsembleBase
from physicsnemo.distributed import DistributedManager
from reproduce_utilities import create_base_seed_string

from utilities import (
    initialise,
    initialise_output,
    store_tracks,
    update_model_dict,
    write_to_disk,
)

DistributedManager.initialize()

(
    ensemble_configs,
    model_dict,
    dx_model_dict,
    cyclone_tracking,
    data,
    output_coords_dict,
    base_random_seed,
    all_tracks_dict,
    writer_executor,
    writer_threads
) = initialise(cfg)




[32m2025-04-15 14:39:34.254[0m | [1mINFO    [0m | [36mutilities[0m:[36mpair_packages_ics[0m:[36m330[0m - [1mrank 0: predicting from following models/initial times: [('/media/mkoch/9ee63bf8-5a14-4872-86f2-7f16b120269b/hens_data/hens_checkpoints/sfno_linear_74chq_sc2_layers8_edim620_wstgl2-epoch70_seed102', numpy.datetime64('2024-09-24T12:00:00.000000000'), 0, [0, 1])][0m
pkg='/media/mkoch/9ee63bf8-5a14-4872-86f2-7f16b120269b/hens_data/hens_checkpoints/sfno_linear_74chq_sc2_layers8_edim620_wstgl2-epoch70_seed102' ic=numpy.datetime64('2024-09-24T12:00:00.000000000') ens_idx=0 batch_ids_produce=[0, 1]


ensemble configs include XYZ and represent blah. Let's have a look at the content and explore how many cases we will run, depending on number of ensemble members, batch size, number of checkpoints and initial conditions.

In [7]:
for ii, (pkg, ic, ens_idx, batch_ids_produce) in enumerate(ensemble_configs):
    print(f'{ii=} {pkg=} {ic=} {ens_idx=} {batch_ids_produce=}')

ii=0 pkg='/media/mkoch/9ee63bf8-5a14-4872-86f2-7f16b120269b/hens_data/hens_checkpoints/sfno_linear_74chq_sc2_layers8_edim620_wstgl2-epoch70_seed102' ic=numpy.datetime64('2024-09-24T12:00:00.000000000') ens_idx=0 batch_ids_produce=[0, 1]


model dict includes model and weights. Let's have a look at its contents. At each loop, it will be updated according to pkg in provided in the ensemble config.

In [21]:
print('The model class is:')
print(model_dict['class'], '\n')

print('The model package (weights), which is currently loaded:')
print(model_dict['package'], '\n')

print('The fully initialised model is provided in:')
print(model_dict['model'].parameters, '\n')

The model class is:
<class 'earth2studio.models.px.sfno.SFNO'> 

The model package (weights), which is currently loaded:
/media/mkoch/9ee63bf8-5a14-4872-86f2-7f16b120269b/hens_data/hens_checkpoints/sfno_linear_74chq_sc2_layers8_edim620_wstgl2-epoch70_seed102 

The fully initialised model is provided in:
<bound method Module.parameters of SFNO(
  (model): ModelWrapper(
    (model): SingleStepWrapper(
      (preprocessor): Preprocessor2D()
      (model): SphericalFourierNeuralOperatorNet(
        (trans_down): RealSHT(
          nlat=721, nlon=1440,
           lmax=360, mmax=361,
           grid=equiangular, csphase=True
        )
        (itrans_up): InverseRealSHT(
          nlat=721, nlon=1440,
           lmax=360, mmax=361,
           grid=equiangular, csphase=True
        )
        (trans): RealSHT(
          nlat=360, nlon=720,
           lmax=360, mmax=361,
           grid=legendre-gauss, csphase=True
        )
        (itrans): InverseRealSHT(
          nlat=360, nlon=720,
      

In [4]:
from hens_perturbation import HENSPerturbation

from numpy import ndarray, datetime64

from earth2studio.data import DataSource
from earth2studio.models.px import PrognosticModel
from earth2studio.perturbation import Perturbation

def initialise_perturbation(
    model: PrognosticModel,
    data: DataSource,
    start_time: ndarray[datetime64],
    cfg: DictConfig,
) -> Perturbation:
    perturbation = HENSPerturbation(
        model=model,
        data=data,
        start_time=start_time,
        skill_path=cfg.perturbation.skill_path,
        noise_amplification=cfg.perturbation.noise_amplification,
        perturbed_var=cfg.perturbation.perturbed_var,
        integration_steps=cfg.perturbation.integration_steps
    )

    return perturbation

In [5]:

# run forecasts
for pkg, ic, ens_idx, batch_ids_produce in ensemble_configs:
    # create seed base string required for reproducibility of individual batches
    base_seed_string = create_base_seed_string(pkg, ic, base_random_seed)

    # load new weights if necessary
    model_dict = update_model_dict(model_dict, pkg)

    io_dict = initialise_output(cfg, ic, model_dict, output_coords_dict)

    perturbation = initialise_perturbation(
        model=model_dict["model"], data=data, start_time=ic, cfg=cfg
    )

    run_hens = EnsembleBase(
        time=[ic],
        nsteps=cfg.nsteps,
        nensemble=cfg.nensemble,
        prognostic=model_dict["model"],
        data=data,
        io_dict=io_dict,
        perturbation=perturbation,
        output_coords_dict=output_coords_dict,
        dx_model_dict=dx_model_dict,
        cyclone_tracking=cyclone_tracking,
        batch_size=cfg.batch_size,
        ensemble_idx_base=ens_idx,
        batch_ids_produce=batch_ids_produce,
        base_seed_string=base_seed_string,
    )
    df_tracks_dict, io_dict = run_hens()
    for k, v in df_tracks_dict.items():
        v["ic"] = pd.to_datetime(ic)
        all_tracks_dict[k].append(v)

    # if in-memory flavour of io backend was chosen, write content to disk now
    if io_dict:
        writer_threads, writer_executor = write_to_disk(
            cfg,
            ic,
            model_dict,
            io_dict,
            writer_threads,
            writer_executor,
            ens_idx,
        )

# Output summaries of cyclone tracks if required
if "cyclone_tracking" in cfg:
    for area_name, all_tracks in all_tracks_dict.items():
        store_tracks(area_name, all_tracks, cfg)

if writer_executor is not None:
    for thread in list(writer_threads):
        thread.result()
        writer_threads.remove(thread)
    writer_executor.shutdown()

Fetching GFS for 2024-09-23 18:00:00: 100%|██████████| 74/74 [00:01<00:00, 44.61it/s]
Fetching GFS for 2024-09-24 00:00:00: 100%|██████████| 74/74 [00:01<00:00, 55.15it/s]
Fetching GFS for 2024-09-24 06:00:00: 100%|██████████| 74/74 [00:01<00:00, 55.40it/s]
Fetching GFS for 2024-09-24 12:00:00: 100%|██████████| 74/74 [00:01<00:00, 54.13it/s]


[32m2025-04-15 14:39:40.983[0m | [1mINFO    [0m | [36mensemble_utilities[0m:[36m__init__[0m:[36m101[0m - [1mSetting up HENS.[0m
[32m2025-04-15 14:39:40.983[0m | [1mINFO    [0m | [36mensemble_utilities[0m:[36mmove_models_to_device[0m:[36m151[0m - [1mInference device: cuda[0m


Fetching GFS for 2024-09-24 12:00:00: 100%|██████████| 74/74 [00:01<00:00, 55.80it/s]


[32m2025-04-15 14:39:43.673[0m | [32m[1mSUCCESS [0m | [36mensemble_utilities[0m:[36mfetch_ics[0m:[36m192[0m - [32m[1mFetched data from GFS[0m
[32m2025-04-15 14:39:43.676[0m | [1mINFO    [0m | [36mensemble_utilities[0m:[36m__call__[0m:[36m324[0m - [1mStarting 4 Member Ensemble inference with 2 number of batches.[0m


Total Ensemble Batches: 100%|██████████| 2/2 [01:01<00:00, 30.58s/it]

[32m2025-04-15 14:40:44.848[0m | [32m[1mSUCCESS [0m | [36mensemble_utilities[0m:[36m__call__[0m:[36m381[0m - [32m[1mInference complete[0m



