In [None]:
from pathlib import Path

model_name, ds_name = "resnet18", "cifar10"
path = Path("out") / model_name / ds_name

In [None]:
from glob import glob
from regex import match
import pandas as pd
import numpy as np
import json


def _is_epoch_analysis(p: str) -> bool:
    return Path(p).is_dir() and (match(r".*/\d+$", p) is not None)


analysis_paths = list(filter(_is_epoch_analysis, glob(str(path / "*"))))
dfs = []
for p in analysis_paths:
    m = match(r".*/(\d+)$", p)
    assert m is not None
    epoch = int(m.group(1))
    with (Path(p) / "svc" / "pairwise_rbf.json").open(
        mode="r", encoding="utf-8"
    ) as fp:
        doc = json.load(fp)  # Prevents loading numpy arrays
    data = [
        [epoch, k, np.mean([d["score"] for d in v])] for k, v in doc.items()
    ]
    df = pd.DataFrame(data, columns=["epoch", "submodule", "mean_score"])
    dfs.append(df)

df = pd.concat(dfs, ignore_index=True)
df

In [None]:
metrics = pd.read_csv(
    path / "model" / "csv_logs" / model_name / "version_0" / "metrics.csv"
)
metrics.drop(columns=["train/loss"], inplace=True)
metrics = metrics.groupby("epoch").tail(1)
metrics.reset_index(inplace=True, drop=True)
# metrics = metrics[metrics["epoch"].isin(df["epoch"])]
best_epoch = metrics["val/loss"].argmin()

metrics

In [None]:
import seaborn as sns

e = np.linspace(0, best_epoch, num=5, dtype=int)

figure = sns.lineplot(
    df[df["epoch"].isin(e)],
    x="submodule",
    y="mean_score",
    hue="epoch",
    size="epoch",
)
figure.set(title="Separability scores by epoch")
figure.set_xticklabels(
    figure.get_xticklabels(),
    rotation=45,
    rotation_mode="anchor",
    ha="right",
)
figure.get_figure().savefig(path / "separability_epoch.png")

In [None]:
figure = sns.lineplot(df, x="epoch", y="mean_score", hue="submodule")
sns.move_legend(figure, "upper left", bbox_to_anchor=(1, 1))
figure.axvline(best_epoch, linestyle=":", color="gray")
figure.set(title="Separability scores by submodule")
figure.get_figure().savefig(path / "separability_submod.png")

In [None]:
val_acc = metrics["val/acc"].to_numpy()
val_loss = metrics["val/loss"].to_numpy()
submodules = df[df["epoch"] == 0]["submodule"]
data = []
for s in submodules:
    a = df[df["submodule"] == s]["mean_score"].to_numpy()
    data.append(
        [
            s,
            np.corrcoef(val_acc, a)[0, 1],
            np.corrcoef(val_loss, a)[0, 1],
        ],
    )
correlations = pd.DataFrame(
    data,
    columns=[
        "submodule",
        "val/acc",
        "val/loss",
    ],
)
correlations

In [None]:
mcorr = correlations.melt(
    id_vars=["submodule"],
    var_name="sep. vs.",
    value_name="corr.",
)
grid = sns.FacetGrid(mcorr, col="sep. vs.")
grid.map(sns.barplot, "submodule", "corr.")
for ax in grid.axes_dict.values():
    ax.set_xticklabels(
        ax.get_xticklabels(),
        rotation=45,
        rotation_mode="anchor",
        ha="right",
    )
    ax.set_ylim(-1, 1)
grid.fig.savefig(path / "correlations.png")

In [None]:
import bokeh.plotting as bk
import bokeh.io

bokeh.io.output_notebook()

In [None]:
import turbo_broccoli as tb

data = {}
for p in analysis_paths:
    m = match(r".*/(\d+)$", p)
    assert m is not None
    epoch = int(m.group(1))
    if epoch > best_epoch:
        continue
    doc = tb.load_json(Path(p) / "umap" / "plots.json")
    plots = list(doc.values())
    for p in plots:
        p.height, p.width = 200, 200
        p.grid.visible, p.axis.visible = False, False
        # p.title = f"[{epoch}/{best_epoch}] {p.title.text}"
    data[epoch] = plots

