Compares latent space structure of a model throughout training

In [None]:
%load_ext jupyter_black
%load_ext autoreload
%autoreload 2

In [None]:
import bokeh.layouts as bkl
import bokeh.plotting as bk
from bokeh.io import output_notebook

from nlnas.plotting import export_png

output_notebook()

import sys

from loguru import logger as logging

logging.remove()
logging.add(
    sys.stdout,
    level="INFO",
    format="[<level>{level: <8}</level>] <level>{message}</level>",
)

# Prepare stuff

## Find all checkpoints

In [None]:
from pathlib import Path
import turbo_broccoli as tb

HF_DATASET_NAME = "cifar100"

HF_MODEL_NAME = "microsoft/resnet-18"
SUBMODULES = [
    "resnet.encoder.stages.0.layers.0",
    "resnet.encoder.stages.0.layers.1",
    "resnet.encoder.stages.1.layers.0",
    "resnet.encoder.stages.1.layers.1",
    "resnet.encoder.stages.2.layers.0",
    "resnet.encoder.stages.2.layers.1",
    "resnet.encoder.stages.3.layers.0",
    "resnet.encoder.stages.3.layers.1",
    "classifier",
]

# HF_MODEL_NAME = "timm/mobilenetv3_small_050.lamb_in1k"
# SUBMODULES = [
#     "model.blocks.0",
#     "model.blocks.1.0",
#     "model.blocks.1.1",
#     "model.blocks.2.0",
#     "model.blocks.2.1",
#     "model.blocks.2.2",
#     "model.blocks.3.0",
#     "model.blocks.3.1",
#     "model.blocks.4.0",
#     "model.blocks.4.1",
#     "model.blocks.4.2",
#     "model.blocks.5",
#     "model.conv_head",
#     "model.classifier",
# ]

# HF_MODEL_NAME = "timm/tinynet_e.in1k"
# SUBMODULES = [f"model.blocks.{i}" for i in range(7)] + [
#     "model.conv_head",
#     "model.classifier",
# ]


VERSION = 0

DATASET_NAME = HF_DATASET_NAME.replace("/", "-")
MODEL_NAME = HF_MODEL_NAME.replace("/", "-")

RESULT_FILE_PATH = (
    Path("out/ftlcc") / DATASET_NAME / MODEL_NAME / f"results.{VERSION}.json"
)
RESULTS = tb.load_json(RESULT_FILE_PATH)

In [None]:
from nlnas.training import all_checkpoint_paths

ckpts = all_checkpoint_paths(
    Path("out/ftlcc")
    / DATASET_NAME
    / MODEL_NAME
    / "tb_logs"
    / MODEL_NAME
    / f"version_{VERSION}"
)

logging.info("Found {} checkpoints", len(ckpts))
logging.info("Best epoch: {}", RESULTS["model"]["best_checkpoint"]["best_epoch"])

## Load dataset

In [None]:
from nlnas.datasets.huggingface import HuggingFaceDataset
from nlnas.classifiers.timm import TimmClassifier
from nlnas.classifiers.base import BaseClassifier
from nlnas.classifiers.huggingface import HuggingFaceClassifier
from nlnas.utils import get_reasonable_n_jobs

classifier_cls: type[BaseClassifier]
if HF_MODEL_NAME.startswith("timm/"):
    classifier_cls = TimmClassifier
else:
    classifier_cls = HuggingFaceClassifier

dataset = HuggingFaceDataset(
    HF_DATASET_NAME,
    fit_split=RESULTS["dataset"]["train_split"],
    val_split=RESULTS["dataset"]["val_split"],
    test_split=RESULTS["dataset"]["test_split"],
    predict_split=RESULTS["dataset"]["train_split"],  # not a typo
    train_dl_kwargs={
        "batch_size": 64,
        "num_workers": get_reasonable_n_jobs(),
    },
    label_key=RESULTS["dataset"]["label_key"],
    image_processor=classifier_cls.get_image_processor(HF_MODEL_NAME),
)

