In [None]:
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib
import einops

from glob import glob
import functools
import json
import os
import re

sns.set_theme(style="darkgrid", context="notebook", palette=sns.color_palette("rocket", 4))
matplotlib.rcParams['figure.figsize'] = (20, 10)

%matplotlib inline

In [None]:
# load all results
results_dirs = sorted(glob(f"results/g*"))
results_arrays = {file: np.load(file) for file in glob(f"results/**/*.npz", recursive=True)}
all_results = []
for dir in results_dirs:
    # read json
    file = os.path.join(dir, "metrics.json")
    with open(file) as f:
        results = json.load(f)
    # extract the model name from the filename
    model = re.search(r"results/(.*)/?", dir).group(1)
    # save layers without breaking df
    for task_name, problem in results.items():
        # add the data to the results (to be turned into df)
        all_results.append({
            "task": task_name,
            "model": model,
            "task/model": f"{task_name}/{model}",
            } | {
                # keep attributes as arrays
                problem_name: array
                for problem_name, array in problem.items()
            }
        )
df = pd.DataFrame(all_results)
df.head()

In [None]:
# generic compose function
compose = lambda *F: functools.reduce(lambda f, g: lambda x: f(g(x)), F)

In [None]:
def plot_path(results: pd.DataFrame, title: str, key: str = "path", split_key="task/model", transform=None):
    splits = results[split_key].unique()
    fig, axes = plt.subplots(len(splits))
    fig.tight_layout(rect=[0, 0.03, 1, 0.95])
    data = []
    for split in splits:
        results_filename = results.loc[results[split_key] == split].iloc[0].loc[key]
        path = results_arrays[results_filename][key]
        flattened_path, shape = einops.pack(path, "* h")
        pca = PCA(2)
        flattened_path = pca.fit_transform(flattened_path)
        path = einops.unpack(flattened_path, shape, "* h")
        for i, sample in enumerate(path):
            for j, layer in enumerate(sample):
                data.append({
                    "split": split,
                    "layer": j,
                    "x": layer[0],
                    "y": layer[1],
                    "sample": i
                })
            if i > 10:
                break

    # convert results to a dataframe
    data = pd.DataFrame(data)

    # plot results on each axis
    for i, split in enumerate(splits):
        ax = axes[i] if len(splits) > 1 else axes
        # plot a line
        sns.lineplot(data=data[data["split"] == split], x="x", y="y", hue="sample", ax=ax, legend=False)
        # perform a transform to the axis
        if transform is not None:
            transform(ax)
        ax.set_title(split)
        ax.set_xlabel("")
    plt.suptitle(title)


In [None]:
plot_path(df, "Path")