# Methods

In [1]:
import matplotlib.colors as mcolors
from ts_inverse.attack_time_series_utils import SMAPELoss
import os
import json
import torch
import matplotlib.lines as mlines
from matplotlib.transforms import Bbox
from matplotlib.gridspec import GridSpec
import wandb
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import dotenv
import torch.nn.functional as F

dotenv.load_dotenv()


def gather_run_histories(columns, rows, lines, filters, runs, y_axis):
    unique_columns, unique_rows, unique_lines = set(), set(), set()
    for run in runs:
        continue_next_run = [False for _ in range(len(filters))]
        for i, filter_dict in enumerate(filters):
            if not np.array([v(run.config[k]) for k, v in filter_dict.items() if k in run.config]).all():
                continue_next_run[i] = True
        if np.array(continue_next_run).all():
            continue

        unique_columns.add(run.config[columns])
        unique_rows.add(run.config[rows])
        unique_lines.add(run.config[lines])

    unique_columns = sorted(list(unique_columns))
    unique_rows = sorted(list(unique_rows))
    unique_lines = sorted(list(unique_lines))
    print(unique_columns, unique_rows, unique_lines)

    series_dict = {}

    for run in runs:
        continue_next_run = [False for _ in range(len(filters))]
        for i, filter_dict in enumerate(filters):
            if not np.array([v(run.config[k]) if k in run.config else False for k, v in filter_dict.items()]).all():
                continue_next_run[i] = True
        if np.array(continue_next_run).all():
            continue

        row_label = run.config[rows]
        column_label = run.config[columns]

        i = unique_rows.index(row_label)
        j = unique_columns.index(column_label)

        historys = run.history(keys=y_axis).set_index("_step")
        mean_series = historys.mean(axis=1)
        mean_series.name = run.config["seed"]
        if (i, j) not in series_dict:
            series_dict[(i, j)] = {}

        if run.config[lines] not in series_dict[(i, j)]:
            series_dict[(i, j)][run.config[lines]] = mean_series.to_frame()
        else:
            series_dict[(i, j)][run.config[lines]][mean_series.name] = mean_series

    y_lines = [f"{s.split('/')[0]} {s.split('/')[1].upper()}" for s in y_axis]
    title = f'Comparing {str.join(", ", y_lines)} over {columns} and {rows} for different {lines}'

    return unique_columns, unique_rows, series_dict, title


def plot_metrics_in_grid(unique_columns, unique_rows, series_dict, title, label_prefix, x_limits):
    plot_size = 7
    fig, axes = plt.subplots(
        len(unique_rows), len(unique_columns), figsize=(plot_size * len(unique_columns), plot_size * len(unique_rows))
    )
    for (i, j), graph_data in series_dict.items():
        row_label, column_label = unique_rows[i], unique_columns[j]
        for line_label, df_series in graph_data.items():
            df_series.ffill(inplace=True)
            mean_series = df_series.mean(axis=1)
            std_series = df_series.std(axis=1)
            axes[i, j].fill_between(mean_series.index, mean_series - std_series, mean_series + std_series, alpha=0.2)
            axes[i, j].plot(mean_series, label=f"{label_prefix}{line_label}")

        if j == 0:
            axes[i, j].set_ylabel(row_label)
        if i == 0:
            axes[i, j].set_title(column_label)
        axes[i, j].set_ylim(0)
        axes[i, j].set_xlim(*x_limits)
        axes[i, j].legend()
        # axes[i, j].set_yscale('log')

    fig.suptitle(title)
    plt.tight_layout()
    fig.savefig(f"./out/plots/baselines/grid_{title}.pdf")
    plt.show()


def plot_metrics(unique_columns, unique_rows, series_dict, title, label_prefix, x_limits):
    # Plot the metrics in a single plot and average over seeds, columns and rows, just like the grid plot
    # Average the different seeds, columns and rows such that a pingle plot with multiple lines of just line_label is created
    fig, ax = plt.subplots(1, 1, figsize=(10, 6))
    line_dict = {}
    for (i, j), graph_data in series_dict.items():
        row_label, column_label = unique_rows[i], unique_columns[j]
        for line_label, df_series in graph_data.items():
            if line_label not in line_dict:
                line_dict[line_label] = [df_series]
            else:
                line_dict[line_label].append(df_series)

    for line_label, df_series_list in line_dict.items():
        df_series = pd.concat(df_series_list, axis=1, join="inner")
        # df_series.ffill(inplace=True)
        mean_series = df_series.mean(axis=1)
        std_series = df_series.std(axis=1)
        ax.fill_between(mean_series.index, mean_series - std_series, mean_series + std_series, alpha=0.2)
        ax.plot(mean_series, label=f"{label_prefix}{line_label}")

    ax.set_xlim(*x_limits)
    ax.legend()
    fig.suptitle(title)
    plt.tight_layout()
    fig.savefig(f"./out/plots/baselines/all_{title}.pdf")
    plt.show()


# CREATE TABLES
def gather_final_metrics_by_parameters(
    columns, rows, metrics, variable, filters, runs, variables_other_sorted=None, specific_seed=None
):
    history_dict = {}
    if isinstance(runs, list):
        runs = [run for run_list in runs for run in run_list]

    runs_found = 0
    for run in runs:
        if specific_seed is not None and run.config["seed"] != specific_seed:
            continue

        continue_next_run = [False for _ in range(len(filters))]
        for i, filter_dict in enumerate(filters):
            if not np.array([v(run.config[k]) if k in run.config else False for k, v in filter_dict.items()]).all():
                continue_next_run[i] = True
        if np.array(continue_next_run).all():
            continue

        run_column = str(run.config[columns]).replace("JitGRU_Predictor", "GRU_Predictor").replace("_Predictor", "")
        run_row = str(run.config[rows]).replace("JitGRU_Predictor", "GRU_Predictor").replace("_Predictor", "")
        run_variable = str(run.config[variable]).replace("JitGRU_Predictor", "GRU_Predictor").replace("_Predictor", "")
        run_metric = [run.summary[metric] for metric in metrics]
        # run_metric = { metric: run.summary[metric] for metric in metrics}
        run_seed = run.config["seed"]

        if (run_column, run_row, run_variable) not in history_dict:
            history_dict[(run_column, run_row, run_variable)] = [run_metric]
        else:
            history_dict[(run_column, run_row, run_variable)].append(run_metric)
        runs_found += 1

    print(f"Found {runs_found} runs with metrics: {history_dict}")

    unique_columns = sorted(list(set([column for column, _, _ in history_dict.keys()])))
    unique_rows = sorted(list(set([row for _, row, _ in history_dict.keys()])))
    unique_variables = sorted(list(set([variable for _, _, variable in history_dict.keys()])))
    if variables_other_sorted is not None:
        unique_variables = variables_other_sorted

    print(unique_columns, unique_rows, unique_variables)

    # Create a table with the average metrics for each variable
    table = np.zeros((len(unique_columns), len(unique_rows), len(unique_variables), len(metrics) * 2))
    for j, row in enumerate(unique_rows):
        for i, column in enumerate(unique_columns):
            for k, variable in enumerate(unique_variables):
                if (column, row, variable) in history_dict:
                    table[i, j, k, 0:2] = np.mean(history_dict[(column, row, variable)], axis=0)
                    if len(history_dict[(column, row, variable)]) > 1:
                        table[i, j, k, 2:4] = np.std(history_dict[(column, row, variable)], axis=0)
                    else:
                        table[i, j, k, 2:4] = (
                            np.ones(2) * -1
                        )  # If there is only one seed, the std is -1 to indicate that there is no std
                else:
                    table[i, j, k, :] = np.nan
    return table, unique_columns, unique_rows, unique_variables


def print_latex_table_input_target(table, unique_columns, unique_rows, unique_variables, variable_name=""):
    print("\\begin{table*}[]")
    print("\\centering")
    print("\\resizebox{\\linewidth}{!}{")
    print("\\begin{tabular}{l|l" + "|cc" * len(unique_columns) + "}")
    print("\\toprule")
    print(
        "      & Dataset                 & \\multicolumn{2}{c}{Electricity 370} & \\multicolumn{2}{c}{KDDCup} & \\multicolumn{2}{c}{London Smartmeter} & \\multicolumn{2}{c}{Proprietary} \\\\"
    )
    dataset_headers = " & ".join([f"\\multicolumn{{2}}{{c}}{{{col}}}" for col in unique_columns])
    print(f"      & REPLACE                 & {dataset_headers} \\\\")
    print("\\midrule")

    # print(f"Model & {variable_name} & Input            & Target           & Input       & Target       & Input             & Target            & Input            & Target           \\\\")
    print(f"Model & {variable_name} & {'           & '.join(['Input & Target'] * len(unique_columns))} \\\\")
    print("\\midrule")

    def add_bold_if_min(number_value, string, minimum):
        if number_value == minimum:
            return "\\textbf{" + str(string) + "}"
        return str(string)

    def format_mean_std_value(value, std):
        # \textbf{0.573$\pm$0.160}
        # $\mathbf{0.573}_{\mathbf{0.160}}$
        if value < 1e-3:
            if std > 0:
                std_str = f"{std:.2f}"  # .lstrip('0')  # Remove leading zero
                return (f"{value:.1E}$_" + "{" + std_str + "}$").replace("E", "e")
            return f"{value:.1E}".replace("E", "e")
        if std > 0:
            std_str = f"{std:.2f}"  # .lstrip('0')  # Remove leading zero
            return f"{value:.3f}$_" + "{" + std_str + "}$"
        return f"{value:.3f}"

    # def format_mean_std_value(value, std):
    #     # \textbf{0.573$\pm$0.160}
    #     # $\mathbf{0.573}_{\mathbf{0.160}}$
    #     if value < 1e-3:
    #         if std > 0:
    #             return f"{value:.1E}$\pm${std:.1E}".replace('E', 'e')
    #         return f"{value:.1E}".replace('E', 'e')
    #     if std > 0:
    #         return f"{value:.3f}$\pm${std:.3f}"
    #     return f"{value:.3f}"

    for j, row in enumerate(unique_rows):
        for k, variable in enumerate(unique_variables):
            if k == 0:
                row_values = [f"{row.replace('_Predictor', '')}\t& {variable}\t"]
            else:
                row_values = [f"\t& {variable}\t"]
            for i, column in enumerate(unique_columns):
                # Assuming last dimension is for mean values and we want the first element
                input_mean_value, target_mean_value = table[i, j, k, 0], table[i, j, k, 1]
                input_std_value, target_std_value = table[i, j, k, 2], table[i, j, k, 3]
                # If the input mean or target mean are the minimum value then it should be printed in bold with \textbf{}
                if not np.isnan(input_mean_value):
                    row_string_value = f""
                    row_string_value += add_bold_if_min(
                        input_mean_value, format_mean_std_value(input_mean_value, input_std_value), np.nanmin(table[i, j, :, 0])
                    )
                    row_string_value += " & "
                    row_string_value += add_bold_if_min(
                        target_mean_value,
                        format_mean_std_value(target_mean_value, target_std_value),
                        np.nanmin(table[i, j, :, 1]),
                    )
                    row_values.append(row_string_value)
                else:
                    row_values.append("N/A & N/A")
            print(" & ".join(row_values) + " \\\\")
        if j != len(unique_rows) - 1:
            print("\\midrule")

    print("\\bottomrule")
    print("\\end{tabular}}")
    print("\\caption{}")
    print("\\end{table*}")


