In [None]:
import json
from typing import List, Union, Tuple

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
from matplotlib.backends.backend_pdf import PdfPages
from matplotlib.ticker import ScalarFormatter
from seaborn.palettes import SEABORN_PALETTES
import pandas as pd
import seaborn as sns

In [None]:
# Labels for plots
RG_LABELS = {
    "RQT": "QT-",
    "QG": "QG-",
    "PD": "PD-"
}

LAYER_LINESTYLES = {
    "tucker": (0, (1.5, 1)),
    "cp-shared": "dashdot",
    "cp": "-"
}

LAYER_MARKERS = {
    "cp": "v",
    "tucker": "o",
    "cp-shared": "X"
}

LAYER_LABELS = {
    "cp": "CP",
    "tucker": "Tucker",
    "cp-shared": "$CP^s$"
}

PALETTE = SEABORN_PALETTES["colorblind"]
RG_COLORS = {
    "PD": PALETTE[0],
    "QG": PALETTE[1],
    "RQT": PALETTE[2]
}

In [None]:
# BPD PLOT (x= "k" OR "num_params", y = "test_bpd")

RESULTS_PATHS = {
    "mnist": "../results/grid_mnist.csv",
    "fashion_mnist": "../results/grid_fashion_mnist.csv",
    "celeba": "../results/grid_celeba.csv"
}

SAVE_PDF_PATH = {
    "mnist": "../results/mnist_bpd.pdf",
    "fashion_mnist": "../results/fashion_mnist_bpd.pdf",
    "celeba": "../results/celeba_bpd.pdf"
}

# Settings for plot
TICK_SIZE = 20
LINE_TICKNESS = 3.5
matplotlib.rcParams.update({'font.size': 17})

X_FIELD = "k" # or "num_params"
Y_FIELD = "Best/Test/bpd"


for dataset in RESULTS_PATHS:

    # load results data in dataframe
    complete_data: pd.DataFrame = pd.read_csv(RESULTS_PATHS[dataset])

    figs = []

    fig, axes = plt.subplots(nrows=1, ncols=1)

    for rg in complete_data["rg"].unique():
        for layer in complete_data["layer"].unique():

            values_to_plot = complete_data[(complete_data["rg"] == rg) & (complete_data["layer"] == layer)].sort_values(by="k")
            
            # TODO: check if it is correct (should be)
            values_to_plot = values_to_plot.groupby(by="k", as_index=False).mean(numeric_only=True)
            
            if len(values_to_plot) >= 1:
                print(values_to_plot[X_FIELD])
                axes.plot(values_to_plot[X_FIELD].map(lambda x: int(x)), 
                          values_to_plot[Y_FIELD], 
                          color=RG_COLORS[rg], 
                          lw=LINE_TICKNESS, linestyle=LAYER_LINESTYLES[layer],
                          marker=LAYER_MARKERS[layer], markersize=10, 
                          label=RG_LABELS[rg] + LAYER_LABELS[layer])

    # Set labels
    axes.set_xlabel("K", fontsize=22)
    #axes.xaxis.set_label_coords(-0.05, -0.025)
    axes.set_ylabel("Test bpd", fontsize=22)

    # Set legend
    plt.legend(loc="lower center", bbox_to_anchor=(0.4, 1.01), ncol=2, columnspacing=0.4,
            borderpad=0.1, labelspacing=0.05, markerscale=0.8, handletextpad=0.1)

    # Set ticks and size
    MINOR_Y_TICKS = []
    axes.set_yticks(MINOR_Y_TICKS, minor=True)
    plt.yticks(fontsize=TICK_SIZE)
    plt.xticks(fontsize=TICK_SIZE)

    # Set grid
    plt.grid(True, which='both', alpha=0.5, linestyle="--")

    # Set
    fig.set_figwidth(6)
    fig.set_figheight(5)
    figs.append(fig)

    # Save
    with PdfPages(SAVE_PDF_PATH[dataset]) as pdf:
        # Save each figure as an SVG file and append it to the PDF file
        for fig in figs:
            pdf.savefig(fig, bbox_inches="tight")


In [None]:
# Benchmarking plots

# Matplotlib settings
LINE_TICKNESS = 3.5
TICK_SIZE = 10
LEGEND_SIZE = 18
matplotlib.rcParams.update({'font.size': 18})

# PLOT SETTING
MIN_K = 1
MAX_K = 4096
MIN_ZOOM_K = 1
MAX_ZOOM_K = 32

BENCHMARKING_RESULTS_PATH = "../results/benchmarking_amari.csv"
X_NAME: str = "k" # assert X_NAME in ["k", "params", "forward_time"]

Y_LABELS = {
   "test-time": "Test Time (ms)",
   "test-space": "Test Memory (GB)",
   "train-time": "Train Time (ms)",
   "train-space": "Train Memory (GB)"
}


results_df = pd.read_csv(BENCHMARKING_RESULTS_PATH)

fig, axes = plt.subplots(nrows=1, ncols=4)

# create zoom axes
def create_zoom_ax(fig, ax):
    ax_pos = ax.get_position()
    return fig.add_axes([ax_pos.x0,
                        ax_pos.y0 + 0.6 * ax_pos.height,
                        ax_pos.width * 0.7,
                        ax_pos.height * 0.4])

zoom_axes = [None, None, None, None]


# Define the scaling factor and shift for the new axes
scale_factor = 0.5
shift_factor = 0.2