y_true = dataset.y_true("train").numpy()
n_classes, n_samples = dataset.n_classes(), len(y_true)
logging.info("y_true: {}", y_true.shape)

# Full DS clustering on every epoch

In [None]:
CLUSTERING_METHOD = "louvain"
PLOTS_PATH = Path("out/ftlcc") / DATASET_NAME / MODEL_NAME / "analysis" / str(VERSION)

## Loading

In [None]:
from tempfile import TemporaryDirectory

from tqdm.notebook import tqdm

from nlnas.classifiers.timm import TimmClassifier
from nlnas.classifiers.base import full_dataset_latent_clustering
from nlnas.training import checkpoint_ves

lc_data = {}

for ckpt in tqdm(ckpts):
    _, epoch, _ = checkpoint_ves(ckpt)
    output_dir = (
        Path("out/ftlcc")
        / DATASET_NAME
        / MODEL_NAME
        / "analysis"
        / str(VERSION)
        / str(epoch)
    )
    (output_dir / "louvain").mkdir(exist_ok=True, parents=True)
    g = tb.GuardedBlockHandler(output_dir / CLUSTERING_METHOD / "data.json")
    for _ in g:
        with TemporaryDirectory() as tmp:
            model = classifier_cls.load_from_checkpoint(ckpt)
            model.hparams["lcc_submodules"] = SUBMODULES
            data = full_dataset_latent_clustering(
                model=model,
                dataset=dataset,
                output_dir=tmp,
                method=CLUSTERING_METHOD,
                device="cuda",
                tqdm_style="notebook",
            )
            g.result = {sm: (d.y_clst, d.matching) for sm, d in data.items()}
    lc_data[epoch] = g.result

## Basic plots (r_clst, r_cc, r_mc)

In [None]:
import pandas as pd
import numpy as np

from nlnas.correction.clustering import otm_matching_predicates, _mc_cc_predicates

data = []
for epoch, d in tqdm(lc_data.items()):
    for sm, (y_clst, matching) in d.items():
        p_mc, p_cc = _mc_cc_predicates(y_true, y_clst, matching)
        row = {
            "epoch": epoch,
            "sm": sm,
            "r_clst": len(np.unique(y_clst)) / n_classes,
            "r_cc": p_cc.sum() / n_samples,
            "r_mc": p_mc.sum() / n_samples,
        }
        data.append(row)

df = pd.DataFrame(data)

In [None]:
import seaborn as sns

ax = sns.lineplot(data=df, x="epoch", y="r_clst", hue="sm")
sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
ax

In [None]:
ax = sns.lineplot(data=df, x="epoch", y="r_cc", hue="sm")
sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
ax

In [None]:
ax.get_figure().savefig(
    str(PLOTS_PATH / f"{CLUSTERING_METHOD}.r_cc_by_epoch_by_sm.png")
)

# Tracking

In this section, we track samples and the clusters they belong to across latent spaces

## Tracing the true label owning the cluster owning the sample

In [None]:
# First, get the array of all predictions
# y_preds: (n_epochs, N)


from tempfile import TemporaryDirectory

import pytorch_lightning as pl
import torch

torch.set_float32_matmul_precision("medium")

_yps = []
for ckpt in ckpts:
    model = TimmClassifier.load_from_checkpoint(ckpt)
    with TemporaryDirectory() as tmp:
        trainer = pl.Trainer(
            callbacks=[pl.callbacks.TQDMProgressBar()], default_root_dir=tmp
        )
        logit_batches = trainer.predict(model, dataset)
    y_pred = torch.cat(logit_batches).argmax(dim=-1).numpy()
    _yps.append(y_pred)

y_preds = np.array(_yps, dtype=int)
y_preds.shape

In [None]:
# Want a boolean array containing matching predictions
# The matching prediction of a sample j is i_true if i_true owns the cluster of
# sample j
# y_match[epoch, sm, j] is the true label that owns the cluster of sample j, at
# epoch e
# For example, in a super ideal world, y_match[e, s, j] would always be the true
# label of j, at least for the best epoch e

