In [None]:
import os

import pyrootutils

root = pyrootutils.setup_root(
    search_from=os.getcwd(),
    indicator=[".git", "pyproject.toml"],
    pythonpath=True,
    dotenv=True,
)

In [None]:
import matplotlib
import matplotlib.pyplot as plt

%matplotlib inline
import ast
import glob
import pickle
import re
import shutil

import numpy as np
import pandas as pd
import seaborn as sns

# %matplotlib notebook
from omegaconf import OmegaConf

import src.utils.plotting

# Inspect Sine Experiment Results
This notebook aids in inspecting the results of the sine experiments, and in collecting the necessary information for the paper: gathering the figures and outputting latex table data.

In [None]:
search_tags = None
search_experiment_name = "sine-eval"

In [None]:
# TODO: this notebook uses slow file-based search. Update to use mlflow-search as used in the other notebooks


def extract_model_name(path):
    # Define a regex pattern to capture the date-time segment and a potential subsequent segment
    pattern = r"(\d{4}-\d{2}-\d{2}_\d{2}-\d{2}-\d{2})(?:\/(\d+))?"

    # Search the path for the pattern
    match = re.search(pattern, path)

    # If a match is found, construct the segment accordingly
    if match:
        date_time_part = match.group(1)

        # If the subsequent segment exists
        if match.group(2):
            return f"Global: {date_time_part}/{match.group(2)}"
        return f"ReWTS: {date_time_part}"

    # If no match is found, return an empty string or handle accordingly
    return None


matched_runs = []
model_names = []

if search_experiment_name is not None:
    for filename in glob.iglob(
        str(root) + "/logs/eval/multiruns/2023-09-06_18*/**/.hydra/config.yaml", recursive=True
    ):
        match = False
        config = OmegaConf.load(filename)
        if (
            search_experiment_name is not None
            and config.get("logger", {}).get("mlflow", {}).get("experiment_name", "")
            == search_experiment_name
        ):
            match = True
        elif search_tags is not None:
            with open(os.path.join(filename, os.pardir, os.pardir, "tags.log"), "r") as tag_file:
                file_tags = ast.literal_eval(tag_file.readlines()[0])
            if not set(search_tags).isdisjoint(file_tags):
                match = True
        if not match:
            continue
        run_dir = os.path.normpath(os.path.join(filename, os.pardir, os.pardir))
        model_name = extract_model_name(config["model_dir"])
        eval_results = OmegaConf.load(os.path.join(run_dir, "eval_test_results.yaml"))
        matched_runs.append(
            dict(
                dataset_name=config["datamodule"]["dataset_name"],
                chunk_idx=config["datamodule"].get("chunk_idx", None),
                metrics=eval_results["metrics"],
                model_name=model_name,
                run_path=run_dir,
            )
        )

print(f"found {len(matched_runs)} runs matching searched terms")

In [None]:
df = pd.DataFrame(matched_runs)
df = pd.concat([df.drop("metrics", axis=1), df["metrics"].apply(pd.Series)], axis=1)

In [None]:
notation = "%.2E"
base = 1
dataset_order = ["sine-train", "sine-test"]

# Filter out non-integer chunk_idx values
latex_df = df[df["chunk_idx"].apply(lambda x: isinstance(x, int))]

latex_df = latex_df.copy()
# Modify the model_name column
latex_df["model_name"] = latex_df["model_name"].apply(lambda name: name.split(":")[0])

# Pivot the dataframe to create a wide format
latex_df = latex_df.pivot_table(
    index=["dataset_name", "model_name"], columns="chunk_idx", values="test_mse", aggfunc="mean"
).reset_index()

# Reorder so that 'train' dataset comes first
latex_df = latex_df.sort_values(
    by="dataset_name",
    key=lambda column: column.map({name: i for i, name in enumerate(dataset_order)}),
)

# Calculate the average metric over chunks
latex_df["avg"] = latex_df.iloc[:, 2:].mean(axis=1)

# Calculate the average metric over chunks
latex_df.iloc[:, 2:] = latex_df.iloc[:, 2:] / base