for plot_n, y_name in enumerate(Y_LABELS):

    ax = axes[plot_n]
    for rg in RG_LABELS:
        for n, layer in enumerate(LAYER_LABELS):

            label = RG_LABELS[rg] + LAYER_LABELS[layer]
            values_to_plot = results_df[(results_df["rg"] == rg) & (results_df["layer"] == layer)].sort_values(by="k")
            values_to_plot = values_to_plot[(values_to_plot["k"] >= MIN_K) & (values_to_plot["k"] <= MAX_K)]


            #if X_NAME == "k":
                #plt.scatter(x=[str(v) for v in values_to_plot[x_name]], y=values_to_plot[y_name],
                #            s=20, color=color, marker=marker, label=label)

            #sns.lineplot(data=values_to_plot, x=X_NAME, y=y_name, ax=axes[plot_n])
            if len(values_to_plot) > 0: # otherwise it prints the label
                ax.plot(values_to_plot[X_NAME].map(lambda x: str(x)), values_to_plot[y_name],
                            color=RG_COLORS[rg], linestyle="-",
                            marker=LAYER_MARKERS[layer], markersize=7, alpha=0.4,
                            lw=LINE_TICKNESS, label=label)

                ax.scatter(values_to_plot[X_NAME].map(lambda x: str(x)), values_to_plot[y_name],
                            color=RG_COLORS[rg],
                            marker=LAYER_MARKERS[layer], s=100, alpha=0.7)

            if (zoom_ax := zoom_axes[plot_n]) is not None:
                # plot only subset of values
                zoomed_values_to_plot = values_to_plot[(values_to_plot["k"] >= MIN_ZOOM_K) & (values_to_plot["k"] <= MAX_ZOOM_K)]
                zoom_ax.plot(zoomed_values_to_plot[X_NAME].map(lambda x: str(x)), zoomed_values_to_plot[y_name],
                            color=RG_COLORS[rg], linestyle="-",
                            marker=LAYER_MARKERS[layer], markersize=7, alpha=0.4,
                            lw=LINE_TICKNESS, label=label)
                zoom_ax.scatter(zoomed_values_to_plot[X_NAME].map(lambda x: str(x)), zoomed_values_to_plot[y_name],
                            color=RG_COLORS[rg],
                            marker=LAYER_MARKERS[layer], s=100, alpha=0.7)

                #zoom_ax.set_yscale("log")
                zoom_ax.grid(True, which='both', alpha=0.5, linestyle="--")


            #else:
            #    raise NotImplementedError

    ax.set_title(Y_LABELS[y_name])
    ax.set_xlabel(f"${X_NAME.upper()}$")
    ax.set_xticks(ax.get_xticks()[::2])
    ax.set_xticklabels([f"$2^{{{tick+4}}}$" for tick in ax.get_xticks()])

    # log scale only for time
    ax.set_yscale("log" if plot_n % 2 == 0 else "linear")
    ax.grid(True, which='both', alpha=0.5, linestyle="--")


fig.set_figwidth(20)
fig.set_figheight(4)


if X_NAME != "k":
    axes.set_xscale('log')

axes[3].legend(loc="upper left", bbox_to_anchor=(1.01, 1.0), fontsize=19,
            columnspacing=0.4, borderpad=0.1, labelspacing=0.05,
            markerscale=1.3, handletextpad=0.1)




# Create a new set of axes with the calculated position


# Plot the zoomed-in region
#ax_zoom.plot(x, y1, color='blue')
#ax_zoom.set_xlim(x_zoom)
#ax_zoom.set_ylim(y_zoom)



# Hide x and y ticks for the zoomed-in region
#ax_zoom.tick_params(axis='both', which='both', bottom=False, top=False, left=False, right=False)

# Add gridlines to the zoomed-in region


In [None]:
# Benchmarking plots

# Matplotlib settings
LINE_TICKNESS = 3.5
TICK_SIZE = 10
LEGEND_SIZE = 18
matplotlib.rcParams.update({'font.size': 18})
X_TICKS = [str(2**i) for i in range(3, 13)]

# PLOT SETTING
MIN_K = 1
MAX_K = 4096

BENCHMARKING_RESULTS_PATH = "../results/benchmarking_amari.csv"
X_NAME: str = "k" # assert X_NAME in ["k", "params", "forward_time"]

Y_LABELS = {
   "test-time": "Test Time (ms)",
   "test-space": "Test Memory (GB)",
   "train-time": "Train Time (ms)",
   "train-space": "Train Memory (GB)"
}


results_df = pd.read_csv(BENCHMARKING_RESULTS_PATH)

fig, axes = plt.subplots(nrows=len(LAYER_LABELS), ncols=4, sharey="col")