# Relevant OTM predicate: p2 since p2[i_true, j] is true if j is in a cluster
# owned by true class i_true

_yms = []
for epoch, d in tqdm(lc_data.items()):
    u = []
    for sm, (y_clst, matching) in tqdm(d.items(), leave=False):
        _, p2, _, _ = otm_matching_predicates(y_true, y_clst, matching)
        # (N,), y_clst_true[j] = true class that owns j's cluster
        y_clst_true = p2.argmax(axis=0)
        u.append(y_clst_true)
    _yms.append(u)

y_match = np.array(_yms, dtype=int)
y_match.shape

In [None]:
# Computing the accuracy of the matching predictions

# match_accs: (n_epochs, n_submods)

from sklearn.metrics import accuracy_score

_ma = [
    [accuracy_score(y_true, b) for b in a]  # a: (N,)
    for a in y_match  # a: (n_submods, N)
]

match_accs = np.array(_ma)
match_accs.shape

In [None]:
d = []
for e, v in enumerate(match_accs):
    for sm, a in zip(SUBMODULES, v):
        d.append({"epoch": e, "sm": sm, "acc": a})

df = pd.DataFrame(d)
df

In [None]:
ax = sns.lineplot(data=df, x="epoch", y="acc", hue="sm")
sns.move_legend(ax, "upper left", bbox_to_anchor=(1, 1))
ax

In [None]:
ax.get_figure().savefig(
    str(PLOTS_PATH / f"{CLUSTERING_METHOD}.clst_acc_by_epoch_by_sm.png")
)

# Cluster diffraction

Look at a latent cluster and see how its samples are clustered in the next LS 

In [None]:
# Array containing all diffractions of submodules in SUBMODULES
# (except the first one)
# difts: (n_epochs, s) where s represents module SUBMODULES[s+1]

from itertools import pairwise
from nlnas.correction.clustering import class_otm_matching


def diffraction(y_clst_1: np.ndarray, y_clst_2: np.ndarray) -> float:
    matching = class_otm_matching(y_clst_1, y_clst_2)
    _, _, p3, _ = otm_matching_predicates(y_clst_1, y_clst_2, matching)
    return p3.sum() / len(y_clst_1)


_dfs = []
for epoch, d in tqdm(lc_data.items()):
    all_y_clst = [v[0] for v in d.values()]
    _dfs.append(
        [
            diffraction(yc1, yc2)
            for yc1, yc2 in tqdm(list(pairwise(all_y_clst)), leave=False)
        ]
    )

difts = np.array(_dfs)
difts.shape

In [None]:
_data = []
for epoch, v in enumerate(difts):
    for sm, d in zip(SUBMODULES[1:], v):
        _data.append({"epoch": epoch, "sm": sm, "d": d})

df = pd.DataFrame(_data)
df

In [None]:
g = sns.FacetGrid(data=df, col="sm", col_wrap=3)
g.map(sns.lineplot, "epoch", "d")

In [None]:
g.fig.savefig(str(PLOTS_PATH / f"{CLUSTERING_METHOD}.diffraction_by_epoch_by_sm.png"))

# Class diffraction

Look at a true class and look where the samples in its matched clusters go in the next LS

In [None]:
from nlnas.classifiers import BaseClassifier
from nlnas import HuggingFaceClassifier, TimmClassifier

ClassifierClass: type[BaseClassifier]
if HF_MODEL_NAME.startswith("timm/"):
    ClassifierClass = TimmClassifier
else:
    ClassifierClass = HuggingFaceClassifier

In [None]:
FT_CKPT_PATH = Path("out") / "ft" / FT_RESULTS["fine_tuning"]["best_checkpoint"]["path"]
FT_MODEL = ClassifierClass.load_from_checkpoint(FT_CKPT_PATH)

