In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import os

import pyrootutils

root = pyrootutils.setup_root(
    search_from=os.getcwd(),
    indicator=[".git", "pyproject.toml"],
    pythonpath=True,
    dotenv=True,
)

In [None]:
import matplotlib.pyplot as plt
import mlflow
import pandas as pd
import seaborn as sns

# %matplotlib notebook
from omegaconf import OmegaConf, open_dict

import src.eval
import src.utils
import src.utils.plotting

# Inspect Iterative Experiment Results
This notebook aids in inspecting the results of the iterative data chunks experiments, and in collecting the necessary information for the paper: gathering the figures and outputting latex table data.

In [None]:
def search_mlflow(
    search_experiment_name,
    mlflow_tracking_uri=None,
    mlflow_file_path=os.path.join(root, "logs", "mlflow", "mlruns"),
):
    tags_model_to_name = dict(
        XGB="XGBoost",
        TCN="TCN",
        RNN="LSTM",
        Regression="ElasticNet",
        NaiveSeasonal="BaselineSeasonal",
        TCNNoTarget="TCN",
        RNNNoTarget="LSTM",
    )
    if isinstance(search_experiment_name, str):
        search_experiment_name = [search_experiment_name]

    if mlflow_tracking_uri is None:
        assert mlflow_file_path is not None
        mlflow_tracking_uri = f"file:///{mlflow_file_path}"

    mlflow.set_tracking_uri(mlflow_tracking_uri)
    df = mlflow.search_runs(experiment_names=search_experiment_name)
    df["tags.model"] = df["tags.model"].apply(
        lambda x: tags_model_to_name.get(x.replace("Model", ""), x.replace("Model", ""))
    )

    return df

## Performance Metrics

In [None]:
search_experiment_name = "veas_pilot-nowcast-test"
latex_df = search_mlflow(
    search_experiment_name
)

notation = "%.2f"
base = 1
latex_columns = ["tags.model", "metrics.val_mse", "metrics.test_mse", "metrics.test_mae"]

aggregate = "mean"

latex_df = latex_df[latex_df["status"] == "FINISHED"]
latex_df = latex_df[latex_columns]
latex_df = latex_df.copy()

# Calculate mean, median, and std for each group
grouped = latex_df.groupby("tags.model")
latex_df_first = grouped.agg(aggregate).reset_index()
latex_df_std = grouped.std().reset_index()

# Handle NaN std values by replacing them with 0 (or you can choose to replace them with "")
latex_df_std = latex_df_std.fillna(0)

# Sort latex_df by 'metrics.test_mse'
latex_df_first = latex_df_first.sort_values(by="metrics.test_mse")

# Reindex latex_df_std and latex_df_mean to match the sorted order of latex_df
latex_df_std = latex_df_std.reindex(latex_df_first.index)

# Combine median and std into the desired format
for col in latex_columns[1:]:  # Skip the 'tags.model' column
    latex_df_first[col] = latex_df_first[col].apply(lambda x: notation % x)
    latex_df_std[col] = latex_df_std[col].apply(lambda x: notation % x)

    # Use the mean only if std is 0, otherwise use the full format
    latex_df_first[col] = latex_df_first[col] + latex_df_std[col].apply(
        lambda std: "" if std == notation % 0 else r" $\pm$ " + std
    )
latex_df_first["tags.model"] = "& " + latex_df_first["tags.model"]

# Convert the dataframe to LaTeX format
latex_output = latex_df_first.to_latex(index=False, float_format=notation)

print(latex_output)

# Lags Input Length History

In [None]:
src.utils.plotting.set_matplotlib_attributes(font_size=8)
# Configurable metric for y-axis
metric_name = "test_mse"
metric_column = f"metrics.{metric_name}"
metric_plot_name = " ".join(metric_name.replace("test_", "").split("_")).upper()

lags_column = "params.model_lags"
lags_plot_name = "Length of input (hours)"

model_name_column = "tags.model"

models = ["xgboost", "elastic_net", "rnn", "tcn"]
model_order = ["ElasticNet", "LSTM", "TCN", "XGBoost"]
task = "forecast"

search_experiment_name = [f"veas_pilot-lags_test-veas_pilot_{model}_{task}" for model in models]
df = search_mlflow(search_experiment_name)

df = df.loc[df[metric_column].notna()]


# Rename columns for better plotting
df = df.rename(columns={metric_column: metric_plot_name, lags_column: lags_plot_name})

# Sort by 'chunk_length' numerically
df[lags_plot_name] = df[lags_plot_name].astype(
    int, errors="ignore"
)  # Convert to integer if it's not already
df = df.sort_values(by=lags_plot_name)

# Plotting
plot = sns.lineplot(
    data=df,
    x=lags_plot_name,
    y=metric_plot_name,
    hue=model_name_column,
    # marker="o",
    hue_order=model_order,
)
# set_figure_size(plot.get_figure(), column_span=fig_column_span, height=fig_height)

# Set x-ticks to only where there is data
unique_lags_lengths = df[lags_plot_name].unique()[::4]
plot.set_xticks(unique_lags_lengths)

# Convert x-tick labels from count of 10 minutes to hours
plot.set_xticklabels([f"{length / 6:.0f}" for length in unique_lags_lengths])

if task == "forecast":
    # Remove the legend title
    legend = plot.legend_
    legend.set_title("")
    legend.set_frame_on(False)
else:
    plt.legend([])

plot.set_title(task.capitalize())