for plot_row, layer in enumerate(LAYER_LABELS):
    for plot_col, y_name in enumerate(Y_LABELS):

        ax = axes[plot_row, plot_col]
        for rg in RG_LABELS:



            label = RG_LABELS[rg] + LAYER_LABELS[layer]
            values_to_plot = results_df[(results_df["rg"] == rg) & (results_df["layer"] == layer)].sort_values(by="k")
            values_to_plot = values_to_plot[(values_to_plot["k"] >= MIN_K) & (values_to_plot["k"] <= MAX_K)]

            #if X_NAME == "k":
                #plt.scatter(x=[str(v) for v in values_to_plot[x_name]], y=values_to_plot[y_name],
                #            s=20, color=color, marker=marker, label=label)

            #sns.lineplot(data=values_to_plot, x=X_NAME, y=y_name, ax=axes[plot_n])
            if len(values_to_plot) > 0: # otherwise it prints the label
                ax.plot(values_to_plot[X_NAME].map(lambda x: str(x)), values_to_plot[y_name],
                            color=RG_COLORS[rg], linestyle=LAYER_LINESTYLES[layer],
                            marker=RG_MARKERS[rg], markersize=7,
                            lw=LINE_TICKNESS, label=label)
            #else:
            #    raise NotImplementedError

        ax.set_title(Y_LABELS[y_name])
        ax.set_xlabel(f"${X_NAME.upper()}$")
        ax.set_xticks(X_TICKS)
        ax.grid(True, which='both', alpha=0.5, linestyle="--")
        ax.set_yscale('log' if plot_col % 2 == 0 else 'linear')

        ax.set_xticklabels([f"$2^{{{int(tick + 3)}}}$" for tick in ax.get_xticks()])
        #axes[plot_n].set_xscale("log")

    # plot_col = last value
    axes[plot_row, plot_col].legend(loc="upper left", bbox_to_anchor=(1.01, 1.0), fontsize=19,
            columnspacing=0.4, borderpad=0.1, labelspacing=0.05,
            markerscale=1.3, handletextpad=0.1)







fig.set_figwidth(20)
fig.set_figheight(16)


if X_NAME != "k":
    axes.set_xscale('log')


In [None]:
# FROM HERE, DO NOT TOUCH AND DO NOT EXECUTE
# df = pd.DataFrame(columns=["hyp_par", "model", "value", "forward_time", "forward_space",
#                           "backward_time", "backward_space"])

RESULT_PATH = "../results/benchmarking_amari.csv"
COLORS = ['b', 'g', 'r', 'c', 'm', 'y', 'k']


def main():
    df = pd.read_csv(RESULT_PATH)
    figs = []

    hyper_list = df["hyp_par"].unique()

    for hyp in hyper_list:

        values = df[df["hyp_par"] == hyp]
        model_values_groups = [group for _, group in values.groupby('model')]
        # values_og = values[values["model"] == "og"]
        # values_si = values[values["model"] == "decomposed"]

        fig, axes = plt.subplots(nrows=2, ncols=2)
        fig.suptitle(hyp)

        funny_list = [["forward_", "backward_"], ["time", "space"]]
        for (i, j) in [(i, j) for i in {0, 1} for j in {0, 1}]:

            ax = axes[i, j]
            field = funny_list[0][i] + funny_list[1][j]

            for n, group in enumerate(model_values_groups):
                model_name = (group['model'].unique()[0])
                model_abbr_name = model_name[:3] if len(model_name) >= 3 else model_name

                ax.plot(group["value"].values, group[field].values, color=COLORS[n], label=model_abbr_name)
                ax.scatter(group["value"].values, group[field].values, color=COLORS[n], s=10)

            ax.set_title(field)
            ax.legend()
            ax.set_xlabel(hyp)
            ax.set_ylabel("sec" if j == 0 else "bytes")

        # Adjust spacing between subplots
        plt.subplots_adjust(hspace=0.6, wspace=0.4)
        figs.append(fig)

    # Create a PDF file to save the figures
    with PdfPages(RESULT_PATH.replace(".csv", ".pdf")) as pdf:
        # Save each figure as an SVG file and append it to the PDF file
        for fig in figs:
            pdf.savefig(fig)


def print_metrics(path: str, x_name: str, metrics: List[str]):
    data: pd.DataFrame = pd.read_csv(path)
    fig, axes = plt.subplots(nrows=1, ncols=len(metrics))
    for num, metric in enumerate(metrics):

        values = data[data[x_name] > 0][[x_name, metric]]
        base_value = data[data[x_name] == 0][metric]
        print(base_value.item())

        if len(metrics) > 1:
            ax = axes[num]
        else:
            ax = axes
        ax.axhline(y=base_value.item(), color='r', linestyle='-')
        ax.scatter(x=values[x_name], y=values[metric], s=10, color='blue')

        ax.set_title(metric)
        ax.legend()
        ax.set_xlabel(x_name)
        ax.set_ylabel(metric)
        # ax.set_xticks(values[values[x_name] % 10 == 0][x_name])
        ax.set_xticks(values[x_name])

    fig.set_figwidth(16)
    fig.set_figheight(4)

    # Adjust spacing between subplots
    plt.subplots_adjust(hspace=0.1, wspace=0.5)

    # Create a PDF file to save the figures
    with PdfPages(path.replace(".csv", ".pdf")) as pdf:
        pdf.savefig(fig)


def scatter_dim(df, x_field, y_field, dim_field, color, marker, name):

    for index, row in df.iterrows():
        x = row[x_field]
        y = row[y_field]
        s = row[dim_field]
        plt.scatter(x, y, s=s, color=color, marker=marker)

    plt.xscale('log')
    plt.yscale('log')
    plt.grid(True, which='both', alpha=0.3, linestyle="--")


def plot_benchmarking_value(data: pd.DataFrame, x_name: str, y_name: str, color, marker: str, linestyle, label: str):

    LINE_TICKNESS = 3.5
    plt.grid(True, which='both', alpha=0.3, linestyle="--")
    values_to_plot = data[[x_name, y_name]]# .sort_values(by=x_name, ascending=True)

    if y_name.endswith("space"):
        values_to_plot[y_name] = values_to_plot[y_name] / (1024 * 1024 * 1024)

    if x_name == "k":
        #plt.scatter(x=[str(v) for v in values_to_plot[x_name]], y=values_to_plot[y_name],
        #            s=20, color=color, marker=marker, label=label)

        plt.plot([str(v) for v in values_to_plot[x_name]], values_to_plot[y_name],
                 color=color, linestyle=linestyle, marker=marker,
                 markersize=7, lw=LINE_TICKNESS, label=label.replace("_", " "))
    else:
        # plt.scatter(x=values_to_plot[x_name], y=values_to_plot[y_name], s=20, color=color, label=label)
        plt.plot(values_to_plot[x_name], values_to_plot[y_name],
                 color=color,  marker=marker,
                 markersize=7, lw=LINE_TICKNESS, label=label.replace("_", " "))
        #for i in range(len(data)):
        #    plt.text(values_to_plot[x_name] + 10, values_to_plot[y_name] + 10, values_to_plot[x_name], size=12)