# PLOT RECONSTRUCTIONS
def gather_run_reconstructions(dataset_seed, columns, rows, variables, filters, runs, replace_dict={}):
    def dataframe_keys_and_should_skip_run(run):
        if run.config["dataset"] not in dataset_seed or run.config["seed"] != dataset_seed[run.config["dataset"]]:
            return [], True

        dataframe_keys = []
        for key in run.summary.keys():
            if "dataframe" in key:
                dataframe_keys.append(key)
        if len(dataframe_keys) == 0:
            return [], True

        continue_next_run = [False for _ in range(len(filters))]
        for i, filter_dict in enumerate(filters):
            if not np.array([v(run.config[k]) if k in run.config else False for k, v in filter_dict.items()]).all():
                continue_next_run[i] = True
        if np.array(continue_next_run).all():
            return [], True
        return dataframe_keys, False

    def replace_predictor_name(runs_config, name):
        run_config_value = runs_config[name]
        for key, value in replace_dict.items():
            if isinstance(run_config_value, str):
                run_config_value = run_config_value.replace(key, value)
        if name == "dataset":
            return f"{run_config_value} ({dataset_seed[runs_config[name]]})"
        if name == "model":
            return run_config_value.replace("JitGRU_Predictor", "GRU_Predictor").replace("_Predictor", "")
        return run_config_value

    reconstructed_data_dict = {}
    if isinstance(runs, list):
        runs = [run for run_list in runs for run in run_list]

    for run in runs:
        dataframe_keys, should_skip_run = dataframe_keys_and_should_skip_run(run)
        if should_skip_run:
            continue

        run_column = replace_predictor_name(run.config, columns)
        run_row = replace_predictor_name(run.config, rows)
        run_variable = replace_predictor_name(run.config, variables)

        dataframe_info = run.summary[dataframe_keys[0]]
        file = run.file(dataframe_info["path"])
        if not os.path.exists(file.name):
            file.download(replace=True)
        with open(file.name, "r") as f:
            data = json.load(f)
        reconstructed_data = pd.DataFrame(data=data["data"], columns=data["columns"])
        if (run_column, run_row, run_variable) not in reconstructed_data_dict:
            reconstructed_data_dict[(run_column, run_row, run_variable)] = reconstructed_data

    print(reconstructed_data_dict.keys())

    unique_columns = sorted(list(set([column for column, _, _ in reconstructed_data_dict.keys()])))
    unique_rows = sorted(list(set([row for _, row, _ in reconstructed_data_dict.keys()])))
    unique_variables = sorted(list(set([variable for _, _, variable in reconstructed_data_dict.keys()])))

    return reconstructed_data_dict, unique_columns, unique_rows, unique_variables


def plot_single_batch_reconstructions_in_grid(
    series_dict,
    unique_columns,
    unique_rows,
    unique_variables,
    extra_title="",
    legend_loc="upper left",
    column_offset=0.0,
    plot_size_width=2,
    plot_size_height=0.33,
    legend_font_size=7,
):
    if not series_dict:
        print("No series given, skipping plotting")
        return

    fig = plt.figure(figsize=(plot_size_width * len(unique_columns), plot_size_height * len(unique_rows) * len(unique_variables)))
    subfigs = fig.subfigures(len(unique_rows), len(unique_columns))
    if len(unique_rows) == 1 and len(unique_columns) == 1:
        subfigs = np.array([[subfigs]])
    elif len(unique_rows) == 1:
        subfigs = np.array([subfigs])
    elif len(unique_columns) == 1:
        subfigs = np.array([subfigs]).T

    colors = plt.cm.tab10.colors

    for i, row_label in enumerate(unique_rows):
        for j, column_label in enumerate(unique_columns):
            subfig = subfigs[i, j]
            subfig.subplots_adjust(wspace=0, hspace=0, left=0.05, right=0.99, top=0.9, bottom=0.05)
            axes_1 = subfig.subplots(len(unique_variables), 1, sharex=True, sharey=True)
            if len(unique_variables) == 1:
                axes_1 = [axes_1]
            for k, variable in enumerate(unique_variables):
                if (column_label, row_label, variable) in series_dict:
                    reconstructed_data = series_dict[(column_label, row_label, variable)]
                    n_batches = len(reconstructed_data.columns) // 4
                    for b in range(n_batches):
                        df_plot = reconstructed_data[[f"batch_inputs_{b}_0", f"batch_targets_{b}"]]
                        batch_colors = [colors[(2 * b) % 10], colors[(2 * b + 2) % 10]]
                        df_plot.plot(ax=axes_1[k], legend=False, color=batch_colors)

                    for b in range(n_batches):
                        df_plot = reconstructed_data[[f"dummy_inputs_{b}_0", f"dummy_targets_{b}"]]
                        dummy_colors = [colors[(2 * b + 1) % 10], colors[(2 * b + 3) % 10]]
                        df_plot.plot(ax=axes_1[k], legend=False, style=["--", "--"], color=dummy_colors)

                    # reconstructed_data.plot(ax=axes_1[k], legend=False)
                    dummy_handle = mlines.Line2D([], [], color="none", label=variable)
                    axes_1[k].legend(
                        handles=[dummy_handle],
                        loc=legend_loc,
                        handlelength=0,
                        handletextpad=0,
                        fancybox=True,
                        fontsize=legend_font_size,
                    )

                axes_1[k].set_ylim(-0.1, 1.1)
                axes_1[k].set_yticks([])
                axes_1[k].set_xticks([])

            if i == 0:
                subfig.text(0.5, 1.0 + column_offset, column_label, ha="center", va="top")
            if j == 0:
                subfig.text(0.0, 0.5, row_label, ha="left", va="center", fontsize=8, rotation=90)

    fig.savefig(f"./out/plots/reconstructions/grid_reconstructions_{extra_title}.pdf", bbox_inches="tight")
    plt.show()


def get_batch_sample_mapping(original_data, dummy_data):
    batch_size = original_data.shape[0]
    sample_mapping = np.arange(0, batch_size)
    for i in range(batch_size):
        smallest_loss = float("inf")
        for j in range(batch_size):
            loss = (
                F.l1_loss(original_data[i], dummy_data[j]).detach().item()
            )  # instead of MSE because L1 is less sensitive to outliers
            if loss < smallest_loss:
                smallest_loss = loss
                sample_mapping[i] = j
    # Check if sample_mapping contains duplicates; if so, reset it
    if len(np.unique(sample_mapping)) != batch_size:
        sample_mapping = np.arange(0, batch_size)
    return sample_mapping


def get_batch_sample_mapping_from_dataframe(reconstructed_data, batch_size):
    original_inputs = torch.tensor(
        reconstructed_data[[f"batch_inputs_{i}_0" for i in range(batch_size)]].dropna().values.T
    ).unsqueeze(-1)
    dummy_inputs = torch.tensor(
        reconstructed_data[[f"dummy_inputs_{i}_0" for i in range(batch_size)]].dropna().values.T
    ).unsqueeze(-1)
    original_targets = torch.tensor(
        reconstructed_data[[f"batch_targets_{i}" for i in range(batch_size)]].dropna().values.T
    ).unsqueeze(-1)
    dummy_targets = torch.tensor(
        reconstructed_data[[f"dummy_targets_{i}" for i in range(batch_size)]].dropna().values.T
    ).unsqueeze(-1)
    # print(original_inputs.shape, dummy_inputs.shape, original_targets.shape, dummy_targets.shape)

    standard_mapping = np.arange(0, batch_size)
    input_sample_mapping = get_batch_sample_mapping(original_inputs, dummy_inputs)
    target_sample_mapping = get_batch_sample_mapping(original_targets, dummy_targets)

    sample_mapping = np.arange(0, batch_size)
    if not (standard_mapping == input_sample_mapping).all():
        sample_mapping = input_sample_mapping
    if (standard_mapping == input_sample_mapping).all() and not (standard_mapping == target_sample_mapping).all():
        sample_mapping = target_sample_mapping
    # if not (standard_mapping == input_sample_mapping).all() and not (standard_mapping == target_sample_mapping).all():
    #     if not (input_sample_mapping == target_sample_mapping).all():
    #         raise ValueError('Input and target sample mappings are not equal while being different from the standard mapping.')
    return sample_mapping, original_inputs, dummy_inputs, original_targets, dummy_targets


def gather_final_metrics_from_reconstructions_by_parameters(
    columns, rows, metric, variable, filters, runs, variables_other_sorted=None
):
    history_dict = {}
    if isinstance(runs, list):
        runs = [run for run_list in runs for run in run_list]

    runs_found = 0
    for run in runs:
        dataframe_keys = []
        for key in run.summary.keys():
            if "dataframe" in key and not "quantile" in key:  # Exclude quantile dataframes
                dataframe_keys.append(key)
        if len(dataframe_keys) == 0:
            continue

        continue_next_run = [False for _ in range(len(filters))]
        for i, filter_dict in enumerate(filters):
            if not np.array([v(run.config[k]) if k in run.config else False for k, v in filter_dict.items()]).all():
                continue_next_run[i] = True
        if np.array(continue_next_run).all():
            continue

        run_column = str(run.config[columns]).replace("JitGRU_Predictor", "GRU_Predictor").replace("_Predictor", "")
        run_row = str(run.config[rows]).replace("JitGRU_Predictor", "GRU_Predictor").replace("_Predictor", "")
        run_variable = str(run.config[variable]).replace("JitGRU_Predictor", "GRU_Predictor").replace("_Predictor", "")

        dataframe_info = run.summary[dataframe_keys[0]]
        file = run.file(dataframe_info["path"])
        if not os.path.exists(file.name):
            file.download(replace=True)
        with open(file.name, "r") as f:
            data = json.load(f)
        reconstructed_data = pd.DataFrame(data=data["data"], columns=data["columns"])

        sample_mapping, original_inputs, dummy_inputs, original_targets, dummy_targets = get_batch_sample_mapping_from_dataframe(
            reconstructed_data, run.config["batch_size"]
        )

        # Calculate the metrics
        run_metric = [
            metric(dummy_inputs[sample_mapping], original_inputs).item(),
            metric(dummy_targets[sample_mapping], original_targets).item(),
        ]

        if (run_column, run_row, run_variable) not in history_dict:
            history_dict[(run_column, run_row, run_variable)] = [run_metric]
        else:
            history_dict[(run_column, run_row, run_variable)].append(run_metric)
        runs_found += 1

    print(f"Found {runs_found} runs with reconstructions")
    unique_columns = sorted(list(set([column for column, _, _ in history_dict.keys()])))
    unique_rows = sorted(list(set([row for _, row, _ in history_dict.keys()])))
    unique_variables = sorted(list(set([variable for _, _, variable in history_dict.keys()])))
    if variables_other_sorted is not None:
        unique_variables = variables_other_sorted

    print(unique_columns, unique_rows, unique_variables)

    # Create a table with the average metrics for each variable
    table = np.zeros((len(unique_columns), len(unique_rows), len(unique_variables), 4))
    for j, row in enumerate(unique_rows):
        for i, column in enumerate(unique_columns):
            for k, variable in enumerate(unique_variables):
                if (column, row, variable) in history_dict:
                    table[i, j, k, 0:2] = np.mean(history_dict[(column, row, variable)], axis=0)
                    if len(history_dict[(column, row, variable)]) > 1:
                        table[i, j, k, 2:4] = np.std(history_dict[(column, row, variable)], axis=0)
                    else:
                        table[i, j, k, 2:4] = (
                            np.ones(2) * -1
                        )  # If there is only one seed, the std is -1 to indicate that there is no std
                else:
                    table[i, j, k, :] = np.nan
    return table, unique_columns, unique_rows, unique_variables


def plot_multi_batch_reconstructions_in_grid(
    series_dict,
    unique_columns,
    unique_rows,
    unique_batch_sizes,
    extra_title="",
    legend_loc="upper left",
    column_offset=0.0,
    plot_size_width=2,
    plot_size_height=0.33,
    legend_font_size=7,
    external_fig=None,
):
    if not series_dict:
        print("No Dictionary provided, thuse not plotting multi batch reconstructions in grid!")
        return

    if external_fig is not None:
        fig = external_fig
    else:
        fig = plt.figure(
            figsize=(plot_size_width * len(unique_columns), plot_size_height * len(unique_rows) * sum(unique_batch_sizes))
        )
    subfigs = fig.subfigures(len(unique_rows), len(unique_columns))
    if len(unique_rows) == 1 and len(unique_columns) == 1:
        subfigs = np.array([[subfigs]])
    elif len(unique_rows) == 1:
        subfigs = np.array([subfigs])
    elif len(unique_columns) == 1:
        subfigs = np.array([subfigs]).T

    for i, row_label in enumerate(unique_rows):
        for j, column_label in enumerate(unique_columns):
            subfig = subfigs[i, j]

            batch_size = [bs for bs in unique_batch_sizes if (column_label, row_label, bs) in series_dict][0]

            subfig.subplots_adjust(wspace=0, hspace=0, left=0.05, right=0.99, top=0.9, bottom=0.05)
            axes = subfig.subplots(max(unique_batch_sizes), 1, sharex=True, sharey=True)

            if (column_label, row_label, batch_size) in series_dict:
                reconstructed_data = series_dict[(column_label, row_label, batch_size)]
                sample_mapping, batch_inputs, dummy_inputs, batch_targets, dummy_targets = (
                    get_batch_sample_mapping_from_dataframe(reconstructed_data, batch_size)
                )
                original_x_axis = np.arange(0, batch_inputs.shape[1] + batch_targets.shape[1])
                for b, d in enumerate(sample_mapping):
                    ax = axes[b] if batch_size > 1 else axes
                    ax.plot(original_x_axis[: batch_inputs.shape[1]], batch_inputs[b, :].detach().cpu().numpy())
                    ax.plot(original_x_axis[: batch_inputs.shape[1]], dummy_inputs[d, :].detach().cpu().numpy(), linestyle="--")

                    ax.plot(original_x_axis[batch_inputs.shape[1] :], batch_targets[b, :].detach().cpu().numpy())
                    ax.plot(original_x_axis[batch_inputs.shape[1] :], dummy_targets[d, :].detach().cpu().numpy(), linestyle="--")

                    dummy_handle = mlines.Line2D([], [], color="none", label=f"Sample {b}")
                    ax.legend(
                        handles=[dummy_handle],
                        loc=legend_loc,
                        handlelength=0,
                        handletextpad=0,
                        fancybox=True,
                        fontsize=legend_font_size,
                    )
                    ax.set_ylim(-0.1, 1.1)
                    ax.set_yticks([])
                    ax.set_xticks([])

            if i == 0:
                subfig.text(0.5, 1.0 + column_offset, column_label, ha="center", va="top")
            if j == 0:
                subfig.text(0.0, 0.5, row_label, ha="left", va="center", fontsize=8, rotation=90)
    if external_fig is None:
        fig.savefig(f"./out/plots/reconstructions/grid_reconstructions_{extra_title}.pdf", bbox_inches="tight")
        plt.show()
    else:
        return fig