# Convert the dataframe to LaTeX format
latex_output = latex_df.to_latex(index=False, float_format=notation)

print(latex_output)

In [None]:
dataset_configs = {}
for run in matched_runs:
    if run["dataset_name"] not in dataset_configs:
        dataset_configs[run["dataset_name"]] = OmegaConf.load(
            os.path.join(run["run_path"], ".hydra", "config.yaml")
        )["datamodule"]

for dataset_name, dataset_config in dataset_configs.items():
    # Extract amplitude and frequency values
    amplitude_vals = [str(chunk["amplitude"]) for chunk in dataset_config["data_args"]]
    frequency_vals = [str(chunk["frequency"]) for chunk in dataset_config["data_args"]]

    # Create the amplitude and frequency rows
    amplitude_row = (
        f"{dataset_name} & Amplitude $A$ & "
        + " & ".join([str(a) for a in amplitude_vals])
        + " \\\\"
    )
    frequency_row = (
        f"{dataset_name} & Frequency $\\omega$ & "
        + " & ".join([str(f) for f in frequency_vals])
        + " \\\\"
    )

    print(amplitude_row)
    print(frequency_row)

In [None]:
def set_matplotlib_attributes(font_size=8, font="DejaVu Sans"):
    sns.set_theme(
        style="white",
        rc={
            "font.size": font_size,
            "font.family": font,
            "axes.spines.right": False,
            "axes.spines.top": False,
        },
        font_scale=1,
    )
    # matplotlib.rcParams.update({'font.size': font_size, 'font.family': font})


def set_figure_size(fig, column_span, height=None):
    if height is None:
        height = 4 if column_span == "double" else 6

    cm = 1 / 2.54
    if column_span == "single":
        fig_width = 8.4 * cm
    elif column_span == "double":
        fig_width = 17.4 * cm
    else:
        raise ValueError()
    figsize = (fig_width, height * cm)

    fig.set_size_inches(*figsize)


def plot_concatenated(_fig, _ax, p_data, preds, fontsize=8):
    color_cycle = plt.rcParams["axes.prop_cycle"].by_key()["color"]
    plt.sca(_ax)

    _ax.plot(p_data["series"].time_index, np.squeeze(p_data["series"].all_values()))

    for p_i, p in enumerate(preds):
        _ax.plot(
            p.time_index,
            np.squeeze(p.all_values()),
            color=color_cycle[1 + p_i % (len(color_cycle) - 1)],
        )

    plt.grid(which="major", axis="x")
    plt.xlabel("Chunk #", fontsize=fontsize, fontweight="normal")
    xticks, _ = plt.xticks()
    plt.xticks(ticks=xticks, labels=[f"{(i) / (500) + 1:.0f}" for i in xticks], fontsize=fontsize)
    plt.yticks(fontsize=fontsize)
    # for xtick in xticks[1:-2]:
    #    plt.axvline(xtick, linestyle="dashed")
    plt.xlim(p_data["series"].time_index[0], p_data["series"].time_index[-1])

    return _fig, _ax


def save_figure(fig, path):
    os.makedirs(os.path.dirname(path), exist_ok=True)
    fig.savefig(path + ".pdf", format="pdf", bbox_inches="tight")
    fig.savefig(path + ".png", format="png", bbox_inches="tight")

In [None]:
model_name = "ensemble"
dataset = "test"

prediction_dirs = dict(
    full=dict(
        train="../path/to/global/train/eval/predictions",
        test="../path/to/global/test/eval/predictions",
    ),
    ensemble=dict(
        train="../path/to/ensemble/train/eval/predictions",
        test="../path/to/global/test/eval/predictions",
    ),
)

prediction_dir = prediction_dirs[model_name][dataset]

with open(os.path.join(prediction_dir, "predictions.pkl"), "rb") as f:
    predictions = pickle.load(f)

with open(os.path.join(prediction_dir, "data.pkl"), "rb") as f:
    prediction_data = pickle.load(f)

## Concatenated Chunk Plots