def benchmark_plot(results_path, x_name, y_name):
    results_df = pd.read_csv(results_path)
    color_map = {"CP": 'r', "CP-Shared": 'g', "CP-s": 'g', "Tucker": 'b', "LoRa-CP": 'r'}
    marker_map = {"CP": 'v', "CP-Shared": 'P', "CP-s": 'P', "Tucker": 'o', "LoRa-CP": 'D'}
    name_map = {"CP": "CP-f", "CP-Shared": "CP-s", "CP-s": "CP-s", "Tucker": "Tucker-f"}

    assert x_name in ["k", "params", "forward_time"]

    matplotlib.rcParams.update({'font.size': 18})
    TICK_SIZE = 22
    LEGEND_SIZE = 18

    fig, axes = plt.subplots()
    #fig.subplots_adjust(0, 0, 1, 1)

    for n, (name, group) in enumerate(results_df.groupby(["model"])):

        if name.endswith("LoRa-CP"):
            continue

        # Set marker and linestyle
        if group.iloc[0]["model"].startswith("PD"):
            linestyle = "-"
        else:
            linestyle = (0, (2.5, 1))

        suffix_name = name[3:]

        if name[:3] == "QT-":
            rg = "QG-"
        else:
            rg = name[:3]

        label = rg + name_map[name[3:]]

        plot_benchmarking_value(group, x_name, y_name, color=color_map[suffix_name],
                                marker=marker_map[suffix_name], linestyle=linestyle, label=label)


    if y_name.endswith("time"):
        title =  y_name + " (s)"
    else:
        title = y_name + " (GB)"

    axes.set_title(title.replace("_", " "))

    axes.set_xlabel(x_name.upper())
    plt.yticks(fontsize=TICK_SIZE)
    plt.xticks(fontsize=TICK_SIZE)

    fig.set_figwidth(4.8)
    fig.set_figheight(3.2)

    if x_name != "k":
        axes.set_xscale('log')
    # axes.set_yscale('log')

    if y_name == "backward_space":
        """
        def export_legend(legend):
            legend_fig = legend.figure
            legend_fig.canvas.draw()
            bbox = legend.get_window_extent().transformed(legend_fig.dpi_scale_trans.inverted())
            return legend_fig, bbox

        legend = plt.legend(loc="upper left", fontsize=16, ncol=1, columnspacing=0.4,
                            borderpad=0.05, labelspacing=0.2, handletextpad=0.05)

        legend_fig_, bbox = export_legend(legend)

        return fig, legend_fig_, bbox
        """
        plt.legend(loc="upper left", bbox_to_anchor=(1.01, 1.0), fontsize=19,
                   columnspacing=0.4, borderpad=0.1, labelspacing=0.05,
                   markerscale=1.3, handletextpad=0.1)
    return fig


