# Inference with Makani 

There are currently two ways to inference models trained in Makani: `model_package` and `inferencer`. Let us start by adding Makani to our path.

In [None]:
import os, sys

## Using Model Package

In [None]:
from makani.models.model_package import LocalPackage, load_model_package

path_to_package = os.path.expanduser("~/Projects/fourcastnet/climate_fno/CWO-data/73VarQ/runs/sfno_linear_73chq_sc3_layers8_edim384_asgl2_cadam/ngpu64_sp1")

load_model_package(LocalPackage(path_to_package))

## Using Inferencer

The inferencer module is designed for running models directly within Makani. It supports massively parallel autoregressive roll-outs, ensemble forecasting and scoring. However, it's setup is slightly more involved.

To instantiate `inferencer`, we require the `params` datastructure. Thankfully, we can either use the model package for this or alternatively, use the configuration with which the model was trained.

When using inferencer, a dataloader for inference is created. As such, we recommend the latter method and manually setting the path to the out-of-sample dataset:

In [None]:
from makani.utils.YParams import YParams

yaml_config = os.path.join(makani_home, "config/sfnonet_devel.yaml")
config =  "sfno_linear_73chq_sc3_layers8_edim384_asgl2_cadam"
run_num = "ngpu64_sp1"

params = YParams(yaml_config, config)

let us set some extra parameters necessary for experimentation. Paths need to be adjusted but these can all be set to use the paths from the `model_package`. A lot of this is boilerplate and a rework is in progress to avoid all of this.

In [None]:
import torch
torch.backends.cudnn.benchmark = True

from makani.utils.parse_dataset_metada import parse_dataset_metadata

# point it to the training data
data_dir = '/home/bbonev/Projects/fourcastnet/climate_fno/CWO-data/73VarQ/'

# setting the necessary paths
params['inf_data_path'] = os.path.join(data_dir, 'out_of_sample/') # dataset to use for inference
params['experiment_dir'] = os.path.join(data_dir, 'runs/', config, run_num) # directory for writing out results
params['checkpoint_path'] = os.path.join(params.experiment_dir, 'training_checkpoints/ckpt_mp0.tar') # last checkpoint
params['best_checkpoint_path'] = os.path.join(params.experiment_dir, 'training_checkpoints/best_ckpt_mp0.tar') # best checkpoint
params['metadata_json_path'] = os.path.join(data_dir, 'invariants/data.json') # data.json file - see README for detailed info

# where to find normalization 
params['min_path'] = os.path.join(data_dir, 'stats/mins.npy')
params['max_path'] = os.path.join(data_dir, 'stats/maxs.npy')
params['time_means_path'] = os.path.join(data_dir, 'stats/time_means.npy')
params['global_means_path'] = os.path.join(data_dir, 'stats/global_means.npy')
params['global_stds_path'] =  os.path.join(data_dir, 'stats/global_stds.npy')
params['time_diff_means_path'] = os.path.join(data_dir, 'stats/time_diff_means.npy')
params['time_diff_stds_path'] = os.path.join(data_dir, 'stats/time_diff_stds.npy')

# land-sea-mask and orography
params['orography_path'] = os.path.join(data_dir, 'invariants/orography.nc')
params['landmask_path'] = os.path.join(data_dir, 'invariants/land_mask.nc')

# set parameters which can be read from the metadata file
params, _ = parse_dataset_metadata(params['metadata_json_path'], params=params)

params['multifiles'] = True # use the multifiles dataloader (not DALI)
params['n_future'] = 0 # predict one step at a time
params['valid_autoreg_steps'] = 20
params['split_data_channels'] = False 

# do not log to wandb
params['log_to_wandb'] = False

In [None]:
from makani import Inferencer

inferencer = Inferencer(params, world_rank=0)

Let's select some channels we want to analyze:

In [None]:
output_channels = ["u10m", "v10m", "z500", "t2m"]
output_channels = [params["channel_names"].index(ch) for ch in output_channels]
output_channels

Currently, the API provided in inferencer takes in an initial condition

In [None]:
truth, pred, _, acc_curve, rmse_curve = inferencer.inference_single(ic=0, output_data=True, output_channels=output_channels, compute_metrics=True)

In [None]:
import matplotlib.pyplot as plt

t = torch.arange((params.valid_autoreg_steps+1))*params.dhours
plt.plot(t, acc_curve[0])
plt.show()

In [None]:
pred.shape

In [None]:
from makani.utils.visualize import plot_comparison

plot_comparison(pred[-1, 0, 0], truth[-1, 0, 0], diverging=True)