def gather_run_gradients(dataset_seed, columns, rows, filters, runs, replace_dict={}):
    def should_skip_run(run):
        if run.config["dataset"] not in dataset_seed or run.config["seed"] != dataset_seed[run.config["dataset"]]:
            return True

        if "dummy_grad_list" not in run.summary or "split_indexes" not in run.summary or "original_grad_list" not in run.summary:
            return True

        continue_next_run = [False for _ in range(len(filters))]
        for i, filter_dict in enumerate(filters):
            if not np.array([v(run.config[k]) if k in run.config else False for k, v in filter_dict.items()]).all():
                continue_next_run[i] = True
        if np.array(continue_next_run).all():
            return True
        return False

    def replace_predictor_name(runs_config, name):
        run_config_value = runs_config[name]
        for key, value in replace_dict.items():
            if isinstance(run_config_value, str):
                run_config_value = run_config_value.replace(key, value)
        if name == "dataset":
            return f"{run_config_value} ({dataset_seed[runs_config[name]]})"
        if name == "model":
            return run_config_value.replace("JitGRU_Predictor", "GRU_Predictor").replace("_Predictor", "")
        return run_config_value

    gradient_data_dict = {}
    if isinstance(runs, list):
        runs = [run for run_list in runs for run in run_list]

    for run in runs:
        if should_skip_run(run):
            continue

        run_column = replace_predictor_name(run.config, columns)
        run_row = replace_predictor_name(run.config, rows)

        grad_dict = {
            "dummy_grad_list": run.summary["dummy_grad_list"],  # List
            "original_grad_list": run.summary["original_grad_list"],  # List
            "split_indexes": run.summary["split_indexes"],  # List
        }
        if (run_column, run_row) not in gradient_data_dict:
            gradient_data_dict[(run_column, run_row)] = grad_dict

    print(gradient_data_dict.keys())

    unique_columns = sorted(list(set([column for column, _ in gradient_data_dict.keys()])))
    unique_rows = sorted(list(set([row for _, row in gradient_data_dict.keys()])))

    return gradient_data_dict, unique_columns, unique_rows


def plot_gradients_in_grid(
    gradients_info_dict,
    unique_columns,
    unique_rows,
    extra_title="",
    legend_loc="upper left",
    column_offset=0.0,
    plot_size_width=2,
    plot_size_height=0.33,
    legend_font_size=7,
):
    fig, axes = plt.subplots(
        len(unique_rows),
        len(unique_columns),
        figsize=(plot_size_width * len(unique_columns), plot_size_height * len(unique_rows)),
        sharey=True,
    )

    if len(unique_rows) == 1 and len(unique_columns) == 1:
        axes = np.array([[axes]])
    elif len(unique_rows) == 1:
        axes = np.array([axes])
    elif len(unique_columns) == 1:
        axes = np.array([axes]).T

    for i, row_label in enumerate(unique_rows):
        min_value, max_value = float("inf"), float("-inf")
        for j, column_label in enumerate(unique_columns):
            key = (column_label, row_label)
            if key in gradients_info_dict:
                grad_dict = gradients_info_dict[key]
                grad_diff = np.abs(np.array(grad_dict["dummy_grad_list"]) - np.array(grad_dict["original_grad_list"]))
                if (grad_min := grad_diff[grad_diff > 0.0].min()) < min_value:
                    min_value = grad_min
                if (grad_max := grad_diff.max()) > max_value:
                    max_value = grad_max

        for j, column_label in enumerate(unique_columns):
            key = (column_label, row_label)
            if key in gradients_info_dict:
                ax = axes[i, j]
                grad_dict = gradients_info_dict[key]
                grad_diff = np.abs(np.array(grad_dict["dummy_grad_list"]) - np.array(grad_dict["original_grad_list"]))

                split_indexes = grad_dict["split_indexes"]
                print(split_indexes)

                x_values = np.arange(len(grad_diff))
                ax.vlines(x_values, min_value, grad_diff, linewidth=0.05, alpha=0.5)
                ax.set_yscale("log")
                ax.set_ylim(min_value * 100, max_value)

                for split_index in split_indexes:
                    ax.axvline(x=split_index, color="r", linestyle="--", linewidth=0.8)

                dummy_handle = mlines.Line2D([1], [1], color="tab:blue", label=f"Gradient Residual")
                ax.legend(
                    handles=[dummy_handle],
                    loc=legend_loc,
                    handlelength=0,
                    handletextpad=0,
                    fancybox=True,
                    fontsize=legend_font_size,
                )

                if i == 0:
                    ax.set_title(column_label)
                if j == 0:
                    ax.set_ylabel(row_label)

    fig.savefig(f"./out/plots/reconstructions/grid_gradients_{extra_title}.pdf", bbox_inches="tight")
    plt.show()


def plot_gradients_histogram_in_grid(
    gradients_info_dict,
    unique_columns,
    unique_rows,
    extra_title="",
    legend_loc="upper left",
    column_offset=0.0,
    plot_size_width=2,
    plot_size_height=0.33,
    legend_font_size=7,
):
    fig, axes = plt.subplots(
        len(unique_rows),
        len(unique_columns),
        figsize=(plot_size_width * len(unique_columns), plot_size_height * len(unique_rows)),
        sharex=True,
        sharey=True,
    )

    if len(unique_rows) == 1 and len(unique_columns) == 1:
        axes = np.array([[axes]])
    elif len(unique_rows) == 1:
        axes = np.array([axes])
    elif len(unique_columns) == 1:
        axes = np.array([axes]).T

    absolute_min = 1e-12
    column_min_max = {}

    for j, column_label in enumerate(unique_columns):
        min_value, max_value = float("inf"), float("-inf")
        for row_label in unique_rows:
            key = (column_label, row_label)
            if key in gradients_info_dict:
                grad_dict = gradients_info_dict[key]
                grad_diff = np.abs(np.array(grad_dict["dummy_grad_list"]) - np.array(grad_dict["original_grad_list"]))
                grad_diff = np.where(grad_diff <= absolute_min, absolute_min, grad_diff)
                if (grad_min := grad_diff[grad_diff > 0.0].min()) < min_value:
                    min_value = grad_min
                if (grad_max := grad_diff.max()) > max_value:
                    max_value = grad_max
        column_min_max[column_label] = (min_value, max_value)

    for i, row_label in enumerate(unique_rows):
        for j, column_label in enumerate(unique_columns):
            key = (column_label, row_label)
            if key in gradients_info_dict:
                ax = axes[i, j]
                grad_dict = gradients_info_dict[key]
                grad_diff = np.abs(np.array(grad_dict["dummy_grad_list"]) - np.array(grad_dict["original_grad_list"]))
                grad_diff = np.where(grad_diff == 0, absolute_min, grad_diff)

                min_value, max_value = column_min_max[column_label]
                bins = np.logspace(np.log10(min_value), np.log10(max_value), 30)
                ax.hist(grad_diff, bins=bins, edgecolor="black")
                ax.set_xscale("log")

                if i == 0:
                    ax.set_title(column_label)
                if j == 0:
                    ax.set_ylabel(row_label)

    fig.savefig(f"./out/plots/reconstructions/grid_gradients_histograms_{extra_title}.pdf", bbox_inches="tight")
    plt.show()


def plot_sorted_gradients_in_grid(
    gradients_info_dict,
    unique_columns,
    unique_rows,
    extra_title="",
    legend_loc="upper left",
    column_offset=0.0,
    plot_size_width=2,
    plot_size_height=0.33,
    legend_font_size=7,
):
    fig, axes = plt.subplots(
        len(unique_rows),
        len(unique_columns),
        figsize=(plot_size_width * len(unique_columns), plot_size_height * len(unique_rows)),
        sharey=True,
    )

    if len(unique_rows) == 1 and len(unique_columns) == 1:
        axes = np.array([[axes]])
    elif len(unique_rows) == 1:
        axes = np.array([axes])
    elif len(unique_columns) == 1:
        axes = np.array([axes]).T

    color_map, all_colors = None, None

    for i, row_label in enumerate(unique_rows):
        for j, column_label in enumerate(unique_columns):
            key = (column_label, row_label)
            if key in gradients_info_dict:
                ax = axes[i, j]
                grad_dict = gradients_info_dict[key]
                grad_diff = np.abs(np.array(grad_dict["dummy_grad_list"]) - np.array(grad_dict["original_grad_list"]))
                split_indexes = grad_dict["split_indexes"]
                if color_map is None:
                    color_map = plt.cm.get_cmap("tab20", len(split_indexes))
                    all_colors = color_map(np.arange(color_map.N))

                split_indexes = [0] + split_indexes
                sorted_indices = np.argsort(grad_diff)
                sorted_grad_diff = grad_diff[sorted_indices]

                color_codes = []
                split_colors = {
                    range_start: all_colors[idx % len(all_colors)] for idx, range_start in enumerate(split_indexes[:-1])
                }
                for idx in sorted_indices:
                    for range_start, range_end in zip(split_indexes[:-1], split_indexes[1:]):
                        if range_start <= idx < range_end:
                            color_codes.append(split_colors[range_start])
                            break
                    else:
                        color_codes.append("tab:gray")

                x_values = np.arange(len(sorted_grad_diff))
                for x, y, color in zip(x_values, sorted_grad_diff, color_codes):
                    ax.vlines(x, 0, y, color=color, linewidth=0.05, alpha=0.5)

                ax.set_yscale("log")

                if i == 0:
                    ax.set_title(column_label)
                if j == 0:
                    ax.set_ylabel(row_label)

    # Create a colorbar axis on the right side of the figure
    cbar_ax = fig.add_axes([0.91, 0.15, 0.02, 0.7])  # Adjust the position as needed
    norm = mcolors.Normalize(vmin=0, vmax=len(split_indexes) - 1)
    sm = plt.cm.ScalarMappable(cmap=color_map, norm=norm)
    sm.set_array([])
    cbar = fig.colorbar(sm, cax=cbar_ax)
    cbar.set_ticks(np.arange(len(split_indexes) - 1))
    cbar.set_ticklabels([f"Range {split_indexes[idx]}-{split_indexes[idx+1]}" for idx in range(len(split_indexes) - 1)])
    cbar.ax.tick_params(labelsize=legend_font_size)

    plt.subplots_adjust(right=0.9)
    fig.savefig(f"./out/plots/reconstructions/grid_sorted_gradients_{extra_title}.pdf", bbox_inches="tight")
    plt.show()


# api = wandb.Api()

In [None]:
api = wandb.Api()
baseline_project_name = "capsar-meijer/ts-inverse_preparation_baselines"
ts_inverse_project_name = "capsar-meijer/ts-inverse_preparation"

# Defenses