def double_plot(df: pd.DataFrame, x_field: str, y_field: str, right_y_field: Union[str, None], axes,
                y_baseline=None, right_y_baseline=None, title="Plot", right_tick=False, left_tick=False,
                x_ticks=None, y_lims=None, right_y_lims=None, x_logscale=False, horizontal_lines_to_plot=None,
                vertical_lines_to_plot=None):

    axes.grid(True, which='both', alpha=0.3, linestyle="--")

    if y_baseline is not None:
        axes.axhline(y=y_baseline, color='b', linestyle='-', lw=2, alpha=0.7)

        df["bpd_loss"] = 100 * (df["test_bpd"] - y_baseline) / y_baseline

        # x > 1 %
        greater_than_1 = df[df["bpd_loss"] > 1]
        # 0.1 < x < 1 %
        between_01_1 = df[(df["bpd_loss"] <= 1) & (df["bpd_loss"] > 0.1)]
        # x < 0.1 %
        less_than_01 = df[df["bpd_loss"] <= 0.1]

        axes.scatter(x=greater_than_1[x_field], y=greater_than_1[y_field], color='blue', marker='x')
        axes.scatter(x=between_01_1[x_field], y=between_01_1[y_field], color='blue', marker='X')
        axes.scatter(x=less_than_01[x_field], y=less_than_01[y_field], color='blue', marker='o')
    else:
        axes.scatter(x=df[x_field], y=df[y_field], color='b', marker='o')

    bpd_line, = axes.plot(df[x_field], df[y_field], color='b', linestyle='-', alpha=0.5, zorder=0)

    if right_y_field is not None:
        right_ax = axes.twinx()

    if right_y_baseline is not None:
        right_ax.axhline(y=right_y_baseline, color='r', linestyle='-', lw=2, alpha=0.7)

    if right_y_field is not None:
        ok_points = df[df[right_y_field] < 0.99]
        good_points = df[(df[right_y_field] < 0.999) & (df[right_y_field] > 0.99)]
        excellent_points = df[df[right_y_field] > 0.999]

        right_ax.scatter(x=ok_points[x_field], y=ok_points[right_y_field], color='r', marker='x')
        right_ax.scatter(x=good_points[x_field], y=good_points[right_y_field], color='r', marker='X')
        right_ax.scatter(x=excellent_points[x_field], y=excellent_points[right_y_field], color='r', marker='o')

        compr_line, = right_ax.plot(df[x_field], df[right_y_field], color='r', linestyle='--', alpha=0.5)

    # add ticks
    if right_tick:
        right_y_tick_to_add = df[right_y_field].max()
        right_ax.set_yticks(list(right_ax.get_yticks()) + [right_y_tick_to_add])  # set the ticks to be a

    if left_tick:
        y_tick_to_add = df[y_field].max()
        axes.set_yticks(list(axes.get_yticks()) + [y_tick_to_add])

    if x_ticks is not None:
        axes.set_xticks(x_ticks)
    axes.set_xticks(df[x_field], minor=True)

    if y_lims is not None:
        axes.set_ylim(bottom=y_lims[0], top=y_lims[1])

    if right_y_lims is not None:
        right_ax.set_ylim(bottom=right_y_lims[0], top=right_y_lims[1])

    axes.legend([bpd_line], [y_field])

    if right_y_field is not None:
        right_ax.legend([compr_line], [right_y_field])

    one_percent_up = (axes.get_ylim()[1] - axes.get_ylim()[0]) / 100
    one_percent_right = (axes.get_xlim()[1] - axes.get_xlim()[0]) / 100
    if x_logscale:
        axes.set_xscale('log')
        # axes.get_xaxis().set_major_formatter(ScalarFormatter())

        # get min and max and set major ticks
        min_x_row = df.loc[df[x_field].idxmin()]
        max_x_row = df.loc[df[x_field].idxmax()]

        new_ticks = [min_x_row[x_field], max_x_row[x_field]]
        axes.set_xticks(list(axes.get_xticks()) + new_ticks)
        # add corresponding text

        five_percent_up = (axes.get_ylim()[1] - axes.get_ylim()[0]) / 20
        axes.text(x=min_x_row[x_field], y=min_x_row[y_field] + 5 * one_percent_up,
                  s=f"{min_x_row[x_field]:.3e}", ha='center', va='center')
        axes.text(x=max_x_row[x_field], y=max_x_row[y_field] - 5 * one_percent_up,
                  s=f"{max_x_row[x_field]:.3e}", ha='center', va='center')

    # dict text: value
    if horizontal_lines_to_plot is not None:
        for text, value in horizontal_lines_to_plot.items():
            axes.axhline(y=value, color='r', linestyle='-', alpha=0.7)
            axes.text(x=axes.get_xlim()[1] - 100 * one_percent_right, y=value + 2 * one_percent_up,
                      s=f'{text}', ha='right', va='center', color='r')

    if vertical_lines_to_plot is not None:
        for text, value in vertical_lines_to_plot.items():
            axes.axvline(x=value, color='r', linestyle='--', alpha=0.4)
            axes.text(x=value, y= axes.get_ylim()[1] - 5 * one_percent_up,
                      s=f'{text}', ha='right', va='center', color='r')

    axes.set_title(title)
    axes.set_xlabel(x_field)
    if right_y_field is not None:
        right_ax.set_ylabel(right_y_field, color='r')
    axes.set_ylabel(y_field, color='b')


def compression_plot(df_path: str):
    data: pd.DataFrame = pd.read_csv(df_path)

    figs = []

    ############## FIRST
    fig, axes = plt.subplots(nrows=1, ncols=1)

    orig = data[data["model"] == "Original"]
    values = data[~(data["model"] == "Original")]

    data = pd.concat([values, orig])

    # calculate "compression_rate" and extract "r" fields
    orig_compr_params = orig["compressible_params"].item()
    values["compression rate"] = 1 - (values["compressible_params"] / orig_compr_params)
    values["r"] = values["model"].apply(lambda x: int(x.split("_")[1]))

    double_plot(df=values, x_field="r", y_field="test_bpd", right_y_field="compression rate", axes=axes,
                y_baseline=orig["test_bpd"].item(), right_y_baseline=1, x_ticks=[1, 2, 4, 8, 16, 32, 64],
                right_y_lims=(0.96, 1), right_tick=True)

    fig.set_figwidth(10)
    fig.set_figheight(5)

    figs.append(fig)

    ################ SECOND
    fig, axes = plt.subplots(nrows=1, ncols=1)

    # calculate general compression rate
    orig_params = orig["params"].item()
    values["model compression rate"] = 1 - (values["params"] / orig_params)

    double_plot(df=values, x_field="r", y_field="test_bpd", right_y_field="model compression rate", axes=axes,
                y_baseline=orig["test_bpd"].item(), right_y_baseline=1, x_ticks=[1, 2, 4, 8, 16, 32, 64],
                right_y_lims=(0.90, 1), right_tick=True)

    fig.set_figwidth(10)
    fig.set_figheight(5)

    figs.append(fig)


    ####### THIRD
    fig, axes = plt.subplots(nrows=1, ncols=1)

    h_line_1 = orig["test_bpd"].item() * 1.01
    h_line_01 = orig["test_bpd"].item() * 1.001

    v_line_99 = (1 - 0.99) * orig_compr_params
    v_line_999 = (1 - 0.999) * orig_compr_params
    double_plot(df=data, x_field="compressible_params", y_field="test_bpd", axes=axes, right_y_field=None, x_logscale=True,
                horizontal_lines_to_plot={"1% bpd loss": h_line_1, "0.1% bpd loss": h_line_01},
                vertical_lines_to_plot={"99.9% \n compr": v_line_999, "99% \n compr": v_line_99})
    fig.set_figwidth(8)
    fig.set_figheight(6)

    figs.append(fig)

    ####### FOURTH
    fig, axes = plt.subplots(nrows=1, ncols=1)

    h_line_1 = orig["test_bpd"].item() * 1.01
    h_line_01 = orig["test_bpd"].item() * 1.001

    v_line_99 = (1 - 0.99) * orig_compr_params
    v_line_95 = (1 - 0.95) * orig_compr_params
    v_line_90 = (1 - 0.9) * orig_compr_params
    double_plot(df=data, x_field="params", y_field="test_bpd", axes=axes, right_y_field=None, x_logscale=True,
                horizontal_lines_to_plot={"1% bpd loss": h_line_1, "0.1% bpd loss": h_line_01},
                vertical_lines_to_plot={"99% \n compr": v_line_99, "95% \n compr": v_line_95,
                                        "90% \n compr": v_line_90})
    fig.set_figwidth(8)
    fig.set_figheight(6)

    figs.append(fig)


    # Create a PDF file to save the figures
    with PdfPages(df_path.replace(".csv", ".pdf")) as pdf:
        # Save each figure as an SVG file and append it to the PDF file
        for fig in figs:
            pdf.savefig(fig)


