# HENS

First, some background info on HENS and that there's more to explore
in the folder...

Quite a few configs to be tweaked, so we will set the most important ones here
and then turn them into a config opbject afterwards.







In [1]:
project = 'helene'

start_times = ["2024-09-24 12:00:00"]
nsteps = 3
nensemble = 4
batch_size = 2

model_packages = '/media/mkoch/9ee63bf8-5a14-4872-86f2-7f16b120269b/hens_data/hens_checkpoints'
max_num_checkpoints = 2


next, fully configure the inference, do some imports and basic initialisations.

In [2]:
from omegaconf import DictConfig
from physicsnemo.distributed import DistributedManager
import sys
import os
from dotenv import load_dotenv

sys.path.append(os.path.join(os.getcwd(), '11_hens'))
load_dotenv()
DistributedManager.initialize()

# Create the configuration dictionary
cfg = DictConfig({
    'project': project,
    'random_seed': 377778,
    'start_times': start_times,
    'nsteps': nsteps,          # number of forecasting steps
    'nensemble': nensemble,       # ensemble size per checkpoint
    'batch_size': batch_size,      # inference batch size

    'forecast_model': {
        'architecture': 'earth2studio.models.px.SFNO',   # forecast model class
        'package': model_packages,
        'max_num_checkpoints': max_num_checkpoints  # max number of checkpoints which will be used
    },

    'data_source': {
        '_target_': 'earth2studio.data.GFS'  # data source class
    },

    'cyclone_tracking': {
        'out_dir': './outputs'
    },

    'file_output': {
        'path': './outputs',       # directory to which outfiles are written
        'output_vars': ["t2m", 'u10m', 't850', 'q850', 'z500'],
        'thread_io': False,      # write out in separate thread
        'format': {              # io backend class
            '_target_': 'earth2studio.io.NetCDF4Backend',
            '_partial_': True,
            'backend_kwargs': {
                'mode': 'w',
                'diskless': False,
                'persist': False,
                'chunks': {
                    'ensemble': 1,
                    'time': 1,
                    'lead_time': 1
                }
            }
        }
    }
})

- why is batch generation progress not shown? To be fixed (cursor knows how, but it isn't clean)

In [None]:

import pandas as pd
from ensemble_utilities import EnsembleBase
from reproduce_utilities import create_base_seed_string

from utilities import (
    initialise,
    initialise_output,
    store_tracks,
    update_model_dict,
    write_to_disk,
)

(
    ensemble_configs,
    model_dict,
    dx_model_dict,
    cyclone_tracking,
    data,
    output_coords_dict,
    base_random_seed,
    all_tracks_dict,
    writer_executor,
    writer_threads
) = initialise(cfg)


ensemble configs include XYZ and represent blah. Let's have a look at the content and explore how many cases we will run, depending on number of ensemble members, batch size, number of checkpoints and initial conditions.

In [None]:
for ii, (pkg, ic, ens_idx, batch_ids_produce) in enumerate(ensemble_configs):
    print(f'ensemble config {ii+1} of {len(ensemble_configs)}:')
    print(f'    package: {pkg}')
    print(f'    initial condition: {ic}')
    print(f'    ensemble index: {ens_idx}')
    print(f'    batch ids to produce: {batch_ids_produce}\n')


model dict includes model and weights. Let's have a look at its contents. At each loop, it will be updated according to pkg in provided in the ensemble config.

In [None]:
from termcolor import colored

print(colored('The model class is:', attrs=['bold']))
print(model_dict['class'], '\n')

print(colored('The model package (weights), which is currently loaded:', attrs=['bold']))
print(model_dict['package'], '\n')

print(colored('The fully initialised model is provided in:', attrs=['bold']))
print(model_dict['model'].parameters, '\n')

need to set up HENS perturbation.

In [6]:
# for perturbation
skill_path = "/media/mkoch/9ee63bf8-5a14-4872-86f2-7f16b120269b/hens_data/hens_checkpoints/d2m_sfno_linear_74chq_sc2_layers8_edim620_wstgl2-epoch70_seed16.nc"
noise_amplification = 0.35
perturbed_var = ["z500"]
integration_steps = 3



from hens_perturbation import HENSPerturbation

from numpy import ndarray, datetime64

from earth2studio.data import DataSource
from earth2studio.models.px import PrognosticModel
from earth2studio.perturbation import Perturbation

def initialise_perturbation(
    model: PrognosticModel,
    data: DataSource,
    start_time: ndarray[datetime64],
) -> Perturbation:
    perturbation = HENSPerturbation(
        model=model,
        data=data,
        start_time=start_time,
        skill_path=skill_path,
        noise_amplification=noise_amplification,
        perturbed_var=perturbed_var,
        integration_steps=integration_steps
    )

    return perturbation

now bring everyhting together:
- loop over ensemble configs
- update model dict (if package has changed)
- initialise output
- initialise perturbation (as ICs might have changed)
- run inference, where all ensemble members are produced
- write to disk

In [None]:

# run forecasts
for pkg, ic, ens_idx, batch_ids_produce in ensemble_configs:
    # create seed base string required for reproducibility of individual batches
    base_seed_string = create_base_seed_string(pkg, ic, base_random_seed)

    # load new weights if necessary
    model_dict = update_model_dict(model_dict, pkg)

    io_dict = initialise_output(cfg, ic, model_dict, output_coords_dict)

    perturbation = initialise_perturbation(
        model=model_dict["model"], data=data, start_time=ic
    )

    run_hens = EnsembleBase(
        time=[ic],
        nsteps=cfg.nsteps,
        nensemble=cfg.nensemble,
        prognostic=model_dict["model"],
        data=data,
        io_dict=io_dict,
        perturbation=perturbation,
        output_coords_dict=output_coords_dict,
        dx_model_dict=dx_model_dict,
        cyclone_tracking=cyclone_tracking,
        batch_size=cfg.batch_size,
        ensemble_idx_base=ens_idx,
        batch_ids_produce=batch_ids_produce,
        base_seed_string=base_seed_string,
    )
    df_tracks_dict, io_dict = run_hens()
    for k, v in df_tracks_dict.items():
        v["ic"] = pd.to_datetime(ic)
        all_tracks_dict[k].append(v)

    # if in-memory flavour of io backend was chosen, write content to disk now
    if io_dict:
        writer_threads, writer_executor = write_to_disk(
            cfg,
            ic,
            model_dict,
            io_dict,
            writer_threads,
            writer_executor,
            ens_idx,
        )

# Output summaries of cyclone tracks if required
if "cyclone_tracking" in cfg:
    for area_name, all_tracks in all_tracks_dict.items():
        store_tracks(area_name, all_tracks, cfg)

if writer_executor is not None:
    for thread in list(writer_threads):
        thread.result()
        writer_threads.remove(thread)
    writer_executor.shutdown()

Now, there should be a folder called "output" with all the forecasts.
Also the tracks should be in there.

plot tracks and fields.