LCC_CKPT_PATH = (
    Path("out") / "lcc" / LCC_RESULTS["correction"]["best_checkpoint"]["path"]
)
LCC_MODEL = ClassifierClass.load_from_checkpoint(LCC_CKPT_PATH)

In [None]:
from nlnas import HuggingFaceDataset

DATASET = HuggingFaceDataset(
    HF_DATASET_NAME,
    fit_split=FT_RESULTS["dataset"]["train_split"],
    val_split=FT_RESULTS["dataset"]["val_split"],
    test_split=FT_RESULTS["dataset"]["test_split"],
    predict_split=FT_RESULTS["dataset"]["train_split"],  # not a typo
    label_key=FT_RESULTS["dataset"]["label_key"],
    image_processor=ClassifierClass.get_image_processor(HF_MODEL_NAME),
)

Y_TRUE = dataset.y_true("train").numpy()
Y_TRUE.shape

In [None]:
from tempfile import TemporaryDirectory

import pytorch_lightning as pl
import torch
import turbo_broccoli as tb

g = tb.GuardedBlockHandler(FT_PATH / "lc" / "y_pred.st")
for _ in g.guard():
    with TemporaryDirectory() as tmp:
        trainer = pl.Trainer(
            callbacks=pl.callbacks.TQDMProgressBar(),
            default_root_dir=tmp,
        )
        data = trainer.predict(FT_MODEL, DATASET)
    g.result = {"": torch.concat(data).numpy()}
FT_Y_PRED = g.result[""].argmax(axis=-1)
FT_Y_PRED.shape

In [None]:
g = tb.GuardedBlockHandler(LCC_PATH / "lc" / "y_pred.st")
for _ in g.guard():
    with TemporaryDirectory() as tmp:
        trainer = pl.Trainer(
            callbacks=pl.callbacks.TQDMProgressBar(),
            default_root_dir=tmp,
        )
        data = trainer.predict(LCC_MODEL, DATASET)
    g.result = {"": torch.concat(data).numpy()}
LCC_Y_PRED = g.result[""].argmax(axis=-1)
LCC_Y_PRED.shape

In [None]:
# Truncate the dataset to only consider DD_N_SAMPLES samples for distance
# distribution computations
DD_N_SAMPLES = 10000

# For DD histograms
RESOLUTION = 500

In [None]:
import turbo_broccoli as tb

from nlnas.analysis.dd import distance_distribution

# Full dataset distance distribution (DD) data
# Each entry in the dict is itself a dict with three entries:
# - `d`: the pdist distance matrix (it's actually just a flat vector but whatever)
# - `hist` (RESOLUTION,): histogram counts
# - `edges` (RESOLUTION + 1,): histogram bin edges

g = tb.GuardedBlockHandler(
    FT_PATH / "lc" / "pdist" / "train" / "full" / (SUBMODULE + ".st")
)
for _ in g.guard():
    h, e = distance_distribution(FT_LE[:DD_N_SAMPLES])
    g.result = {"hist": h, "edges": e}
FT_DD_FULL = g.result

In [None]:
g = tb.GuardedBlockHandler(
    LCC_PATH / "lc" / "pdist" / "train" / "full" / (SUBMODULE + ".st")
)
for _ in g.guard():
    h, e = distance_distribution(LCC_LE[:DD_N_SAMPLES])
    g.result = {"hist": h, "edges": e}
LCC_DD_FULL = g.result

In [None]:
from nlnas.analysis.dd import distance_distribution_plot

SIZE = 250

ft_dd_plot = distance_distribution_plot(
    FT_DD_FULL["hist"], FT_DD_FULL["edges"], height=SIZE, n_dims=FT_LE.shape[-1]
)
ft_dd_plot.title = (
    f"[After fine-tuning] full DD, sm={SUBMODULE}, n_dims={FT_LE.shape[-1]}"
)

lcc_dd_plot = distance_distribution_plot(
    LCC_DD_FULL["hist"], LCC_DD_FULL["edges"], height=SIZE, n_dims=LCC_LE.shape[-1]
)
lcc_dd_plot.title = f"[After LCC] full DD, sm={SUBMODULE}, n_dims={LCC_LE.shape[-1]}"