def complete_rank_plot(results_path: str, k=64,
                       pdf_path: str = None, x_field="rank"):

    assert x_field in ["rank", "num_params"]
    complete_data: dict = json.load(open(results_path, "r"))
    # extract json for the region graph and layer type

    matplotlib.rcParams.update({'font.size': 16})
    TICK_SIZE = 18
    LEGEND_SIZE = 16

    def get_results_dict(rg: str, layer: str) -> dict:
        # retrieve the relative json object
        for d in complete_data[rg][layer]:
            if d["k"] == k:
                return d

    figs = []

    rg_list = ["quad_tree"] # "pd7"
    layer_type_list = ["rescal"]

    ############## FIRST
    fig, axes = plt.subplots(nrows=1, ncols=1)

    for i, rg in enumerate(rg_list):
        for j, layer in enumerate(layer_type_list):
            rank_plot(axes, get_results_dict(rg, layer), mode=x_field)

    fig.set_figwidth(4.2)
    fig.set_figheight(2.6)
    axes.set_xlabel("Rank (" + r"$R$" + ")", fontsize=18)
    # axes.xaxis.set_label_coords(0.025, -0.023)
    plt.yticks(fontsize=TICK_SIZE)

    axes.set_xscale('log')
    if x_field == "num_params":
        plt.xticks(fontsize=TICK_SIZE)
    else:
        print("a")
        plt.grid(True, which='both', alpha=0.5, linestyle="--")
        axes.set_yticks([1.2, 1.3, 1.4, 1.5, 1.6])
        axes.set_xticks([1, 2, 4, 8, 16, 32, 64])
        axes.get_xaxis().set_major_formatter(ScalarFormatter())
        plt.minorticks_off()

    plt.legend(loc="lower center", bbox_to_anchor=(0.5, 1.01), ncol=2,
               columnspacing=0.4, borderpad=0.1, labelspacing=0.05,
               markerscale=1, handletextpad=0.1)

    figs.append(fig)

    # Create a PDF file to save the figures
    if pdf_path is None:
        pdf_path = results_path.replace(".json", ".pdf")

    with PdfPages(pdf_path) as pdf:
        # Save each figure as an SVG file and append it to the PDF file
        for fig in figs:
            pdf.savefig(fig, bbox_inches="tight")





def extra_shared_plot(results_path: str,  pdf_path: str, k_in=32,
                  x_field="num_params", y_field="bpd"):

    assert x_field in ["num_params", "k"]
    assert y_field in ["bpd"]

    TICK_SIZE = 22
    matplotlib.rcParams.update({'font.size': 18})

    complete_data: dict = json.load(open(results_path, "r"))

    # extract json for the region graph and layer type
    def get_results_dict(rg: str, layer: str) -> dict:
        # retrieve the relative json object
        for d in complete_data[rg][layer]:
            if d["k"] == k_in:
                return d

    figs = []

    data = get_results_dict("quad_tree_stdec", "hcpt")
    fig, axes = plt.subplots(nrows=1, ncols=1)

    # Scatter HCPT
    axes.scatter(x=data["num_params"], y=data["bpd"], s=40, marker="*", color="b", label="kCP")
    # Scatter Extra Shared
    data_extra_shared = data["cp-extra-shared"]
    x, y = zip(*[(data_extra_shared["num_params"][k],
                               data_extra_shared["test_bpd"][k])
                               for k in data_extra_shared["test_bpd"]])

    axes.scatter(x=x, y=y, color="r",
                 marker="o", s=40, label="CrazyNet")

    axes.set_xlabel("# Params", fontsize=22)


    if not results_path.endswith("fashion.csv"):
        axes.set_ylabel(y_field, fontsize=22)
        axes.legend()
    plt.yticks(fontsize=TICK_SIZE)
    plt.xticks(fontsize=TICK_SIZE)
    plt.title(f"k-in {k_in}")
    plt.grid(True, which='both', alpha=0.5, linestyle="--")

    #fig.set_figwidth(16)
    #fig.set_figheight(4)
    axes.set_xscale('log')
    #axes.set_yscale('log')

    #fig.set_figwidth(6)
    #fig.set_figheight(5)
    figs.append(fig)

    with PdfPages(pdf_path) as pdf:
        # Save each figure as an SVG file and append it to the PDF file
        for fig in figs:
            pdf.savefig(fig, bbox_inches="tight")