figures = [data[i] for i in range(best_epoch)]
plot = bk.gridplot(figures)
bk.show(plot)

In [None]:
from bokeh.io import export_png

export_png(plot, filename=path / "umap_all.png")

In [None]:
from glob import glob
from regex import match
import numpy as np
import turbo_broccoli as tb


def _is_epoch_analysis(p: str) -> bool:
    return Path(p).is_dir() and (match(r".*/\d+$", p) is not None)


evaluations = {}
analysis_paths = list(filter(_is_epoch_analysis, glob(str(path / "*"))))
for p in analysis_paths:
    m = match(r".*/(\d+)$", p)
    assert m is not None
    epoch = int(m.group(1))
    evaluations[epoch] = tb.load_json(Path(p) / "eval" / "eval.json")
n_epochs = len(evaluations)

In [None]:
from nlnas import TorchvisionDataset
from nlnas.utils import get_first_n

ds = TorchvisionDataset(ds_name)
ds.setup("fit")
_, y = get_first_n(ds.train_dataloader(), 5000)

In [None]:
import pandas as pd
from nlnas.separability import gdv
from tqdm import tqdm

data = []
progress = tqdm(range(n_epochs), leave=False)
for e in progress:
    for k, x in evaluations[e].items():
        progress.set_postfix({"epoch": e, "submodule": k})
        v = gdv(x, y)
        data.append([e, k, float(v)])
df = pd.DataFrame(data, columns=["epoch", "submodule", "gdv"])
df.to_csv(path / "gdv.csv")
df

In [None]:
import seaborn as sns

e = np.linspace(0, best_epoch, num=5, dtype=int)

figure = sns.lineplot(
    df[df["epoch"].isin(e)],
    x="submodule",
    y="gdv",
    hue="epoch",
    size="epoch",
)
figure.set(title="GDV by epoch")
figure.set_xticklabels(
    figure.get_xticklabels(),
    rotation=45,
    rotation_mode="anchor",
    ha="right",
)
figure.get_figure().savefig(path / "gdv_epoch.png")

In [None]:
figure = sns.lineplot(df, x="epoch", y="gdv", hue="submodule")
sns.move_legend(figure, "upper left", bbox_to_anchor=(1, 1))
figure.axvline(best_epoch, linestyle=":", color="gray")
figure.set(title="GDV by submodule")
figure.get_figure().savefig(path / "gdv_submod.png")

In [None]:
import pandas as pd
from nlnas.separability import label_variation
from tqdm import tqdm

data = []
progress = tqdm(range(n_epochs), leave=False)
for e in progress:
    for k, x in evaluations[e].items():
        progress.set_postfix({"epoch": e, "submodule": k})
        v = label_variation(x, y, k=10)
        data.append([e, k, float(v)])
df = pd.DataFrame(data, columns=["epoch", "submodule", "lv"])
df.to_csv(path / "lv.csv")
df

In [None]:
import seaborn as sns

e = np.linspace(0, best_epoch, num=5, dtype=int)

figure = sns.lineplot(
    df[df["epoch"].isin(e)],
    x="submodule",
    y="lv",
    hue="epoch",
    size="epoch",
)
figure.set(title="Label variation by epoch")
figure.set_xticklabels(
    figure.get_xticklabels(),
    rotation=45,
    rotation_mode="anchor",
    ha="right",
)
figure.get_figure().savefig(path / "lv_epoch.png")

In [None]:
figure = sns.lineplot(df, x="epoch", y="lv", hue="submodule")
sns.move_legend(figure, "upper left", bbox_to_anchor=(1, 1))
figure.axvline(best_epoch, linestyle=":", color="gray")
figure.set(title="Label variation by submodule")
figure.get_figure().savefig(path / "lv_submod.png")