In [None]:
experiment_names = ["ts-inverse_defenses_5-3-2025a", "baseline_with_defenses_5-3-2025a"]
runs_baselines = api.runs(baseline_project_name, filters={"$or": [{"config.experiment_name": name} for name in experiment_names]}, include_sweeps=False, per_page=100)
runs_ts_inverse = api.runs(ts_inverse_project_name, filters={"$or": [{"config.experiment_name": name} for name in experiment_names]}, include_sweeps=False, per_page=100)
print(f"Found {len(runs_baselines) + len(runs_ts_inverse)} runs for TS-Inverse and Baselines Comparison")

filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_defenses_5-3-2025a',
        'num_attack_steps': lambda x: x == 5000,
    },
    {
        'experiment_name': lambda x: x == "baseline_with_defenses_5-3-2025a",
    },

]
attack_methods_ordered_list = ['DLG-LBFGS', 'DLG-Adam', 'InvG', 'DIA', 'LTI', 'TS-Inverse']

dataset_seed_dict = {
    'electricity_370': 10,
    'kddcup': 10,
    'london_smartmeter': 10,
    'tno_electricity': 28,
}

runs = [runs_baselines, runs_ts_inverse]

metric_table, u_columns, u_rows, u_variables = gather_final_metrics_by_parameters(columns='dataset', rows='defense_name', 
                                                                                                         metrics=['inputs/smape/mean', 'targets/smape/mean'], 
                                                                                                         variable='attack_method', filters=filters, runs=runs, 
                                                                                                         variables_other_sorted=attack_methods_ordered_list)
print_latex_table_input_target(metric_table, u_columns, u_rows, u_variables)

Found 84 runs for TS-Inverse and Baselines Comparison
Found 72 runs with metrics: {('electricity_370', 'sign', 'DLG-LBFGS'): [[1.999232292175293, 1.9856681823730469], [1.9430434703826904, 1.9910142421722412], [1.9460372924804688, 1.99007511138916]], ('electricity_370', 'sign', 'LTI'): [[0.23065468668937683, 0.10833343863487244], [0.3850267827510834, 0.26284897327423096], [0.3558160662651062, 0.17938193678855896]], ('electricity_370', 'sign', 'DLG-Adam'): [[1.9133542776107788, 1.9255656003952024], [1.929749846458435, 1.9654988050460815], [1.9437367916107176, 1.9631785154342651]], ('electricity_370', 'sign', 'InvG'): [[1.6580898761749268, 0.6611918807029724], [1.6771066188812256, 0.948647141456604], [1.680199384689331, 0.8372775912284851]], ('electricity_370', 'sign', 'DIA'): [[1.6580753326416016, 1.0824203491210938], [1.6771039962768557, 1.3625648021697998], [1.6801718473434448, 1.3370813131332395]], ('electricity_370', 'prune', 'DLG-LBFGS'): [[0.03026920929551125, 0.009130466729402542]

# Baseline comparisons

In [None]:
experiment_names = ["baselines_final_18-4-2024", "baselines_final_dia_lr_schedular_fix_23-4-2024"]
runs_baselines = api.runs(baseline_project_name, filters={"$or": [{"config.experiment_name": name} for name in experiment_names]}, include_sweeps=False, per_page=100)
print(f"Found {len(runs_baselines)} runs for Baseline Comparison")

filters = [
    {
        'experiment_name': lambda x: x == 'baselines_final_18-4-2024',
        'attack_method': lambda x: x != 'DIA',
        # 'model': lambda x: 'FCN' in x or 'CNN' in x or 'TCN' in x,
    },
    {
        'experiment_name': lambda x: x == 'baselines_final_dia_lr_schedular_fix_23-4-2024',
        'attack_method': lambda x: x == 'DIA',
        # 'model': lambda x: 'FCN' in x or 'CNN' in x or 'TCN' in x,
    }
]
# attack_methods_ordered_list = ['DLG-LBFGS', 'DLG-Adam', 'InvG_TV0', 'DIA', 'LTI']

dataset_seed_dict = {
    'electricity_370': 10,
    # 'kddcup': 10,
    # 'london_smartmeter': 10,
    # 'tno_electricity': 28,
}
replace_dict = {
    'electricity_370': 'Electricity 370',
    'london_smartmeter': 'London Sm.',
    'tno_electricity': 'Proprietary',
    'kddcup': 'KDDCup',
    'cosine_dia': 'Cosine',
    'l1': 'L1',
    'InvG_TV0': 'InvG',
}
attack_methods_ordered_list = ['DLG-Adam', 'InvG', 'DIA', 'LTI']

reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed=dataset_seed_dict, columns='attack_method', rows='dataset', variables='model', filters=filters, runs=runs_baselines, replace_dict=replace_dict)
plot_single_batch_reconstructions_in_grid(reconstruction_dict, attack_methods_ordered_list, u_rows, u_variables, extra_title='baselines_attack_method_x_dataset_x_model_fcn_cnn_gru_tcn', legend_loc='lower right', plot_size_width=2.5, plot_size_height=0.35)


attack_methods_ordered_list = ['DLG-LBFGS', 'DLG-Adam', 'InvG', 'DIA', 'LTI']

reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed=dataset_seed_dict, columns='attack_method', rows='dataset', variables='model', filters=filters, runs=runs_baselines, replace_dict=replace_dict)
plot_single_batch_reconstructions_in_grid(reconstruction_dict, attack_methods_ordered_list, u_rows, u_variables, extra_title='baselines_attack_method_x_dataset_x_model_fcn_cnn_gru_tcn_with_lbfgs', legend_loc='lower right', plot_size_width=2.5, plot_size_height=0.45)

## BELOW IS FOR THESIS
dataset_seed_dict = {
    'electricity_370': 10,
    'kddcup': 10,
    'london_smartmeter': 10,
    'tno_electricity': 28,
}
attack_methods_ordered_list = ['DLG-LBFGS', 'DLG-Adam', 'InvG', 'DIA', 'LTI']
reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed=dataset_seed_dict, columns='attack_method', rows='dataset', variables='model', filters=filters, runs=runs_baselines, replace_dict=replace_dict)
plot_single_batch_reconstructions_in_grid(reconstruction_dict, attack_methods_ordered_list, u_rows, u_variables, extra_title='all_baselines_attack_method_x_dataset_x_model_fcn_cnn_gru_tcn', legend_loc='lower right', plot_size_width=2.5, plot_size_height=0.42)

In [None]:
experiment_names = ["baselines_attacking_seq2seq_24-5-2024"]
runs_baselines = api.runs(baseline_project_name, filters={"$or": [{"config.experiment_name": name} for name in experiment_names]}, include_sweeps=False, per_page=100)
print(f"Found {len(runs_baselines)} runs for Seq-2-Seq Comparison")

filters = [
    {
        'experiment_name': lambda x: x == 'baselines_attacking_seq2seq_24-5-2024',
    },
    {
        'experiment_name': lambda x: x == 'baselines_attacking_seq2seq_24-5-2024',
        'attack_method': lambda x: 'LTI' in x,
        'validation_stride': lambda x: x == 1, # Otherwise the dataset is already biased towards only correct interval of day
    }
]

# metric_table, unique_columns, unique_rows, unique_variables = gather_final_metrics_from_reconstructions_by_parameters(columns='dataset', rows='model', 
#                                                                                                          metric=SMAPELoss, 
#                                                                                                          variable='attack_method', filters=filters, runs=runs)
# print_latex_table_input_target(metric_table, unique_columns, unique_rows, unique_variables, variable_name='Attack Method')

dataset_seed_dict = {
    'electricity_370': 10,
    # 'kddcup': 10,
    # 'london_smartmeter': 28,
    'tno_electricity': 28,
}
replace_dict = {
    # 'JitGRU': 'GRU-2-FCN',
    'InvG_TV0': 'InvG',
    'GRU': 'GRU-2-FCN',
    'JitGRU-2-FCN': 'GRU-2-FCN',
    'JitSeq2Seq': 'GRU-2-GRU',
    'electricity_370': 'Elec. 370',
    'tno_electricity': 'Proprietary',
}
reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed=dataset_seed_dict, columns='model', rows='dataset', variables='attack_method', filters=filters, runs=runs_baselines, replace_dict=replace_dict)
plot_single_batch_reconstructions_in_grid(reconstruction_dict, u_columns, u_rows, u_variables, extra_title='attacking_seq2seq_24-05-2024', legend_loc='upper right', plot_size_height=0.35, plot_size_width=2.5)

In [None]:
experiment_names = ["baselines_invg_tv_regularization_ccn_8-6-2024"]
runs_baselines = api.runs(baseline_project_name, filters={"$or": [{"config.experiment_name": name} for name in experiment_names]}, include_sweeps=False, per_page=100)
print(f"Found {len(runs_baselines)} runs for Total Variation Regularization Comparison")

filters = [
    {
        'experiment_name': lambda x: x == 'baselines_invg_tv_regularization_ccn_8-6-2024',
    },
]

metric_table, u_columns, u_rows, u_variables = gather_final_metrics_by_parameters(columns='dataset', rows='total_variation_alpha_inputs', 
                                                                                                         metrics=['inputs/smape/mean', 'targets/smape/mean'], 
                                                                                                         variable='total_variation_beta_targets', filters=filters, runs=runs_baselines)
print_latex_table_input_target(metric_table, u_columns, u_rows, u_variables)

replace_dict = {
    'electricity_370': '370',
    'tno_electricity': 'Prop.',
}

dataset_seed_dict = {
    'electricity_370': 10,
    'tno_electricity': 28,
}
reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed=dataset_seed_dict, columns='total_variation_beta_targets', rows='total_variation_alpha_inputs', variables='dataset', filters=filters, runs=runs_baselines, replace_dict=replace_dict)
plot_single_batch_reconstructions_in_grid(reconstruction_dict, u_columns, u_rows, u_variables, extra_title='invg_total_variation_effects', legend_loc='lower right', plot_size_height=0.5, plot_size_width=3)

# Table with TS-Inverse and baselines

In [None]:
experiment_names = ["baselines_final_18-4-2024", "baselines_final_dia_lr_schedular_fix_23-4-2024", "ts-inverse_batch1_with_target_reconst_12-6-2024", "ts-inverse_final_cnn_fcn_tcn_with_dummy_init_prior_31-5-2024", "ts-inverse_batch1_without_target_reconst_12-6-2024"]
runs_baselines = api.runs(baseline_project_name, filters={"$or": [{"config.experiment_name": name} for name in experiment_names]}, include_sweeps=False, per_page=100)
runs_ts_inverse = api.runs(ts_inverse_project_name, filters={"$or": [{"config.experiment_name": name} for name in experiment_names]}, include_sweeps=False, per_page=100)
print(f"Found {len(runs_baselines) + len(runs_ts_inverse)} runs for TS-Inverse and Baselines Comparison")

filters = [
    {
        'experiment_name': lambda x: x == 'baselines_final_18-4-2024',
        'attack_method': lambda x: x != 'DIA',
        'model': lambda x: 'FCN' in x or 'CNN' in x or 'TCN' in x,
    },
    {
        'experiment_name': lambda x: x == 'baselines_final_dia_lr_schedular_fix_23-4-2024',
        'attack_method': lambda x: x == 'DIA',
        'model': lambda x: 'FCN' in x or 'CNN' in x or 'TCN' in x,
    },
    # {
    #     'experiment_name': lambda x: x == 'ts-inverse_batch1_with_target_reconst_12-6-2024', # L1, FCN, CNN, TCN,
    #     'one_shot_targets': lambda x: x == True,
    # },
    # {
    #     'experiment_name': lambda x: x == 'ts-inverse_final_cnn_fcn_tcn_with_dummy_init_prior_31-5-2024',
    # }, 
    {
        'experiment_name': lambda x: x == 'ts-inverse_batch1_without_target_reconst_12-6-2024',
    }

]
attack_methods_ordered_list = ['DLG-LBFGS', 'DLG-Adam', 'InvG_TV0', 'DIA', 'LTI', 'TS-Inverse']

dataset_seed_dict = {
    'electricity_370': 10,
    'kddcup': 10,
    'london_smartmeter': 10,
    'tno_electricity': 28,
}

runs = [runs_baselines, runs_ts_inverse]

metric_table, u_columns, u_rows, u_variables = gather_final_metrics_by_parameters(columns='dataset', rows='model', 
                                                                                                         metrics=['inputs/smape/mean', 'targets/smape/mean'], 
                                                                                                         variable='attack_method', filters=filters, runs=runs, 
                                                                                                         variables_other_sorted=attack_methods_ordered_list)
print_latex_table_input_target(metric_table, u_columns, u_rows, u_variables)