def k_in_less_plot(results_path: str,  pdf_path: str, rg="quad_tree", layer="hcpt",
                  x_field="num_params", y_field="test_bpd", k=128):

    assert x_field in ["num_params", "k"]
    assert y_field in ["test_bpd"]

    complete_data: dict = json.load(open(results_path, "r"))

    # extract json for the region graph and layer type

    def extract_values(rg: str, layer: str) -> List[Tuple[int, float]]:
        # retrieve the relative json object
        for d in complete_data[rg][layer]:
            if d["k"] == k:

                if x_field == "k":
                    return [(int(k_in), d["k_in"]["test_bpd"][k_in]) for k_in in d["k_in"]["test_bpd"]]
                else:
                    return [(d["k_in"]["num_params"][k_in], d["k_in"]["test_bpd"][k_in]) for k_in in d["k_in"]["test_bpd"]]

    figs = []

    ############## FIRST
    fig, axes = plt.subplots(nrows=1, ncols=1)

    colors = ["r", "b"]

    x, y = zip(*extract_values(rg, layer))
    simple_plot(axes, x, y, color=colors[0], linestyle="-", marker="o", label=f"{rg}-{layer}-sums={k}")

    axes.set_xlabel(x_field)
    axes.set_ylabel(y_field)
    axes.legend()
    plt.grid(True, which='both', alpha=0.5, linestyle="--")

    #fig.set_figwidth(16)
    #fig.set_figheight(4)
    axes.set_xscale('log')
    #axes.set_yscale('log')

    fig.set_figwidth(6)
    fig.set_figheight(4)
    figs.append(fig)

    with PdfPages(pdf_path) as pdf:
        # Save each figure as an SVG file and append it to the PDF file
        for fig in figs:
            pdf.savefig(fig)


def simple_plot(axes, x: List, y: List, linestyle, marker: str, color: str, label: str):
    LINE_TICKNESS = 3.5
    # axes.scatter(x=x, y=y, s=10, color=color, label=label)
    axes.plot(x, y, color=color, lw=LINE_TICKNESS, linestyle=linestyle, marker=marker, markersize=10,
              label=label.replace("_", " "))


def rank_plot(axes, values_to_plot: dict, mode="rank"):
    LINE_TICKNESS = 3
    assert mode in ["rank", "num_params"]

    baseline: float = values_to_plot["bpd"]
    basline_2: float = 1.1842083930969238

    axes.axhline(y=baseline, color='b', linestyle='--', lw=LINE_TICKNESS - 1, alpha=0.7, label="Tucker")
    # axes.axhline(y=basline_2, color='orange', linestyle='--', lw=LINE_TICKNESS - 1, alpha=0.7)

    def plot_values_rank(x_dict: dict, y_dict: dict, color: str, label: str, marker: str):
        x_values, y_values = zip(*sorted([(x_dict[k], y_dict[k]) for k in y_dict],
                                         key=lambda a: a[0]))
        axes.plot(x_values, y_values, color=color, label=label, lw=LINE_TICKNESS, marker=marker,
                  markersize=10)

    # plot local results
    if mode == "num_params":
        x_dict = values_to_plot["lora-model"]["num_params"]
    else:
        x_dict = {key: int(key) for key in values_to_plot["lora-model"]["num_params"]}


    local_values: dict = values_to_plot["local"]["slice"]
    if len(local_values) > 0:
        plot_values_rank(x_dict, local_values, "r", "NN-CP", marker="v")

    # plot random init
    random_init_values: dict = values_to_plot["lora-model"]["random_init"]
    if len(random_init_values) > 0:
        plot_values_rank(x_dict, random_init_values, 'orange', "Random init", marker="X")

    # plot finetuned
    finetuned_values: dict = values_to_plot["lora-model"]["finetuned"]
    if len(finetuned_values) > 0:
        plot_values_rank(x_dict, finetuned_values, "g", "Fine-tuned", marker="o")


def plot_params(results_path: str,  pdf_path: str):
    x_field = "k"
    y_field = "num_params"

    TICK_SIZE = 22
    matplotlib.rcParams.update({'font.size': 18})

    complete_data: dict = json.load(open(results_path, "r"))

    # extract json for the region graph and layer type
    def get_results_dict(rg: str, layer: str) -> dict:
        # retrieve the relative json object
        return complete_data[rg][layer]

    def make_label(rg, layer):
        if rg == "quad_tree":
            rg_label = "QG-"
        else:
            rg_label = "PD-"

        if layer == "cp":
            layer_label = "CP"
        else:
            layer_label = "Tucker"

        return rg_label + layer_label + "-f"

    figs = []
    fig, axes = plt.subplots(nrows=1, ncols=1)

    for rg in ["quad_tree", "poon_domingos"]:
        for layer in ["cp"]:
            data = get_results_dict(rg, layer)

            if rg == "quad_tree":
                linestyle = (0, (2.5, 1))
            else:
                linestyle = "-"

            if layer == "cp":
                marker = "v"
                color = "r"
            else:
                marker = "o"
                color = "b"

            x, y = zip(*[(k, data[k]) for k in data])
            simple_plot(axes, x=x, y=y, linestyle=linestyle,
                        marker=marker, color=color, label=make_label(rg, layer))

    axes.set_xlabel("K", fontsize=22)
    axes.set_ylabel("# Params", fontsize=22)
    axes.legend()
    plt.yticks(fontsize=TICK_SIZE)
    plt.xticks(fontsize=TICK_SIZE)
    plt.grid(True, which='both', alpha=0.5, linestyle="--")

    # fig.set_figwidth(16)
    # fig.set_figheight(4)
    # axes.set_xscale('log')
    #axes.set_yscale('log')

    fig.set_figwidth(6)
    fig.set_figheight(5)
    figs.append(fig)

    with PdfPages(pdf_path) as pdf:
        # Save each figure as an SVG file and append it to the PDF file
        for fig in figs:
            pdf.savefig(fig, bbox_inches="tight")