figure = bkl.column([ft_dd_plot, lcc_dd_plot])
bk.show(figure)

In [None]:
figure = bk.figure(
    height=SIZE,
    width=SIZE * 2,
    toolbar_location=None,
    title=f"Full DD, after FT (red) vs. after LCC (blue), sm={SUBMODULE}",
)

x_range = (0, 1.1 * max(FT_DD_FULL["hist"].max(), LCC_DD_FULL["hist"].max()))
figure.line(FT_DD_FULL["edges"][:-1], FT_DD_FULL["hist"], color="red")
figure.line(LCC_DD_FULL["edges"][:-1], LCC_DD_FULL["hist"], color="blue")
bk.show(figure)

In [None]:
# Full dataset distance distribution (DD) data
# Each entry in the dict is itself a dict with three entries:
# - `d`: the pdist distance matrix (it's actually just a flat vector but whatever)
# - `hist` (RESOLUTION,): histogram counts
# - `edges` (RESOLUTION + 1,): histogram bin edges

import turbo_broccoli as tb

CLASSES = list(range(20))

FT_DD_INTRA = {}
for i in tqdm(CLASSES, leave=False):
    g = tb.GuardedBlockHandler(
        FT_PATH
        / "lc"
        / "pdist"
        / "train"
        / "intra-class"
        / str(i)
        / (SUBMODULE + ".st")
    )
    for _ in g.guard():
        h, e = distance_distribution(FT_LE[y_true == i][:DD_N_SAMPLES])
        g.result = {"hist": h, "edges": e}
    FT_DD_INTRA[i] = g.result

In [None]:
LCC_DD_INTRA = {}
for i in tqdm(CLASSES, leave=False):
    g = tb.GuardedBlockHandler(
        LCC_PATH
        / "lc"
        / "pdist"
        / "train"
        / "intra-class"
        / str(i)
        / (SUBMODULE + ".st")
    )
    for _ in g.guard():
        h, e = distance_distribution(LCC_LE[y_true == i][:DD_N_SAMPLES])
        g.result = {"hist": h, "edges": e}
    LCC_DD_INTRA[i] = g.result

In [None]:
from nlnas.analysis.dd import distance_distribution_plot

SIZE = 250

figure = bk.figure(height=SIZE, width=2 * SIZE, toolbar_location=None, x_range=(0, 2.5))

for i in CLASSES:
    figure.line(
        FT_DD_INTRA[i]["edges"][:-1], FT_DD_INTRA[i]["hist"], color="red", width=0.5
    )
    figure.line(
        LCC_DD_INTRA[i]["edges"][:-1], LCC_DD_INTRA[i]["hist"], color="blue", width=0.5
    )

bk.show(figure)

In [None]:
import numpy as np

figure = bk.figure(height=SIZE, width=2 * SIZE, toolbar_location=None, x_range=(0, 2.5))

figure.line(
    FT_DD_INTRA[0]["edges"][:-1],
    np.stack([d["hist"] for d in FT_DD_INTRA.values()]).mean(axis=0),
    color="red",
    width=1,
)
figure.line(
    LCC_DD_INTRA[0]["edges"][:-1],
    np.stack([d["hist"] for d in LCC_DD_INTRA.values()]).mean(axis=0),
    color="blue",
    width=0.5,
)

bk.show(figure)

In [None]:
from cuml import UMAP
from sklearn.preprocessing import MinMaxScaler

N_SAMPLES = 10000

g = tb.GuardedBlockHandler(FT_PATH / "lc" / "umap" / "train" / (SUBMODULE + ".st"))
for _ in g.guard():
    e = UMAP(n_components=2).fit_transform(FT_LE[:N_SAMPLES])
    e = MinMaxScaler().fit_transform(e)
    g.result = {"": e}
FT_LE_2D = g.result[""]