In [None]:
experiment_names = ["ts-inverse_final_cnn_fcn_tcn_with_dummy_init_prior_31-5-2024", "ts-inverse_batch1_with_target_reconst_12-6-2024", "ts-inverse_batch1_without_target_reconst_12-6-2024"]
runs_ts_inverse = api.runs(ts_inverse_project_name, filters={"$or": [{"config.experiment_name": name} for name in experiment_names]})
print(f"Found {len(runs_ts_inverse)} runs for TS-Inverse Results on Datasets")

filters = [
    # {
    #     'experiment_name': lambda x: x == 'ts-inverse_final_gru_mse_learn_dilate_opti_1-5-2024', # GRU Results
    #     'inversion_regularization_term': lambda x: x == 0.01,
    # },
    # {
    #     'experiment_name': lambda x: x == 'ts-inverse_final_cnn_fcn_tcn_with_dummy_init_prior_31-5-2024' # L1, Dummy Prior and FCN, CNN, TCN
    # },
    # {
    #     'experiment_name': lambda x: x == 'ts-inverse_final_fcn_cnn_tcn_26-4-2024', # FCN, CNN, TCN Results
    #     'gradient_loss': lambda x: x == 'l1',
    # }
    {
        'experiment_name': lambda x: x == 'ts-inverse_batch1_with_target_reconst_12-6-2024', # L1, FCN, CNN, TCN,
        'one_shot_targets': lambda x: x == True,
    },
    {
        'experiment_name': lambda x: x == 'ts-inverse_batch1_without_target_reconst_12-6-2024',
        'one_shot_targets': lambda x: x == False,
    }
]

# attack_methods_ordered_list = ['TS-Inverse']

replace_dict = {
    'electricity_370': 'Electricity 370',
    'london_smartmeter': 'London Smartmeter',
    'tno_electricity': 'Proprietary Dataset',
    'kddcup': 'KDDCup',
}

for seed in [10, 43, 28, 80, 71]:
    dataset_seed_dict = {
        'electricity_370': seed,
        'kddcup': seed,
        'london_smartmeter': seed,
        'tno_electricity': seed,
    }
    replace_dict = {
        'electricity_370': 'Electricity 370',
        'london_smartmeter': 'London Smartmeter',
        'tno_electricity': 'Proprietary Dataset',
        'kddcup': 'KDDCup',
        'True': f'With Target ({seed})',
        'False': f'Without Target ({seed})',
    }

    reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed=dataset_seed_dict, columns='dataset', rows='one_shot_targets', variables='model', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
    plot_single_batch_reconstructions_in_grid(reconstruction_dict, u_columns, u_rows, u_variables, extra_title=f'ts-inverse_dataset_x_model_{seed}', legend_loc='lower right', plot_size_height=0.42, plot_size_width=2.5, column_offset=0.05)

# Quantile Plots

In [None]:
import matplotlib.patches as mpatches


def gather_run_quantile_predictions(dataset_seed, columns, rows, variables, filters, runs, replace_dict={}):
    def dataframe_keys_and_should_skip_run(run):
        if run.config["dataset"] not in dataset_seed or run.config["seed"] != dataset_seed[run.config["dataset"]]:
            return [], True

        dataframe_keys = []
        for key in run.summary.keys():
            if "dataframe_quantile" in key:
                dataframe_keys.append(key)
        if len(dataframe_keys) == 0:
            return [], True

        continue_next_run = [False for _ in range(len(filters))]
        for i, filter_dict in enumerate(filters):
            if not np.array([v(run.config[k]) for k, v in filter_dict.items() if k in run.config]).all():
                continue_next_run[i] = True
        if np.array(continue_next_run).all():
            return [], True
        return dataframe_keys, False

    def replace_predictor_name(runs_config, name):
        run_config_value = runs_config[name]
        for key, value in replace_dict.items():
            if isinstance(run_config_value, str):
                run_config_value = run_config_value.replace(key, value)
        if name == "dataset":
            return f"{run_config_value} ({dataset_seed[runs_config[name]]})"
        if name == "model":
            return run_config_value.replace("JitGRU_Predictor", "GRU_Predictor").replace("_Predictor", "")
        return run_config_value

    reconstructed_data_dict = {}
    if isinstance(runs, list):
        runs = [run for run_list in runs for run in run_list]

    quantiles = [-1]
    for run in runs:
        dataframe_keys, should_skip_run = dataframe_keys_and_should_skip_run(run)
        if should_skip_run:
            continue

        run_column = replace_predictor_name(run.config, columns)
        run_row = replace_predictor_name(run.config, rows)
        run_variable = replace_predictor_name(run.config, variables)

        dataframe_info = run.summary[dataframe_keys[0]]
        quantiles = run.config["quantiles"] if "quantiles" in run.config else [-1]
        file = run.file(dataframe_info["path"])
        if not os.path.exists(file.name):
            file.download(replace=True)
        with open(file.name, "r") as f:
            data = json.load(f)
        reconstructed_data = pd.DataFrame(data=data["data"], columns=data["columns"])
        if (run_column, run_row, run_variable) not in reconstructed_data_dict:
            reconstructed_data_dict[(run_column, run_row, run_variable)] = reconstructed_data

    print(reconstructed_data_dict.keys())
    if quantiles == [-1]:
        print("No Quantiles found in the run!")

    unique_columns = sorted(list(set([column for column, _, _ in reconstructed_data_dict.keys()])))
    unique_rows = sorted(list(set([row for _, row, _ in reconstructed_data_dict.keys()])))
    unique_variables = sorted(list(set([variable for _, _, variable in reconstructed_data_dict.keys()])))

    return reconstructed_data_dict, unique_columns, unique_rows, unique_variables, quantiles


def plot_single_batch_quantiles_in_grid(
    series_dict,
    unique_columns,
    unique_rows,
    batch_size,
    quantiles,
    extra_title="",
    legend_loc="upper left",
    column_offset=0.0,
    plot_size_width=2,
    plot_size_height=0.33,
    legend_font_size=7,
):
    if not series_dict:
        print("No series given, skipping plotting")
        return
    
    fig = plt.figure(figsize=(plot_size_width * len(unique_columns), plot_size_height * len(unique_rows) * len(batch_size)))
    subfigs = fig.subfigures(len(unique_rows), len(unique_columns))
    if len(unique_rows) == 1 and len(unique_columns) == 1:
        subfigs = np.array([[subfigs]])
    elif len(unique_rows) == 1:
        subfigs = np.array([subfigs])
    elif len(unique_columns) == 1:
        subfigs = np.array([subfigs]).T

    colors = plt.cm.tab10.colors
    for i, row_label in enumerate(unique_rows):
        for j, column_label in enumerate(unique_columns):
            subfig = subfigs[i, j]
            subfig.subplots_adjust(wspace=0, hspace=0, left=0.05, right=0.99, top=0.95, bottom=0.01)
            axes_1 = subfig.subplots(len(batch_size), 1, sharex=True, sharey=True)
            if len(batch_size) == 1:
                axes_1 = [axes_1]
            for k, variable in enumerate(batch_size):
                if (column_label, row_label, variable) in series_dict:
                    reconstructed_data = series_dict[(column_label, row_label, variable)]
                    assert (
                        len(reconstructed_data.columns) == (2 * len(quantiles) + 2) * variable
                    ), "Expected (2*len(quantiles) + 2)*variable columns in the dataframe"

                    for b in range(variable):
                        df_plot = reconstructed_data[[f"batch_inputs_{b}_0", f"batch_targets_{b}"]]
                        color_offset = 0
                        batch_colors = [colors[0 + color_offset], colors[2 + color_offset]]
                        df_plot.plot(ax=axes_1[k], legend=False, color=batch_colors, linewidth=0.8)

                    q_c_offset = 1
                    for q in range(len(quantiles) // 2):
                        input_quantiles = reconstructed_data[
                            [f"dummy_quantile_inputs_0_0_{q}", f"dummy_quantile_inputs_0_0_{len(quantiles)-q-1}"]
                        ]
                        target_quantiles = reconstructed_data[
                            [f"dummy_quantile_targets_0_{q}", f"dummy_quantile_targets_0_{len(quantiles)-q-1}"]
                        ]
                        # plot between the quantile columns (which are in pairs and opposite)
                        axes_1[k].fill_between(
                            input_quantiles.index,
                            input_quantiles.iloc[:, 0],
                            input_quantiles.iloc[:, 1],
                            color=colors[q + q_c_offset],
                            alpha=0.5,
                        )
                        axes_1[k].fill_between(
                            target_quantiles.index,
                            target_quantiles.iloc[:, 0],
                            target_quantiles.iloc[:, 1],
                            color=colors[q + q_c_offset],
                            alpha=0.5,
                        )
                        q_c_offset += 3

                    q_c_offset = 1
                    quantile_handles = []
                    for q in range(len(quantiles) // 2):
                        quantile_handles.append(
                            mpatches.Patch(
                                color=colors[q + q_c_offset],
                                label=f"Quantile {quantiles[q]} - {quantiles[len(quantiles)-q-1]}",
                                alpha=0.8,
                            )
                        )
                        q_c_offset += 3
                        # quantile_handles.append(mlines.Line2D([], [], color=colors[q], label=f'Quantile {quantiles[q]} - {quantiles[len(quantiles)-q-1]}', linestyle=''))
                    axes_1[k].legend(handles=quantile_handles, loc=legend_loc, fancybox=True, fontsize=legend_font_size)

                # axes_1[k].set_ylim(-0.1, 1.1)
                axes_1[k].set_yticks([])
                axes_1[k].set_xticks([])

            if i == 0:
                subfig.text(0.5, 1.0 + column_offset, column_label, ha="center", va="top")
            if j == 0:
                subfig.text(0.0, 0.5, row_label, ha="left", va="center", fontsize=8, rotation=90)

    fig.savefig(f"./out/plots/reconstructions/grid_reconstructions_{extra_title}.pdf", bbox_inches="tight")
    plt.show()


# Show the learned priors in terms of quantiles.
experiment_names = ["ts-inverse_quantile_plots_11-6-2024", "ts-inverse_quantile_plots_13-6-2024"]
runs_ts_inverse = api.runs(
    ts_inverse_project_name, filters={"$or": [{"config.experiment_name": name} for name in experiment_names]}
)
print(f"Found {len(runs_ts_inverse)} runs for Quantile Plots")

replace_dict = {
    "electricity_370": "Electricity 370",
    "london_smartmeter": "London Smartmeter",
    "tno_electricity": "Proprietary",
    "kddcup": "KDDCup",
    "cosine_dia": "Cosine",
    "l1": "L1",
    "euclidean": "Euclidean",
    "1_norm_1_cosine": "Cosine + Norm",
    "JitGRU": "GRU-2-FCN",
    "JitSeq2Seq": "GRU-2-GRU",
}
dataset_seed_dict = {
    "electricity_370": 10,
    "london_smartmeter": 10,
    # 'tno_electricity': 10,
}

filters = [
    {
        "experiment_name": lambda x: x == "ts-inverse_quantile_plots_13-6-2024",
        "model": lambda x: "Jit" not in x and not "FCN" in x,
        "dataset": lambda x: x == "electricity_370",
    }
]

reconstruction_dict, u_columns, u_rows, u_variables, quantiles = gather_run_quantile_predictions(
    dataset_seed=dataset_seed_dict,
    columns="dataset",
    rows="model",
    variables="batch_size",
    filters=filters,
    runs=runs_ts_inverse,
    replace_dict=replace_dict,
)
plot_single_batch_quantiles_in_grid(
    reconstruction_dict,
    u_columns,
    u_rows,
    u_variables,
    quantiles,
    extra_title="ts-inverse_ts_learned_quantiles_batch4",
    plot_size_width=5,
    plot_size_height=0.8,
    legend_loc="best",
    column_offset=0.15,
)


## BELOW IS FOR 1 BATCH and THESIS
filters = [
    {
        "experiment_name": lambda x: x == "ts-inverse_quantile_plots_11-6-2024",
        "model": lambda x: "Jit" not in x and not "FCN" in x,
    }
]

reconstruction_dict, u_columns, u_rows, u_variables, quantiles = gather_run_quantile_predictions(
    dataset_seed=dataset_seed_dict,
    columns="dataset",
    rows="model",
    variables="batch_size",
    filters=filters,
    runs=runs_ts_inverse,
    replace_dict=replace_dict,
)
plot_single_batch_quantiles_in_grid(
    reconstruction_dict,
    u_columns,
    u_rows,
    u_variables,
    quantiles,
    extra_title="ts-inverse_ts_learned_quantiles",
    plot_size_width=3.5,
    plot_size_height=0.8,
    legend_loc="best",
    column_offset=0.15,
)


## BELOW IS FOR THESIS
dataset_seed_dict = {
    "electricity_370": 10,
    "london_smartmeter": 10,
    "tno_electricity": 10,
}

filters = [
    {
        "experiment_name": lambda x: x == "ts-inverse_quantile_plots_11-6-2024",
    }
]

reconstruction_dict, u_columns, u_rows, u_variables, quantiles = gather_run_quantile_predictions(
    dataset_seed=dataset_seed_dict,
    columns="dataset",
    rows="model",
    variables="batch_size",
    filters=filters,
    runs=runs_ts_inverse,
    replace_dict=replace_dict,
)
plot_single_batch_quantiles_in_grid(
    reconstruction_dict,
    u_columns,
    u_rows,
    u_variables,
    quantiles,
    extra_title="ts-inverse_ts_learned_quantiles",
    plot_size_width=3.5,
    plot_size_height=0.8,
    legend_loc="best",
    column_offset=0.15,
)

In [None]:
experiment_names = ["ts-inverse_defenses_4-3-2025a"]
runs_ts_inverse = api.runs(ts_inverse_project_name, filters={"$or": [{"config.experiment_name": name} for name in experiment_names]})
filters = [
    {
        'experiment_name': lambda x: x == "ts-inverse_defenses_4-3-2025a",
        'model': lambda x: 'FCN' in x,
        'dataset': lambda x: x == 'electricity_370',
    }
]

print(len(runs_ts_inverse))

reconstruction_dict, u_columns, u_rows, u_variables, quantiles  = gather_run_quantile_predictions(dataset_seed=dataset_seed_dict, columns='attack_method', rows='defense_name', variables='batch_size', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
plot_single_batch_quantiles_in_grid(reconstruction_dict, u_columns, u_rows, u_variables, quantiles, extra_title='ts-inverse_defenses', plot_size_width=5, plot_size_height=0.8, legend_loc='best', column_offset=0.15)

# Quantile Regularization

Comparison between two regularization techniques using the quantiles. The first one, "quantile" uses the pinball loss of the quantiles and the dummy data. The "quantile$_{\text{bounds}}$" regularizes the dummy data with the L1 loss if the value is outside the quantile bounds.

In [None]:
# Show the learned priors in terms of quantiles.
experiment_names = ["ts-inverse_quantile_bounds_regularization_11-6-2024"]
runs_ts_inverse = api.runs(ts_inverse_project_name, filters={"$or": [{"config.experiment_name": name} for name in experiment_names]})
print(f"Found {len(runs_ts_inverse)} runs for Learned Prior Regularization")

filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_quantile_bounds_regularization_11-6-2024',
        'inversion_regularization_loss': lambda x: x == 'quantile_bounds',
    }
]

metric_table, u_columns, u_rows, u_variables = gather_final_metrics_by_parameters(columns='inversion_regularization_term_inputs', rows='dataset', 
                                                                                                         metrics=['inputs/smape/mean', 'targets/smape/mean'], 
                                                                                                         variable='inversion_regularization_term_targets', filters=filters, runs=runs_ts_inverse)
print_latex_table_input_target(metric_table, u_columns, u_rows, u_variables, variable_name='Loss')


filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_quantile_bounds_regularization_11-6-2024',
    }
]

