In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
from drcomp.plotting import plot_trustworthiness_continuity, plot_lcmc
from drcomp.utils.notebooks import get_data_set, get_model_for_dataset
import json
from typing import Union
import matplotlib.pyplot as plt
import numpy as np
import matplotlib as mpl
from matplotlib.ticker import IndexLocator
import scienceplots
import os

plt.style.use(["science", "notebook"])

In [None]:
def load_metrics(dataset: str, reducer: str, throw_on_missing: bool = True):
    filename = f"{dataset}_{reducer}.json"
    if not os.path.exists(f"../metrics/{filename}"):
        raise FileNotFoundError(f"File {filename} not found in ../metrics/")
    with open(f"../metrics/{filename}", "r") as f:
        return json.load(f)


def load_all_metrics_for(
    datasets: str,
    reducers: list[str] = ["ConvAE", "PCA", "KernelPCA", "AE", "LLE", "CAE"],
    throw_on_missing: bool = True,
):
    metrics: dict[str, dict] = {}
    for reducer in reducers:
        try:
            metric = load_metrics(datasets, reducer)
        except FileNotFoundError:
            if throw_on_missing:
                raise
            else:
                continue
        metrics[reducer] = metric
    return metrics

In [None]:
def plot_metric(metric, label: str, ax=None):
    if ax is None:
        ax = plt.axes()
    k = len(metric)
    x = np.arange(1, k + 1)
    ax.plot(x, metric, label=label)
    ax.set_xlabel("$K$")
    ax.set_xlim(0, k + 2)
    # ax.set_ylim(0, 1)
    ax.xaxis.set_major_locator(IndexLocator(20, offset=-1))
    return ax

In [None]:
def save_fig(dir, fig, name: str, latex: bool = True, width=5.91, height=4.8, **kwargs):
    format = "png"
    backend = None
    if latex:
        format = "pgf"
        backend = "pgf"
        plt.style.use("science")
        fig.set_size_inches(w=width, h=height)
        fig.tight_layout()
    fig.savefig(f"{dir}/{name}.{format}", format=format, backend=backend, **kwargs)

In [None]:
LATEX_WIDTH = 5.91


def compare_metrics(metrics: dict, figsize=(8, 8)):
    fig = plt.figure(figsize=figsize)
    ax1 = plt.subplot(221)
    ax2 = plt.subplot(223)
    ax3 = plt.subplot(122)
    for name, metric in metrics.items():
        plot_metric(metric["trustworthiness"], label=name, ax=ax1)
        ax1.set_title("$T(K)$")
        plot_metric(metric["continuity"], label=name, ax=ax2)
        ax2.set_title("$C(K)$")
        plot_metric(metric["lcmc"], label=name, ax=ax3)
        ax3.set_title("$LCMC(K)$")
    plt.legend(metrics.keys())
    plt.tight_layout()
    return fig, [ax1, ax2, ax3]

In [None]:
fig, axs = compare_metrics(load_all_metrics_for("MNIST", throw_on_missing=False))

In [None]:
swiss_roll_fig, _ = compare_metrics(
    load_all_metrics_for("SwissRoll", throw_on_missing=False)
)

In [None]:
twin_peaks_fig, _ = compare_metrics(
    load_all_metrics_for("TwinPeaks", throw_on_missing=False)
)

In [None]:
compare_metrics(load_all_metrics_for("LfwPeople", throw_on_missing=False))

In [None]:
def save_fig(dir, fig, name: str, latex: bool = True, width=5.91, height=4.8, **kwargs):
    format = "png"
    backend = None
    if latex:
        format = "pgf"
        backend = "pgf"
        plt.style.use("science")
        fig.set_size_inches(w=width, h=height)
        fig.tight_layout()
    fig.savefig(f"{dir}/{name}.{format}", format=format, backend=backend, **kwargs)

In [None]:
save_fig("../figures", swiss_roll_fig, name="SwissRoll-comp", latex=True, height=6)
save_fig("../figures", twin_peaks_fig, name="TwinPeaks-comp", latex=True, height=6)

In [None]:
save_fig("../figures", fig, name="MNIST-comp", latex=True, height=6)