In [None]:
fontsize = 8
set_matplotlib_attributes(font_size=fontsize)
fig, ax = plt.subplots(1, 1)
column_span = "double"
set_figure_size(fig, column_span)

fig, ax = plot_concatenated(fig, ax, prediction_data, predictions)
save_figure(
    fig,
    os.path.join(
        prediction_dir, "..", "plots", f"sine_concat_{dataset}_{model_name}_{column_span}_column"
    ),
)

## Edge effects plot

In [None]:
fontsize = 8
set_matplotlib_attributes(font_size=fontsize)
fig, ax = plt.subplots(1, 1)
column_span = "single"
set_figure_size(fig, column_span)

color_cycle = plt.rcParams["axes.prop_cycle"].by_key()["color"]
time_slice = [850, 1150]

ax.plot(
    prediction_data["series"].time_index[time_slice[0] : time_slice[1]],
    np.squeeze(prediction_data["series"].all_values()[time_slice[0] : time_slice[1]]),
)

for p_i, p in enumerate(predictions):
    if p.start_time() >= time_slice[0] and p.start_time() < time_slice[-1]:
        if p.end_time() > time_slice[-1]:
            p, _ = p.split_after(time_slice[-1])
        ax.plot(
            p.time_index,
            np.squeeze(p.all_values()),
            color=color_cycle[1 + p_i % (len(color_cycle) - 1)],
        )

plt.grid(visible=False)
plt.xlabel("Chunk #", fontsize=fontsize, fontweight="normal")
# xticks, _ = plt.xticks()
xticks = [time_slice[0], sum(time_slice) // 2]
plt.xticks(
    ticks=xticks,
    labels=[f"{i // (500) + 1:.0f}" for i in xticks],
    fontsize=fontsize,
    fontweight="normal",
)
plt.yticks(fontsize=fontsize)
plt.xlim(*time_slice)
save_figure(
    fig,
    os.path.join(prediction_dir, "..", "plots", f"sine_edge_{model_name}_{column_span}_column"),
)

## Concatenated with Ensemble Weights

In [None]:
assert model_name == "ensemble"

ensemble_weights = np.load(os.path.join(prediction_dir, "..", "eval_test_weights.npy"))

fontsize = 8
set_matplotlib_attributes(font_size=fontsize)
fig, ax = plt.subplots(2, 1, sharex=False)
column_span = "double"
set_figure_size(fig, column_span, height=10)

fig, ax[0] = plot_concatenated(fig, ax[0], prediction_data, predictions)
# ax[0].set_xlabel("")

fig = src.utils.plotting.plot_ensemble_weights(
    ensemble_weights, ax[1], time_indices=[160 + 30 * i for i in range(127)]
)

plt.sca(ax[1])
ax[1].set_ylabel("Model Weights", fontsize=fontsize, fontweight="normal")
ax[1].set_xlabel("", fontsize=fontsize, fontweight="normal")
plt.xticks(fontsize=fontsize)
plt.yticks(fontsize=fontsize)

plt.xticks(ticks=[], labels=[], fontsize=fontsize)
plt.yticks(fontsize=fontsize)

plt.grid(visible=False)
plt.xlim(prediction_data["series"].time_index[0], prediction_data["series"].time_index[-1])

pos = fig.axes[1].get_position()
fig.axes[1].set_position([pos.bounds[0], pos.bounds[1] - 0.035, pos.bounds[2], pos.bounds[3]])

# update position of colorbar
pos = fig.axes[-1].get_position()
fig.axes[-1].set_position([pos.bounds[0], pos.bounds[1] - 0.035, pos.bounds[2], pos.bounds[3]])
plt.sca(fig.axes[-1])
plt.xlabel("Model indices", fontweight="normal", fontsize=fontsize)
plt.xticks(fontsize=fontsize, fontweight="normal")


save_figure(
    fig,
    os.path.join(
        prediction_dir,
        "..",
        "plots",
        f"sine_concat_ensemble_weights_{dataset}_{column_span}_column",
    ),
)