fig_folder_name = "veas_pilot"
# fig_folder_name = "-".join(search_experiment_name[0].split("-")[:-1])
fig_path = os.path.join(root, "figures", fig_folder_name, f"{task}_lags")
src.utils.plotting.set_figure_size(plot.figure, "single", height=6)
src.utils.plotting.save_figure(plot.figure, fig_path)
plt.show()

# Plot model outputs for paper

In [None]:
def load_objects(model_dir, dot_overrides=None):
    model_dir = src.utils.hydra.get_absolute_project_path(model_dir)
    config_path = os.path.join(
        "..", "..", "configs", "eval.yaml"
    )  # NB: relative to <project_root>/src/utils (must be relative path)

    config_overrides_dot = [  # same notation as for cli overrides (dot notation). Useful for changing whole modules, e.g. change which datamodule file is loaded
        "++extras.disable_pytorch_lightning_output=True",
        "++extras.select_gpu=False",
        "++extras.matplotlib_backend=null",
    ]
    if dot_overrides is not None:
        config_overrides_dot.extend(dot_overrides)
    config_overrides_dict = {
        "model_dir": model_dir,
        "datamodule": dict(train_val_test_split=dict(test=["2024-01-29 10:30:00", "2024-02-28"])),
        "eval": dict(
            split="test",
            plot=False,
            predictions={"return": {"data": True}},
            show_warnings=False,
            kwargs=dict(metric=None),
        ),
    }  # Dictionary with overrides. Useful for larger changes/additions/deletions that does not exist as entire files.

    cfg = src.utils.initialize_hydra(
        config_path,
        config_overrides_dot,
        config_overrides_dict,
        return_hydra_config=True,
        print_config=False,
    )  # print config to inspect if all settings are as expected
    with open_dict(cfg):
        cfg.logger = None

    objects = src.utils.instantiate.instantiate_saved_objects(cfg)
    objects["cfg"] = cfg

    return objects

In [None]:
model_dirs = {
    "forecast": "logs/train/multiruns/2024-09-04_16-27-29/20",
    "nowcast": "logs/train/multiruns/2024-09-03_20-02-19/14",
}

objects = {
    "forecast": load_objects(model_dirs["forecast"], ["experiment=veas_pilot_forecast_test"]),
    "nowcast": load_objects(model_dirs["nowcast"], ["experiment=veas_pilot_nowcast_test"]),
}

In [None]:
eval_results = {}
for task, object_dict in objects.items():
    metric_dict, eval_object_dict = src.eval.run(
        object_dict["cfg"],
        object_dict["datamodule"],
        object_dict["model"],
        object_dict.get("trainer", None),
        object_dict.get("logger", None),
    )
    eval_results[task] = eval_object_dict
    eval_results[task]["metrics"] = metric_dict

In [None]:
%matplotlib inline

def plot_period(eval_results, start_time, end_time, every_n_predictions=3):
    # Plot the forecast lines with different colors
    colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]

    src.utils.plotting.set_matplotlib_attributes()

    fig = src.utils.plotting.plot_prediction(
        [
            p
            for p_i, p in enumerate(eval_results["forecast"]["predictions"])
            if p_i % every_n_predictions == 0 and start_time <= p.start_time() <= end_time
        ],
        eval_results["forecast"]["predictions_data"],
        objects["forecast"]["model"],
        None,
        plot_covariates=False,
        plot_past=False,
    )
    fig = fig[0]
    ax = fig.axes[0]

    eval_results["nowcast"]["predictions"].slice(start_time, end_time).plot(
        ax=ax, color="orange", linestyle="dashed", label="Nowcast"
    )

    for line_i in range(len(ax.lines)):
        ax.lines[line_i].set_linewidth(1)
    # eval_object_dict["predictions_data"]["series"].plot(label="_nolegend_")
    # eval_object_dict["predictions"].plot(label="_nolegend_")

    ax.set_ylabel("Nitrate concentration [mg/l]")
    ax.set_title("Nowcast and hour-ahead forecasts from LSTM")
    src.utils.plotting.set_figure_size(fig, column_span="double", height=8)
    src.utils.plotting.save_figure(fig, "../figures/pilot/predictions_lstm")
    plt.show()

In [None]:
plot_period(eval_results, pd.Timestamp("2024-02-13 12:00:00"), pd.Timestamp("2024-02-14 12:00:00"))

In [None]:
plot_period(eval_results, pd.Timestamp("2024-02-14 12:00:00"), pd.Timestamp("2024-02-17 12:00:00"))

In [None]:
import darts.metrics

metrics = {"forecast": [], "nowcast": []}

for p_i, p in enumerate(eval_results["forecast"]["predictions"]):
    metrics["forecast"].append(
        darts.metrics.mse(
            eval_results["forecast"]["predictions_data"]["series"], p, intersect=True
        )
    )
    metrics["nowcast"].append(
        darts.metrics.mse(
            eval_results["forecast"]["predictions_data"]["series"],
            eval_results["nowcast"]["predictions"][p.start_time()],
            intersect=True,
        )
    )

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(20, 5))
ax.plot(
    eval_results["nowcast"]["predictions"].time_index[: len(metrics["forecast"])],
    metrics["forecast"],
    label="forecast",
)
ax.plot(
    eval_results["nowcast"]["predictions"].time_index[: len(metrics["forecast"])],
    metrics["nowcast"],
    label="nowcast",
)
plt.legend()