def plot_weight_distribution(weights: torch.Tensor, name: str):
    logbins = np.logspace(-20, 2, 60)
    fig, axes = plt.subplots(nrows=1, ncols=1)
    axes.hist(weights, bins=logbins)
    plt.xscale('log')
    #plt.yscale('log')
    fig.savefig(f"models/{name}.pdf")


if __name__ == "__main__":
    # main()
    #  print_metrics(RESULT_PATH, "r", ["test_ll"])
    #compression_plot(RESULT_PATH)

    mode = "lvd_like"
    import os
    os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"

    if mode == "benchmark":
        metrics = ["forward_time", "forward_space", "backward_time", "backward_space"]

        RESULTS_PATH = "charts/new_benchmarking/for_paper/loshared.csv"
        X_VALUE = "params" #"k" # "params" #
        figs = []
        for y_value in metrics:
            figs.append(benchmark_plot(RESULTS_PATH, X_VALUE, y_value))


        with PdfPages(RESULTS_PATH.replace(".csv", ".pdf")) as pdf:
            for fig in figs:
                pdf.savefig(fig, bbox_inches="tight")

    elif mode == "rank":
        complete_rank_plot("charts/final/compression_results.json", k=64,
                           pdf_path="charts/final/results_k_64.pdf", x_field="rank")
    elif mode == "lvd_like":

        bpd_plot("charts/final/compression_results.json",
                      PDF_PATH="charts/final/bpd_mnist.pdf",
                      minoryticks=[1.15, 1.25, 1.35, 1.45])

        bpd_plot("charts/final/compression_results_fashion.json",
                      PDF_PATH="charts/final/bpd_fashion_mnist.pdf",
                      Y_TICKS=[3.2, 3.4, 3.6, 3.8, 4], minoryticks=[3.3, 3.5, 3.7, 3.9])

        bpd_plot("charts/final/shared.json", PDF_PATH="charts/final/shared.pdf",
                      rg_list=["quad_tree"], layer_type_list=["cp-s(k8)", "cp-s(k16)", "cp-s(k32)",
                                                              "hcpt"],
                      Y_TICKS=[1.2, 1.22, 1.24, 1.26, 1.28, 1.30])

        bpd_plot("charts/final/lvd_like/mnist_cp.json", PDF_PATH="charts/final/lvd_like/bpd_mnist_shared.pdf",
                      rg_list=["quad_tree"], layer_type_list=["rescal", "hcpt", "CP-s"])

    elif mode == "k_in_less":
        k_in_less_plot("charts/final/compression_results_fashion.json",
                       pdf_path="charts/final/k_in_k_128_fashion.pdf", k=128, x_field="num_params")
    elif mode == "crazynet":
        extra_shared_plot("charts/final/compression_results.json",
                      pdf_path="charts/final/qt_stdec_extra_shared.pdf", k_in=64)
    elif mode == "plot_params":
        plot_params("charts/final/rg_params.json",
                    pdf_path="charts/final/rg_params.pdf")
    elif mode == "sparsity":

        import os
        os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
        from LoRaEinNet.LoRaEinNetwork import LoRaEinNetwork

        # model1: LoRaEinNetwork = torch.load("models/k_128_k_in_128.mdl", map_location="cpu")
        model2: LoRaEinNetwork = torch.load("models/stdec_k_64.mdl", map_location="cpu")

        with torch.no_grad():

            for name, model in [("k_in_8", model2)]: # [("k_in_128", model1), ("k_in_8", model2)]:

                first_layer = model.einet_layers[1]

                cp_a = first_layer.get_params_dict()["cp_a"]
                cp_b = first_layer.get_params_dict()["cp_b"]
                complete_weight = torch.empty(size=cp_a.shape[0:2] + (2*cp_a.shape[2],))
                for i in range(cp_a.shape[2]):
                    complete_weight[:,:,2*i+1] = cp_a[:,:,i]
                    complete_weight[:,:,2*i] = cp_b[:,:,i]

                #weights = torch.concat([torch.reshape(first_layer.get_params_dict()["cp_a"].data, (-1,)),
                #                 torch.reshape(first_layer.get_params_dict()["cp_b"].data, (-1,))],
                #                 dim=0)
                #print(weights.shape)
                #plot_weight_distribution(weights, name)


                #weights_2d = torch.concat(
                #    [torch.reshape(first_layer.get_params_dict()["cp_a"].data, (-1,)),
                #     torch.reshape(first_layer.get_params_dict()["cp_b"].data, (-1,))],
                #     dim=0)


                weights_2d = complete_weight.reshape(-1,392)
                weights_2d = complete_weight.reshape(-1, 14, 28)

                #weights_2d = complete_weight.reshape(-1,784)
                #weights_2d = complete_weight.reshape(-1, 28, 28)
                weights_2d_zeros = torch.where(weights_2d < 1e-16, 1.0, 0.0)

                final = weights_2d_zeros.sum(0)

                plt.matshow(final, cmap=plt.cm.Blues)
                plt.show()


                # weights_by_pixel =
    else:
        raise AssertionError("Unknown mode")