In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

import pyrootutils

root = pyrootutils.setup_root(
    search_from=os.getcwd(),
    indicator=".project-root",
    pythonpath=True,
    dotenv=True,
)

In [None]:
%matplotlib inline

import hydra
import omegaconf

import src.eval
import src.utils
import src.utils.plotting

# Configuration

In [None]:
# to use ensemble: i) set multiple model_dirs, ii) set to parent directory of several model_dirs, iii) set to glob pattern (e.g. logs/train/multiruns/2023_10-14_*)
model_dir = [
    "logs/train/multiruns/2023-10-13_15-25-56/0",
    "logs/train/multiruns/2023-10-13_15-25-56/1",
]
model_dir = [src.utils.get_absolute_project_path(md) for md in model_dir]

config_path = os.path.join(
    "..", "..", "configs", "eval.yaml"
)  # NB: relative to <project_root>/src/utils (must be relative path)

config_overrides_dot = [  # same notation as for cli overrides (dot notation). Useful for changing whole modules, e.g. change which datamodule file is loaded
    "++extras.disable_pytorch_lightning_output=True",
    "++eval.kwargs.show_warnings=False",
]
config_overrides_dict = dict(
    model_dir=model_dir
)  # Dictionary with overrides. Useful for larger changes/additions/deletions that does not exist as entire files.

cfg = src.utils.initialize_hydra(
    config_path,
    config_overrides_dot,
    config_overrides_dict,
    return_hydra_config=True,
    print_config=False,
)  # print config to inspect if all settings are as expected

In [None]:
object_dict = src.utils.initialize_saved_objects(cfg)
model, datamodule, trainer, logger = (
    object_dict["model"],
    object_dict["datamodule"],
    object_dict.get("trainer"),
    object_dict.get("logger"),
)

In [None]:
with omegaconf.open_dict(cfg):
    cfg.eval.kwargs.forecast_horizon = 6
    cfg.eval.kwargs.stride = 6
    cfg.eval.plot.every_n_prediction = 1
    cfg.eval.plot.presenter = [
        "show",
        "savefig",
    ]  # set presenter to "show" to show figures in output, and "savefig" to save them to the model_dir

# Evaluate
The src.eval.run function returns the configured metrics over the evaluated split.

In [None]:
metric_dict, eval_object_dict = src.eval.run(cfg, datamodule, model, trainer, logger)

## Compare model with baseline

In [None]:
cfg_baseline = cfg.copy()
with omegaconf.open_dict(cfg_baseline):
    del cfg_baseline.model
    del cfg_baseline.model_dir
    if omegaconf.OmegaConf.select(cfg_baseline, "eval.kwargs.retrain") is not None:
        cfg_baseline.eval.kwargs.retrain = True

cfg_baseline = src.utils.initialize_hydra(
    config_path,
    ["model=baseline_naive_seasonal"],
    cfg_baseline,
    return_hydra_config=False,
    print_config=False,
)  # print config to inspect if all settings are as expected

baseline_model = hydra.utils.instantiate(cfg_baseline.model)

In [None]:
with omegaconf.open_dict(cfg):
    cfg.eval.plot = False
    cfg.eval.predictions = {"return": {"data": True}}

with omegaconf.open_dict(cfg_baseline):
    cfg_baseline.eval.plot = False
    cfg_baseline.eval.predictions = {"return": True}

metric_dict, eval_object_dict = src.eval.run(cfg, datamodule, model, trainer, logger)
baseline_metric_dict, baseline_eval_object_dict = src.eval.run(
    cfg_baseline, datamodule, baseline_model, trainer, logger
)

In [None]:
def metric_string(metrics):
    return " ".join([f"{'_'.join(k.split('_')[1:])}={v:.2E}" for k, v in metrics.items()])


fig = src.utils.plotting.plot_prediction(
    eval_object_dict["predictions"],
    eval_object_dict["predictions_data"],
    model,
    None,
    separate_target=False,
    plot_covariates=False,
    plot_encodings=False,
    plot_past=False,
    plot_prediction_point=False,
    fig_title=f"Model: {metric_string(metric_dict)}; Baseline: {metric_string(baseline_metric_dict)}",
)
_ = baseline_eval_object_dict["predictions"].plot(label="baseline")