metric_table, u_columns, u_rows, u_variables = gather_final_metrics_by_parameters(columns='inversion_regularization_term_inputs', rows='inversion_regularization_term_targets', 
                                                                                                         metrics=['inputs/smape/mean', 'targets/smape/mean'], 
                                                                                                         variable='inversion_regularization_loss', filters=filters, runs=runs_ts_inverse)
print_latex_table_input_target(metric_table, u_columns, u_rows, u_variables, variable_name='Loss')

replace_dict = {
    'electricity_370': 'Electricity 370',
    'london_smartmeter': 'London Smartmeter',
    'tno_electricity': 'Proprietary',
    'kddcup': 'KDDCup',
    'cosine_dia': 'Cosine',
    'l1': 'L1',
    'euclidean': 'Euclidean',
    '1_norm_1_cosine': 'Cosine + Norm'
}
dataset_seed_dict = {
    'electricity_370': 10,
}
filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_quantile_bounds_regularization_11-6-2024',
        'inversion_regularization_loss': lambda x: x == 'quantile_bounds',
    }
]
reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed=dataset_seed_dict, columns='inversion_regularization_term_targets', rows='inversion_regularization_term_inputs', variables='batch_size', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
plot_multi_batch_reconstructions_in_grid(reconstruction_dict, u_columns, u_rows, u_variables, extra_title='ts-inverse_ts_learned_quantile_bounds_regularization', plot_size_width=2.5, plot_size_height=0.45, legend_loc='best')
# ## ^^ IN THE PAPER

# # ## BELOW IS FOR THESIS

filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_quantile_bounds_regularization_11-6-2024',
        'inversion_regularization_loss': lambda x: x == 'quantile',
    }
]
reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed=dataset_seed_dict, columns='inversion_regularization_term_targets', rows='inversion_regularization_term_inputs', variables='batch_size', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
plot_multi_batch_reconstructions_in_grid(reconstruction_dict, u_columns, u_rows, u_variables, extra_title='ts-inverse_ts_learned_quantile_regularization', plot_size_width=2.5, plot_size_height=0.45, legend_loc='best')

# TS Regularizations

## Periodicity Regularization

In [None]:
# IN THE PAPER
experiment_names = ["ts-inverse_ts_regularization_11-6-2024"]
runs_ts_inverse = api.runs(ts_inverse_project_name, filters={"$or": [{"config.experiment_name": name} for name in experiment_names]})
print(f"Found {len(runs_ts_inverse)} runs for TS Regularization")

filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_ts_regularization_11-6-2024',
        'trend_loss': lambda x: x == 'l1_mean',
        'trend_term': lambda x: x == 0,
        'periodicity_loss': lambda x: 'mean' in x,
    }
]

# metric_table, u_columns, u_rows, u_variables = gather_final_metrics_by_parameters(columns='dataset', rows='periodicity_loss', 
#                                                                                                          metrics=['inputs/smape/mean', 'targets/smape/mean'], 
#                                                                                                          variable='periodicity_term', filters=filters, runs=runs_ts_inverse)
metric_table, u_columns, u_rows, u_variables = gather_final_metrics_from_reconstructions_by_parameters(columns='dataset', rows='periodicity_loss', 
                                                                                                         metric=SMAPELoss,
                                                                                                         variable='periodicity_term', filters=filters, runs=runs_ts_inverse)

print_latex_table_input_target(metric_table, u_columns, u_rows, u_variables, variable_name='Reg. Term.')


filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_ts_regularization_11-6-2024',
        'trend_loss': lambda x: x == 'l1_mean',
        'trend_term': lambda x: x == 0,
        'periodicity_loss': lambda x: x == 'l1_mean',
        'dataset': lambda x: x == 'electricity_370',
        'periodicity_term': lambda x: x == 2 or x == 0,
    },
    {
        'experiment_name': lambda x: x == 'ts-inverse_ts_regularization_11-6-2024',
        'trend_loss': lambda x: x == 'l1_mean',
        'trend_term': lambda x: x == 0,
        'periodicity_loss': lambda x: x == 'l1_mean',
        'dataset': lambda x: x == 'london_smartmeter',
        'periodicity_term': lambda x: x == 0.5 or x == 0,
    },
    # {
    #     'experiment_name': lambda x: x == 'ts-inverse_ts_regularization_11-6-2024',
    #     'trend_loss': lambda x: x == 'l1_mean',
    #     'trend_term': lambda x: x == 0,
    #     'periodicity_loss': lambda x: x == 'l1_mean',
    #     'dataset': lambda x: x == 'tno_electricity',
    #     'periodicity_term': lambda x: x == 1,
    # }
]
replace_dict = {
    'l1_mean': '',
    'electricity_370': 'Electricity 370',
    'london_smartmeter': 'London Smartmeter',
}
dataset_seed_dict = {
    'electricity_370': 43,
    'london_smartmeter': 10,
}
def plot_multiple_larger_columns_multi_batch_next_to_each_other(dataset_seed_dict):
    # Create multiple subfigs next to each other and add the plots to the subfigs
    fig = plt.figure(figsize=(3*len(dataset_seed_dict), 3.8))
    subfigs = fig.subfigures(1, len(dataset_seed_dict))
    for i, (dataset, seed) in enumerate(dataset_seed_dict.items()):
        reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed={dataset: seed}, columns='periodicity_loss', rows='periodicity_term', variables='batch_size', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
        subfigs[i] = plot_multi_batch_reconstructions_in_grid(reconstruction_dict, u_columns, u_rows, u_variables, extra_title='ts-inverse_ts_periodicity_regularization_specific', plot_size_width=2.5, plot_size_height=0.45, legend_loc='best', external_fig=subfigs[i])
        subfigs[i].text(0.5, 1.01, f"{replace_dict[dataset]} ({seed})", ha='center', va='top')
    fig.savefig(f'./out/plots/reconstructions/grid_reconstructions_ts-inverse_ts_periodicity_regularization_specific.pdf', bbox_inches='tight')
    plt.show()

plot_multiple_larger_columns_multi_batch_next_to_each_other(dataset_seed_dict)

## BELOW IS FOR THESIS
filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_ts_regularization_11-6-2024',
        'trend_loss': lambda x: x == 'l1_mean',
        'trend_term': lambda x: x == 0,
        'periodicity_loss': lambda x: 'mean' in x,
    }
]
replace_dict = {
    'electricity_370': 'Electricity 370',
    'london_smartmeter': 'London Smartmeter',
    'tno_electricity': 'Proprietary',
    'kddcup': 'KDDCup',
    'cosine_dia': 'Cosine',
    'l1_mean': 'L1',
    'l2_mean': 'L2',
    'euclidean': 'Euclidean',
    '1_norm_1_cosine': 'Cosine + Norm'
}

dataset_seed_dict = {
    'electricity_370': 10,
}
reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed=dataset_seed_dict, columns='periodicity_term', rows='periodicity_loss', variables='batch_size', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
plot_multi_batch_reconstructions_in_grid(reconstruction_dict, u_columns, u_rows, u_variables, extra_title='ts-inverse_ts_periodicity_regularization', plot_size_width=2.5, plot_size_height=0.45, legend_loc='best')

dataset_seed_dict = {
    'electricity_370': 10,
    'london_smartmeter': 10,
    'tno_electricity': 10,
}
#PLOT ALL TERMS AND LOSSES ON DATASETS (specific seeds)
def plot_multiple_larger_columns_multi_batch_next_to_each_other(dataset_seed_dict):
    # Create multiple subfigs next to each other and add the plots to the subfigs
    fig = plt.figure(figsize=(5*len(dataset_seed_dict), 10))
    subfigs = fig.subfigures(1, len(dataset_seed_dict))
    for i, (dataset, seed) in enumerate(dataset_seed_dict.items()):
        reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed={dataset: seed}, columns='periodicity_loss', rows='periodicity_term', variables='batch_size', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
        subfigs[i] = plot_multi_batch_reconstructions_in_grid(reconstruction_dict, u_columns, u_rows, u_variables, extra_title='ts-inverse_ts_periodicity_regularization_large', plot_size_width=2.5, plot_size_height=0.45, legend_loc='best', external_fig=subfigs[i])
        subfigs[i].text(0.5, 1.01, f"{replace_dict[dataset]} ({seed})", ha='center', va='top')
    fig.savefig(f'./out/plots/reconstructions/grid_reconstructions_ts-inverse_ts_periodicity_regularization_large.pdf', bbox_inches='tight')
    plt.show()

plot_multiple_larger_columns_multi_batch_next_to_each_other(dataset_seed_dict)

## Trend Regularization

In [None]:
# IN THE PAPER
experiment_names = ["ts-inverse_ts_regularization_11-6-2024"]
runs_ts_inverse = api.runs(ts_inverse_project_name, filters={"$or": [{"config.experiment_name": name} for name in experiment_names]})
print(f"Found {len(runs_ts_inverse)} runs for TS Regularization")

filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_ts_regularization_11-6-2024',
        'trend_loss': lambda x: not 'sum' in x,
        'periodicity_term': lambda x: x == 0,
    }
]