In [None]:
g = tb.GuardedBlockHandler(LCC_PATH / "lc" / "umap" / "train" / (SUBMODULE + ".st"))
for _ in g.guard():
    e = UMAP(n_components=2).fit_transform(LCC_LE[:N_SAMPLES])
    e = MinMaxScaler().fit_transform(e)
    g.result = {"": e}
LCC_LE_2D = g.result[""]

In [None]:
from nlnas.correction.choice import top_confusion_pairs, max_connected_confusion_choice

N_PAIRS = 1

n_classes = DATASET.n_classes()
ft_top_cp = top_confusion_pairs(FT_Y_PRED, Y_TRUE, n_classes, n_pairs=N_PAIRS)
lcc_top_cp = top_confusion_pairs(LCC_Y_PRED, Y_TRUE, n_classes, n_pairs=N_PAIRS)
inter_cp = list(set(ft_top_cp).intersection(lcc_top_cp))

print("FT top confusion pairs:", ft_top_cp)
print("LCC top confusion pairs:", lcc_top_cp)
print(f"In common ({len(inter_cp)} pairs):", inter_cp)

In [None]:
# Select samples

from more_itertools import flatten, unique

# labels = list(unique(flatten(inter_cp)))
labels = list(unique(flatten(ft_top_cp)))
mask = np.isin(Y_TRUE, labels)
print(f"Base labels (n_lbls={len(labels)}, n_smpls={mask.sum()}):", labels)

# Adding samples form matched clusters

ft_matched_lbls = list(flatten(FT_MATCHING[i] for i in labels))
msk = np.isin(FT_Y_CLST, ft_matched_lbls)
mask |= msk
print("Adding (at most)", msk.sum(), "extra samples from FT matching")

lcc_matched_lbls = list(flatten(LCC_MATCHING[i] for i in labels))
msk = np.isin(LCC_Y_CLST, ft_matched_lbls)
mask |= msk
print("Adding (at most)", msk.sum(), "extra samples from LCC matching")

print("Total:", mask.sum(), "samples")

In [None]:
from nlnas.plotting import class_scatter

_msk1 = mask[:N_SAMPLES]  # mask to select the from *_LE_2D and

_msk2 = np.full_like(Y_TRUE, True, dtype=bool)  # mask to select from *_Y_*
_msk2[N_SAMPLES:] = False  # select at most N_SAMPLES samples
_msk2 &= mask

SIZE = 250
kw = {"width": SIZE, "height": SIZE, "toolbar_location": None}

logging.info("Rendering FT/TRUE")
ft_true = bk.figure(title="FT, true", **kw)
class_scatter(ft_true, FT_LE_2D[_msk1], Y_TRUE[_msk2])

logging.info("Rendering FT/PRED")
ft_pred = bk.figure(title="FT, pred", **kw)
class_scatter(ft_pred, FT_LE_2D[_msk1], FT_Y_PRED[_msk2])

logging.info("Rendering FT/CLST")
ft_clst = bk.figure(title="FT, clst", **kw)
class_scatter(ft_clst, FT_LE_2D[_msk1], FT_Y_CLST[_msk2])

logging.info("Rendering LCC/TRUE")
lcc_true = bk.figure(title="LCC, true", **kw)
class_scatter(lcc_true, LCC_LE_2D[_msk1], Y_TRUE[_msk2])

logging.info("Rendering LCC/PRED")
lcc_pred = bk.figure(title="FT, pred", **kw)
class_scatter(lcc_pred, LCC_LE_2D[_msk1], LCC_Y_PRED[_msk2])

logging.info("Rendering LCC/CLST")
lcc_clst = bk.figure(title="LCC, clst", **kw)
class_scatter(lcc_clst, LCC_LE_2D[_msk1], LCC_Y_CLST[_msk2])

figure = bkl.column(
    [
        bkl.row([ft_true, ft_pred, ft_clst]),
        bkl.row([lcc_true, lcc_pred, lcc_clst]),
    ]
)
bk.show(figure)