metric_table, u_columns, u_rows, u_variables = gather_final_metrics_by_parameters(columns='dataset', rows='trend_loss', 
                                                                                                         metrics=['inputs/smape/mean', 'targets/smape/mean'], 
                                                                                                         variable='trend_term', filters=filters, runs=runs_ts_inverse)
print_latex_table_input_target(metric_table, u_columns, u_rows, u_variables, variable_name='Reg. Term.')


filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_ts_regularization_11-6-2024',
        'trend_loss': lambda x: x == 'l1_mean',
        'dataset': lambda x: x == 'electricity_370',
        'trend_term': lambda x: x == 2,
    },
    {
        'experiment_name': lambda x: x == 'ts-inverse_ts_regularization_11-6-2024',
        'trend_loss': lambda x: x == 'l1_mean',
        'dataset': lambda x: x == 'london_smartmeter',
        'trend_term': lambda x: x == 0.5,
    },
    # {
    #     'experiment_name': lambda x: x == 'ts-inverse_ts_regularization_11-6-2024',
    #     'trend_loss': lambda x: x == 'l1_mean',
    #     'dataset': lambda x: x == 'tno_electricity',
    #     'trend_term': lambda x: x == 1,
    # }
]

replace_dict = {
    'l1_mean': '',
    'electricity_370': 'Electricity 370',
    'london_smartmeter': 'London Smartmeter',
}
dataset_seed_dict = {
    'electricity_370': 43,
    'london_smartmeter': 10,
}
def plot_multiple_larger_columns_multi_batch_next_to_each_other(dataset_seed_dict):
    # Create multiple subfigs next to each other and add the plots to the subfigs
    fig = plt.figure(figsize=(3*len(dataset_seed_dict), 1.8))
    subfigs = fig.subfigures(1, len(dataset_seed_dict))
    for i, (dataset, seed) in enumerate(dataset_seed_dict.items()):
        reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed={dataset: seed}, columns='trend_loss', rows='trend_term', variables='batch_size', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
        subfigs[i] = plot_multi_batch_reconstructions_in_grid(reconstruction_dict, u_columns, u_rows, u_variables, extra_title='ts-inverse_ts_trend_regularization_specific', plot_size_width=2.5, plot_size_height=0.45, legend_loc='best', external_fig=subfigs[i])
        subfigs[i].text(0.5, 1.01, f"{replace_dict[dataset]} ({seed})", ha='center', va='top')
    fig.savefig(f'./out/plots/reconstructions/grid_reconstructions_ts-inverse_ts_trend_regularization_specific.pdf', bbox_inches='tight')
    plt.show()

plot_multiple_larger_columns_multi_batch_next_to_each_other(dataset_seed_dict)

## BELOW IS FOR THESIS
filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_ts_regularization_11-6-2024',
        'trend_loss': lambda x: not 'sum' in x,
        'periodicity_term': lambda x: x == 0,
    }
]
replace_dict = {
    'electricity_370': 'Electricity 370',
    'london_smartmeter': 'London Smartmeter',
    'tno_electricity': 'Proprietary',
    'kddcup': 'KDDCup',
    'cosine_dia': 'Cosine',
    'l1_mean': 'L1',
    'l2_mean': 'L2',
    'euclidean': 'Euclidean',
    '1_norm_1_cosine': 'Cosine + Norm'
}
dataset_seed_dict = {
    'electricity_370': 10,
}
reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed=dataset_seed_dict, columns='trend_term', rows='trend_loss', variables='batch_size', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
plot_multi_batch_reconstructions_in_grid(reconstruction_dict, u_columns, u_rows, u_variables, extra_title='ts-inverse_ts_trend_regularization', plot_size_width=2.5, plot_size_height=0.45, legend_loc='best')
# ## ^^ IN THE PAPER

dataset_seed_dict = {
    'electricity_370': 10,
    'london_smartmeter': 10,
    'tno_electricity': 10,
}
# ## LARGE PLOT FOR THESIS
def plot_multiple_larger_columns_multi_batch_next_to_each_other(dataset_seed_dict):
    # Create multiple subfigs next to each other and add the plots to the subfigs
    fig = plt.figure(figsize=(5*len(dataset_seed_dict), 10))
    subfigs = fig.subfigures(1, len(dataset_seed_dict))
    for i, (dataset, seed) in enumerate(dataset_seed_dict.items()):
        reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed={dataset: seed}, columns='trend_loss', rows='trend_term', variables='batch_size', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
        subfigs[i] = plot_multi_batch_reconstructions_in_grid(reconstruction_dict, u_columns, u_rows, u_variables, extra_title='ts-inverse_ts_trend_regularization_large', plot_size_width=2.5, plot_size_height=0.45, legend_loc='best', external_fig=subfigs[i])
        subfigs[i].text(0.5, 1.01, f"{replace_dict[dataset]} ({seed})", ha='center', va='top')
    fig.savefig(f'./out/plots/reconstructions/grid_reconstructions_ts-inverse_ts_trend_regularization_large.pdf', bbox_inches='tight')
    plt.show()

plot_multiple_larger_columns_multi_batch_next_to_each_other(dataset_seed_dict)

# Combined Regularizations


In [None]:
# IN THE PAPER
experiment_names = ["ts-inverse_combinations_regularization_12-6-2024"]
runs_ts_inverse = api.runs(ts_inverse_project_name, filters={"$or": [{"config.experiment_name": name} for name in experiment_names]})
print(f"Found {len(runs_ts_inverse)} runs for TS Regularization Tuning")

filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_combinations_regularization_12-6-2024',
    }
]

metric_table, u_columns, u_rows, u_variables = gather_final_metrics_by_parameters(columns='periodicity_term', rows='trend_term', 
                                                                                                         metrics=['inputs/smape/mean', 'targets/smape/mean'], 
                                                                                                         variable='attack_method', filters=filters, runs=runs_ts_inverse)
print_latex_table_input_target(metric_table, u_columns, u_rows, u_variables, variable_name='Reg. Term.')

replace_dict = {
    'electricity_370': 'Electricity 370',
    'london_smartmeter': 'London Smartmeter',
    'tno_electricity': 'Proprietary',
    'kddcup': 'KDDCup',
    'cosine_dia': 'Cosine',
    'l1_': 'L1 ',
    'L2_': 'L2 ',
    'euclidean': 'Euclidean',
    '1_norm_1_cosine': 'Cosine + Norm',
    'TS-Inverse_Trend_Periodicity_Learned': '',
}
dataset_seed_dict = {
    'electricity_370': 10,
}

filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_combinations_regularization_12-6-2024',
        'attack_method': lambda x: x == 'TS-Inverse_Trend_Periodicity_Learned',
        'trend_term': lambda x: x == 0.5,
        'periodicity_term': lambda x: x == 1,
    }
]
reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed=dataset_seed_dict, columns='dataset', rows='attack_method', variables='batch_size', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
plot_multi_batch_reconstructions_in_grid(reconstruction_dict, u_columns, u_rows, u_variables, extra_title='ts-inverse_ts_regularization_combined_trend_periodicity_learned_10', plot_size_width=2.5, plot_size_height=0.5, legend_loc='lower right')

dataset_seed_dict = {
    'electricity_370': 43,
}
reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed=dataset_seed_dict, columns='dataset', rows='attack_method', variables='batch_size', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
plot_multi_batch_reconstructions_in_grid(reconstruction_dict, u_columns, u_rows, u_variables, extra_title='ts-inverse_ts_regularization_combined_trend_periodicity_learned_43', plot_size_width=2.5, plot_size_height=0.5, legend_loc='lower right')


## BELOW IS FOR THESIS
filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_combinations_regularization_12-6-2024',
        'attack_method': lambda x: x == 'TS-Inverse_Trend_Periodicity',
    }
]
reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed=dataset_seed_dict, columns='trend_term', rows='periodicity_term', variables='batch_size', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
plot_multi_batch_reconstructions_in_grid(reconstruction_dict, u_columns, u_rows, u_variables, extra_title='ts-inverse_ts_regularization_combined_trend_periodicity', plot_size_width=2.5, plot_size_height=0.45, legend_loc='best')

filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_combinations_regularization_12-6-2024',
        'attack_method': lambda x: x == 'TS-Inverse_Trend_Periodicity_Learned',
    }
]
reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed=dataset_seed_dict, columns='trend_term', rows='periodicity_term', variables='batch_size', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
plot_multi_batch_reconstructions_in_grid(reconstruction_dict, u_columns, u_rows, u_variables, extra_title='ts-inverse_ts_regularization_combined_trend_periodicity_learned', plot_size_width=2.5, plot_size_height=0.45, legend_loc='best')


# Gradient Loss Study

In [None]:
# IN THE PAPER
experiment_names = ["ts-inverse_fixed_gradient_loss_12-6-2024"]
runs_ts_inverse = api.runs(ts_inverse_project_name, filters={"$or": [{"config.experiment_name": name} for name in experiment_names]})
print(f"Found {len(runs_ts_inverse)} runs for Gradient Loss Comparisons")


# IN THE PAPER
filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_fixed_gradient_loss_12-6-2024',
        'model': lambda x: 'TCN' in x,
        # 'seed': lambda x: x == 10  or x == 43,
    }
]

metric_table, u_columns, u_rows, u_variables = gather_final_metrics_by_parameters(columns='dataset', rows='batch_size', 
                                                                                                         metrics=['inputs/smape/mean', 'targets/smape/mean'], 
                                                                                                         variable='gradient_loss', filters=filters, runs=runs_ts_inverse)
print_latex_table_input_target(metric_table, u_columns, u_rows, u_variables, variable_name='Gradient Loss')


# filters = [
#     {
#         'experiment_name': lambda x: x == 'ts-inverse_fixed_gradient_loss_12-6-2024',
#         'dataset': lambda x: x == 'electricity_370',
#         'seed': lambda x: x == 10  or x == 43,
#     }
# ]

# metric_table, u_columns, u_rows, u_variables = gather_final_metrics_by_parameters(columns='model', rows='batch_size', 
#                                                                                                          metrics=['inputs/smape/mean', 'targets/smape/mean'], 
#                                                                                                          variable='gradient_loss', filters=filters, runs=runs_ts_inverse)
# print_latex_table_input_target(metric_table, u_columns, u_rows, u_variables, variable_name='Gradient Loss')

# filters = [
#     {
#         'experiment_name': lambda x: x == 'ts-inverse_fixed_gradient_loss_12-6-2024',
#         'dataset': lambda x: x == 'london_smartmeter',
#         'seed': lambda x: x == 10  or x == 43,
#     }
# ]

# metric_table, u_columns, u_rows, u_variables = gather_final_metrics_by_parameters(columns='model', rows='batch_size', 
#                                                                                                          metrics=['inputs/smape/mean', 'targets/smape/mean'], 
#                                                                                                          variable='gradient_loss', filters=filters, runs=runs_ts_inverse)
# print_latex_table_input_target(metric_table, u_columns, u_rows, u_variables, variable_name='Gradient Loss')


filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_fixed_gradient_loss_12-6-2024',
        'batch_size': lambda x: x == 1,
        'model': lambda x: 'TCN' in x,
    }
]
replace_dict = {
    'electricity_370': 'Electricity 370',
    'london_smartmeter': 'London Smartmeter',
    'tno_electricity': 'Proprietary',
    'kddcup': 'KDDCup',
    'cosine_dia': 'Cosine',
    '1_l1norm_1_cosine': 'Cosine + L1-Norm',
    'l1': 'L1-Norm',
    'euclidean': 'L2-Norm',
    '1_l2norm_1_cosine': 'Cosine + L2-Norm'
}
dataset_seed_dict = {
    'electricity_370': 10,
    'london_smartmeter': 28,
    # 'tno_electricity': 43,
}
sorted_variables = ['Cosine + L1-Norm', 'Cosine + L2-Norm', 'Cosine', 'L2-Norm', 'L1-Norm']
reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed=dataset_seed_dict, columns='dataset', rows='model', variables='gradient_loss', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
plot_single_batch_reconstructions_in_grid(reconstruction_dict, u_columns, u_rows, sorted_variables, extra_title='ts-inverse_tcn_x_5gradient_loss_x_dataset', plot_size_width=2.5, plot_size_height=0.45, legend_loc='best')
# ## ^^ IN THE PAPER


# ## BELOW IS FOR THESIS
dataset_seed_dict = {
    'electricity_370': 10,
    'london_smartmeter': 10,
    'tno_electricity': 43,
}
reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed=dataset_seed_dict, columns='dataset', rows='model', variables='gradient_loss', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
plot_single_batch_reconstructions_in_grid(reconstruction_dict, u_columns, u_rows, sorted_variables, extra_title='ts-inverse_tcn_x_4gradient_loss_x_all_dataset', plot_size_width=2.5, plot_size_height=0.45, legend_loc='best')

# GRU-2-FCN and GRU-2-GRU architectures as defense

In [None]:
# IN THE PAPER
experiment_names = ["ts-inverse_defense_13-6-2024"]
runs_ts_inverse = api.runs(ts_inverse_project_name, filters={"$or": [{"config.experiment_name": name} for name in experiment_names]})
print(f"Found {len(runs_ts_inverse)} runs for Defense Comparisons")

# IN THE PAPER
filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_defense_13-6-2024',
        'warmup_number_of_batches': lambda x: x == 0,
    }
]

metric_table, u_columns, u_rows, u_variables = gather_final_metrics_by_parameters(columns='dataset', rows='model', 
                                                                                                         metrics=['inputs/smape/mean', 'targets/smape/mean'], 
                                                                                                         variable='batch_size', filters=filters, runs=runs_ts_inverse)
print_latex_table_input_target(metric_table, u_columns, u_rows, u_variables, variable_name='$\mathcal{B}$')

filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_defense_13-6-2024',
        'warmup_number_of_batches': lambda x: x == 0,
    }
]
replace_dict = {
    'electricity_370': 'Electricity 370',
    'london_smartmeter': 'London Smartmeter',
    'tno_electricity': 'Proprietary',
    'kddcup': 'KDDCup',
    'TS-Inverse_Trend_Periodicity_Learned': 'TS-Inverse',
    'JitGRU': 'GRU-2-FCN',
    'JitSeq2Seq': 'GRU-2-GRU',
}
dataset_seed_dict = {
    'electricity_370': 10,
    'london_smartmeter': 10,
    'kddcup': 10,
    'tno_electricity': 10,
}
reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed=dataset_seed_dict, columns='dataset', rows='model', variables='attack_method', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
plot_single_batch_reconstructions_in_grid(reconstruction_dict, u_columns, u_rows, u_variables, extra_title='ts-inverse_gru_defense', plot_size_width=3, plot_size_height=0.8, legend_loc='best', column_offset=0.08)
# ## ^^ IN THE PAPER

## Old Gradient Loss without target reconstruction

In [None]:
# IN THE PAPER
experiment_names = ["ts-inverse_final_tcn_x_elec_data_x_gradient_loss_5-6-2024"]
runs_ts_inverse = api.runs(ts_inverse_project_name, filters={"$or": [{"config.experiment_name": name} for name in experiment_names]})
print(f"Found {len(runs_ts_inverse)} runs for Gradient Loss Comparisons")

filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_final_tcn_x_elec_data_x_gradient_loss_5-6-2024',
    }
]

metric_table, u_columns, u_rows, u_variables = gather_final_metrics_by_parameters(columns='dataset', rows='batch_size', 
                                                                                                         metrics=['inputs/smape/mean', 'targets/smape/mean'], 
                                                                                                         variable='gradient_loss', filters=filters, runs=runs_ts_inverse)
print_latex_table_input_target(metric_table, u_columns, u_rows, u_variables, variable_name='Gradient Loss')

filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_final_tcn_x_elec_data_x_gradient_loss_5-6-2024',
        'batch_size': lambda x: x == 1
    }
]
replace_dict = {
    'electricity_370': 'Electricity 370',
    'london_smartmeter': 'London Smartmeter',
    'tno_electricity': 'Proprietary',
    'kddcup': 'KDDCup',
    'cosine_dia': 'Cosine',
    'l1': 'L1',
    'euclidean': 'Euclidean',
    '1_norm_1_cosine': 'Cosine + Norm'
}
dataset_seed_dict = {
    'electricity_370': 10,
    'tno_electricity': 43,
}
reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed=dataset_seed_dict, columns='dataset', rows='model', variables='gradient_loss', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
plot_single_batch_reconstructions_in_grid(reconstruction_dict, u_columns, u_rows, u_variables, extra_title='ts-inverse_tcn_x_4gradient_loss_x_dataset', plot_size_width=2.5, plot_size_height=0.45, legend_loc='best')
## ^^ IN THE PAPER


## BELOW IS FOR THESIS
dataset_seed_dict = {
    'electricity_370': 10,
    'london_smartmeter': 10,
    'tno_electricity': 43,
}
reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed=dataset_seed_dict, columns='dataset', rows='model', variables='gradient_loss', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
plot_single_batch_reconstructions_in_grid(reconstruction_dict, u_columns, u_rows, u_variables, extra_title='ts-inverse_tcn_x_4gradient_loss_x_all_dataset', plot_size_width=2.5, plot_size_height=0.45, legend_loc='best')

filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_final_tcn_x_elec_data_x_gradient_loss_5-6-2024',
        'batch_size': lambda x: x == 4
    }
]
dataset_seed_dict = {
    'electricity_370': 10,
    'london_smartmeter': 10,
    'tno_electricity': 43,
}
reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed=dataset_seed_dict, columns='dataset', rows='gradient_loss', variables='batch_size', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
plot_multi_batch_reconstructions_in_grid(reconstruction_dict, u_columns, u_rows, u_variables, extra_title='ts-inverse_tcn_batch4_x_gradient_loss_x_dataset', plot_size_width=2.5, plot_size_height=0.45, legend_loc='best')

### FCN and CNN Gradient loss

In [None]:
experiment_names = ["ts-inverse_final_tcn_x_elec_data_x_gradient_loss_5-6-2024"]
runs_ts_inverse = api.runs(ts_inverse_project_name, filters={"$or": [{"config.experiment_name": name} for name in experiment_names]})
print(f"Found {len(runs_ts_inverse)} runs for FCN and CNN Gradient Loss Comparisons")

# NOT IN THE PAPER (BECAUSE LESS STRONG CASE FOR CNN and FCN) Can be PUT IN THE THESIS
# DOES NOT USE PRIOR PREDICTION

filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_final_gradient_loss_4-6-2024',
        'gradient_loss': lambda x: x in ['1_norm_1_cosine', 'cosine_dia', 'euclidean', 'l1'],
    }
]

metric_table, u_columns, u_rows, u_variables = gather_final_metrics_by_parameters(columns='model', rows='batch_size', 
                                                                                                         metrics=['inputs/smape/mean', 'targets/smape/mean'], 
                                                                                                         variable='gradient_loss', filters=filters, runs=runs_ts_inverse)
print_latex_table_input_target(metric_table, u_columns, u_rows, u_variables, variable_name='Gradient Loss')


# Plotting specific example for in the paper, showing target difference between L1 and Cosine
filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_final_gradient_loss_4-6-2024',
        'gradient_loss': lambda x: x in ['1_norm_1_cosine', 'cosine_dia', 'euclidean', 'l1'],
        'batch_size': lambda x: x == 1
    }
]
replace_dict = {
    'electricity_370': 'Electricity 370',
    'london_smartmeter': 'London Smartmeter',
    'tno_electricity': 'Proprietary',
    'kddcup': 'KDDCup',
    'cosine_dia': 'Cosine',
    'l1': 'L1',
    'euclidean': 'Euclidean',
    '1_norm_1_cosine': 'Cosine + Norm'
}
dataset_seed_dict = {
    'electricity_370': 10,
}
reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed=dataset_seed_dict, columns='gradient_loss', rows='dataset', variables='model', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
plot_single_batch_reconstructions_in_grid(reconstruction_dict, u_columns, u_rows, u_variables, extra_title='ts-inverse_gradient_loss_x_model_x_electricity_370', plot_size_width=2.5, plot_size_height=0.45, legend_loc='center')

dataset_seed_dict = {
    'electricity_370': 28,
}
reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed=dataset_seed_dict, columns='gradient_loss', rows='dataset', variables='model', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
plot_single_batch_reconstructions_in_grid(reconstruction_dict, u_columns, u_rows, u_variables, extra_title='ts-inverse_gradient_loss_x_model_x_electricity_370', plot_size_width=2.5, plot_size_height=0.45, legend_loc='center')

dataset_seed_dict = {
    'electricity_370': 43,
}
reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed=dataset_seed_dict, columns='gradient_loss', rows='dataset', variables='model', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
plot_single_batch_reconstructions_in_grid(reconstruction_dict, u_columns, u_rows, u_variables, extra_title='ts-inverse_gradient_loss_x_model_x_electricity_370', plot_size_width=2.5, plot_size_height=0.45, legend_loc='center')


# Plotting specific example for in the paper, showing target difference between L1 and Cosine
filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_final_gradient_loss_4-6-2024',
        'gradient_loss': lambda x: x in ['1_norm_1_cosine', 'cosine_dia', 'euclidean', 'l1'],
        'batch_size': lambda x: x == 4
    }
]
dataset_seed_dict = {
    'electricity_370': 10,
}

reconstruction_dict, u_columns, u_rows, u_variables  = gather_run_reconstructions(dataset_seed=dataset_seed_dict, columns='gradient_loss', rows='model', variables='batch_size', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
plot_multi_batch_reconstructions_in_grid(reconstruction_dict, u_columns, u_rows, u_variables, extra_title='ts-inverse_batch_4_gradient_loss_x_model_x_electricity_370', plot_size_width=2.5, plot_size_height=0.45, legend_loc='center')

In [None]:
experiment_names = ["ts-inverse_final_gradient_loss_4-6-2024"]
runs_ts_inverse = api.runs(ts_inverse_project_name, filters={"$or": [{"config.experiment_name": name} for name in experiment_names]})
print(f"Found {len(runs_ts_inverse)} runs for Gradient Plots Comparisons")


dataset_seed_dict = {
    'electricity_370': 10,
}
replace_dict = {
    'electricity_370': 'Electricity 370',
    'london_smartmeter': 'London Smartmeter',
    'tno_electricity': 'Proprietary',
    'kddcup': 'KDDCup',
    'cosine_dia': 'Cosine',
    'euclidean': 'L2',
    '1_norm_1_cosine': 'Cosine + L2-Norm',
    '1_l1_1_cosine': 'Cosine + L1-Norm',
    'l1': 'L1',
}

sorted_columns = ['Cosine', 'L1', 'L2', 'Cosine + L1-Norm', 'Cosine + L2-Norm']

filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_final_gradient_loss_4-6-2024',
        'model': lambda x: 'FCN' in x,
        'batch_size': lambda x: x == 1,
        'gradient_loss': lambda x: x != '1_inorm_1_icosine' and x != 'l1_skip_1D',
    }
]
gradients_dict, u_columns, u_rows  = gather_run_gradients(dataset_seed=dataset_seed_dict, columns='dataset', rows='gradient_loss', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
plot_gradients_histogram_in_grid(gradients_dict, u_columns, u_rows, extra_title='ts-inverse_fcn_gradient_loss', plot_size_width=4, plot_size_height=2, legend_loc='upper left', legend_font_size=8)

filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_final_gradient_loss_4-6-2024',
        'model': lambda x: 'CNN' in x,
        'batch_size': lambda x: x == 1,
        'gradient_loss': lambda x: x != '1_inorm_1_icosine' and x != 'l1_skip_1D',
    }
]
gradients_dict, u_columns, u_rows  = gather_run_gradients(dataset_seed=dataset_seed_dict, columns='dataset', rows='gradient_loss', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
plot_gradients_histogram_in_grid(gradients_dict, u_columns, u_rows, extra_title='ts-inverse_cnn_gradient_loss', plot_size_width=4, plot_size_height=2, legend_loc='upper left', legend_font_size=8)

filters = [
    {
        'experiment_name': lambda x: x == 'ts-inverse_final_gradient_loss_4-6-2024',
        'model': lambda x: 'TCN' in x,
        'batch_size': lambda x: x == 1,
        'gradient_loss': lambda x: x != '1_inorm_1_icosine' and x != 'l1_skip_1D',
    }
]
gradients_dict, u_columns, u_rows  = gather_run_gradients(dataset_seed=dataset_seed_dict, columns='dataset', rows='gradient_loss', filters=filters, runs=runs_ts_inverse, replace_dict=replace_dict)
plot_gradients_histogram_in_grid(gradients_dict, u_columns, u_rows, extra_title='ts-inverse_tcn_gradient_loss', plot_size_width=4, plot_size_height=2, legend_loc='upper left', legend_font_size=8)