In [None]:
%load_ext jupyter_black
%load_ext autoreload
%autoreload 2

In [None]:
import bokeh.layouts as bkl
import bokeh.plotting as bk
from bokeh.io import output_notebook

from nlnas.plotting import export_png

output_notebook()

import sys

from loguru import logger as logging

logging.remove()
logging.add(sys.stderr, level="INFO")

In [None]:
HF_MODEL_NAME = "timm/mobilenetv3_small_050.lamb_in1k"
HEAD_NAME = "model.classifier"

In [None]:
HF_DATASET_NAME = "cifar100"  # Name in Hugging Face's dataset index

TRAIN_SPLIT = "train[:80%]"  # See HF dataset page for split name
VAL_SPLIT = "train[80%:]"  # See HF dataset page for split name
TEST_SPLIT = "test"  # See HF dataset page for split name
IMAGE_KEY = "img"  # See HF dataset page for name of dataset column
LABEL_KEY = "fine_label"  # See HF dataset page for name of dataset column

In [None]:
# Filesystem-friendly names
DATASET_NAME = HF_DATASET_NAME.replace("/", "-")
MODEL_NAME = HF_MODEL_NAME.replace("/", "-")

from pathlib import Path

LC_PATH = Path("out") / "ft" / DATASET_NAME / MODEL_NAME / "lc"
assert LC_PATH.is_dir()

# Dataset loading


In [None]:
from nlnas import HuggingFaceClassifier, HuggingFaceDataset, TimmClassifier

if HF_MODEL_NAME.startswith("timm/"):
    classifier_cls = TimmClassifier
else:
    classifier_cls = HuggingFaceClassifier

dataset = HuggingFaceDataset(
    HF_DATASET_NAME,
    fit_split=TRAIN_SPLIT,
    val_split=VAL_SPLIT,
    test_split=TEST_SPLIT,
    predict_split=TRAIN_SPLIT,  # not a typo
    label_key=LABEL_KEY,
    image_processor=classifier_cls.get_image_processor(HF_MODEL_NAME),
)

y_true = dataset.y_true("train").numpy()
y_true.shape

# Latent representation loading


In [None]:
import turbo_broccoli as tb

matching_data = tb.load(LC_PATH / "louvain" / "data.json")
y_clst = {sm: y for sm, (y, _) in matching_data.items()}
matching = {sm: m for sm, (_, m) in matching_data.items()}

list(matching_data.keys())

In [None]:
import numpy as np

# Latent spaces to process
SUBMODULES = list(matching_data.keys())
# SUBMODULES = [HEAD_NAME]

# Need the latent embeddings of the last submodule, aka the output logits
assert HEAD_NAME in SUBMODULES

# Subset of classes to consider in case the true number of classes is too large
# CLASSES = np.array(range(100))
CLASSES = np.arange(dataset.n_classes())

class_mask = np.isin(y_true, CLASSES)

y_true = y_true[class_mask]
y_true.shape

In [None]:
from sklearn.preprocessing import StandardScaler
from tqdm.notebook import tqdm

from nlnas.utils import load_tensor_batched

latent_embeddings = {}
for sm in tqdm(SUBMODULES):
    u = load_tensor_batched(
        LC_PATH / "embeddings" / "train",
        prefix=sm,
        mask=class_mask,
        tqdm_style="notebook",
    )
    u = u.numpy()
    u = u.reshape(len(u), -1)
    u = StandardScaler().fit_transform(u)
    latent_embeddings[sm] = u

for sm, u in latent_embeddings.items():
    print(sm, u.shape)

In [None]:
logits = latent_embeddings[HEAD_NAME]
y_pred = logits.argmax(axis=-1)
logits.shape

## Confusion matrix

In [None]:
from sklearn.metrics import accuracy_score

acc = accuracy_score(y_true, y_pred)
print("Accuracy score:", acc)

In [None]:
from sklearn.metrics import confusion_matrix as _confusion_matrix

confusion_matrix = _confusion_matrix(y_true=y_true, y_pred=y_pred)
confusion_matrix.shape

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import ConfusionMatrixDisplay

size = 25
plt.rcParams["figure.figsize"] = (size, size)

off_diag_cm = confusion_matrix * (1 - np.eye(len(confusion_matrix)))
cmd = ConfusionMatrixDisplay(off_diag_cm)
cmd.plot(include_values=False, colorbar=False)

plt.tight_layout()
plt.show()

In [None]:
m = confusion_matrix * (
    1 - np.eye(len(confusion_matrix))
)  # Remove the diagonal
idx = m.argsort(axis=None)  # Flat indices
idx = np.flip(idx)
confusion_pairs = np.unravel_index(idx, confusion_matrix.shape)
confusion_pairs = np.stack(confusion_pairs).T
confusion_pairs = np.array(
    [[i, j] for i, j in confusion_pairs if confusion_matrix[i, j] > 0]
)

# confusion_pairs (M, 2): coordinates of non-diagonal strictly positive entries in
# confusion_matrix in decreasing order. So if i, j = confusion_pairs[k], then
# confusion_matrix[i, j] is the number of samples in true class i classified in j.
# Furthermore, i != j and confusion_matrix[i, j] > 0.

print(len(confusion_pairs), "confusion pairs")
print("Top confusions:")
for i in range(5):
    a, b = confusion_pairs[i]
    c = confusion_matrix[a, b]
    print(f"- y_true={a:>3}, y_pred={b:>3}, n_err={c}")

# Distance distributions (true classes)


In [None]:
import bokeh.layouts as bkl

# TEST: The blue and grey line should roughly match
import numpy as np
from scipy.spatial.distance import pdist

from nlnas.analysis.dd import distance_distribution, distance_distribution_plot

_n, _figs = 1024, []
for _nd in [1, 2, 8, 512, 2048]:
    _a = np.random.randn(_n, _nd)
    _b = pdist(_a, metric="euclidean") / np.sqrt(_nd)
    h, e = np.histogram(_b, bins=500)
    _fig = distance_distribution_plot(h, e, _nd, height=100)
    _fig.title = f"Test: n={_n}, n_dim={_nd}"
    _figs.append(_fig)
bk.show(bkl.row(_figs))

## Full


In [None]:
# Truncate the dataset to only consider DD_N_SAMPLES samples for distance
# distribution computations
DD_N_SAMPLES = 2048

# For DD histograms
RESOLUTION = 500

In [None]:
# Full dataset distance distribution (DD) data
# Each entry in the dict is itself a dict with three entries:
# - `d`: the pdist distance matrix (it's actually just a flat vector but whatever)
# - `hist` (RESOLUTION,): histogram counts
# - `edges` (RESOLUTION + 1,): histogram bin edges

import numpy as np
import turbo_broccoli as tb
from scipy.spatial.distance import pdist

full_ds_dd = {}
for sm, u in tqdm(latent_embeddings.items()):
    g = tb.GuardedBlockHandler(
        LC_PATH / "pdist" / "train" / "full" / (sm + ".st")
    )
    for _ in g.guard():
        h, e = distance_distribution(u[:DD_N_SAMPLES])
        g.result = {"hist": h, "edges": e}
    full_ds_dd[sm] = g.result

In [None]:
import bokeh.layouts as bkl

SIZE = 250

rows = []
for sm, dat in tqdm(full_ds_dd.items()):
    h, e, u = dat["hist"], dat["edges"], latent_embeddings[sm]
    figure = distance_distribution_plot(h, e, u.shape[-1], height=SIZE)
    figure.title = (
        "pdist distribution (black) vs. expected χ (grey)\n"
        f"sm={sm}, n={DD_N_SAMPLES}, res={RESOLUTION}"
    )
    rows.append(figure)

figure = bkl.column(rows)
bk.show(figure)

In [None]:
export_png(figure, filename=LC_PATH / "full_ds_dd.png")

## Per-class


In [None]:
# intra_class_dd is a two level dict that maps
# submodule_name → true class number → class distance distribution
# where the class distance distribution data is compliled in a dict with three keys:
# - `d`: the pdist distance matrix (it's actually just a flat vector but whatever)
# - `hist` (RESOLUTION,): histogram counts
# - `edges` (RESOLUTION + 1,): histogram bin edges

from collections import defaultdict

from tqdm.notebook import tqdm

intra_class_dd = defaultdict(dict)
for sm, u in tqdm(latent_embeddings.items()):
    for i in tqdm(CLASSES[:20], leave=False):
        g = tb.GuardedBlockHandler(
            LC_PATH / "pdist" / "train" / "intra-class" / str(i) / (sm + ".st")
        )
        for _ in g.guard():
            h, e = distance_distribution(u[y_true == i][:DD_N_SAMPLES])
            g.result = {"hist": h, "edges": e}
        intra_class_dd[sm][i] = g.result

In [None]:
from nlnas.analysis.dd import distance_distribution_plot

SIZE = 250

rows = []
for sm in tqdm(SUBMODULES):
    h, e = full_ds_dd[sm]["hist"], full_ds_dd[sm]["edges"]
    figure = distance_distribution_plot(h, e, height=SIZE, include_chi=False)
    figure.title = (
        "Distance distributions: full ds (black) vs. intra-classes (green)\n"
        f"{sm}, n={DD_N_SAMPLES}, res={RESOLUTION}"
    )
    for i, dat in intra_class_dd[sm].items():
        e, h = dat["edges"][:-1], dat["hist"]
        figure.line(e, h, color="green", width=0.5)
    rows.append(figure)

figure = bkl.column(rows)
bk.show(figure)

In [None]:
export_png(figure, filename=LC_PATH / "intra_class_dd.png")

## Inter-class


In [None]:
# inter_class_dd is similar to intra_class_dd: it is a two level dict that maps
# submodule → pair of classes (i, j) -> inter-class distance distribution
# where the class distance distribution data is compliled in a dict with three keys:
# - `d`: the flattened cdist distance matrix
# - `hist` (RESOLUTION,): histogram counts
# - `edges` (RESOLUTION + 1,): histogram bin edges

from collections import defaultdict
from itertools import combinations

from scipy.spatial.distance import cdist
from tqdm.notebook import tqdm

# Considering every class pair would be too much
CLASS_PAIRS = list(combinations(CLASSES[:5], 2))
print("Considering", len(CLASS_PAIRS), "class pairs")

inter_class_dd = defaultdict(dict)
for sm, u in tqdm(latent_embeddings.items()):
    for i, j in tqdm(CLASS_PAIRS, leave=False):
        g = tb.GuardedBlockHandler(
            LC_PATH
            / "pdist"
            / "train"
            / "inter-class"
            / f"{i}-{j}"
            / (sm + ".st")
        )
        for _ in g.guard():
            ui = u[y_true == i][:DD_N_SAMPLES]
            uj = u[y_true == j][:DD_N_SAMPLES]
            h, e = distance_distribution(ui, uj)
            g.result = {"hist": h, "edges": e}
        inter_class_dd[sm][(i, j)] = g.result

In [None]:
SIZE = 250

rows = []
for sm in tqdm(SUBMODULES):
    h, e = full_ds_dd[sm]["hist"], full_ds_dd[sm]["edges"]
    figure = distance_distribution_plot(h, e, height=SIZE, include_chi=False)
    figure.title = (
        "Distance distribution: full ds (black) "
        "    vs. intra-class (green)\n"
        "    vs. inter-classes (red)\n"
        f"{sm}, n={DD_N_SAMPLES}, res={RESOLUTION}"
    )
    for i, dat in intra_class_dd[sm].items():
        e, h = dat["edges"][:-1], dat["hist"]
        figure.line(e, h, color="green", width=0.5)
    for (i, j), dat in inter_class_dd[sm].items():
        e, h = dat["edges"][:-1], dat["hist"]
        figure.line(e, h, color="red", width=0.5)
    rows.append(figure)

figure = bkl.column(rows)
bk.show(figure)

In [None]:
export_png(figure, filename=LC_PATH / "inter_class_dd.png")

# Dim-redux


In [None]:
# Consider DR_N_SAMPLES for dimensionality reduction
DR_N_SAMPLES = 50000

In [None]:
from cuml import UMAP
from sklearn.preprocessing import MinMaxScaler

latent_embeddings_2d = {}
for sm, u in tqdm(latent_embeddings.items()):
    g = tb.GuardedBlockHandler(LC_PATH / "umap" / "train" / (sm + ".st"))
    for _ in g.guard():
        e = UMAP(n_components=2).fit_transform(u[:DR_N_SAMPLES])
        e = MinMaxScaler().fit_transform(e)
        g.result = {"": e}
    latent_embeddings_2d[sm] = g.result[""]

# Highly confused classes


In [None]:
# confused_true_pred_dd is a dict that maps
# submodule → (i_true, j_pred) -> inter-class distance distribution data
# between true class i_true and predicted class j_pred.
# As usual, distance distribution data are compiled into a dict
# - `d`: the cdist distance vector
# - `hist` (RESOLUTION,): histogram counts
# - `edges` (RESOLUTION + 1,): histogram bin edges

from collections import defaultdict

N_CONFUSIONS = 10

confused_true_pred_dd = defaultdict(dict)
for sm, u in tqdm(latent_embeddings.items()):
    for i_true, j_pred in tqdm(confusion_pairs[:N_CONFUSIONS], leave=False):
        ui = u[y_true == i_true]  # [:DD_N_SAMPLES]
        uj = u[y_pred == j_pred]  # [:DD_N_SAMPLES]
        h, e = distance_distribution(ui, uj)
        confused_true_pred_dd[sm][(i_true, j_pred)] = {
            "edges": e,
            "hist": h,
        }

In [None]:
from bokeh.palettes import viridis

SIZE = 300

rows = []
for sm, data in tqdm(confused_true_pred_dd.items()):
    h, e = full_ds_dd[sm]["hist"], full_ds_dd[sm]["edges"]
    figure = distance_distribution_plot(h, e, height=SIZE, include_chi=False)
    figure.title = (
        "Distance distribution: full ds (black)\n"
        "vs. inter-class in high-confusion true-pred pairs (viridis, darker = more confused)\n"
        f"{sm}, n={DD_N_SAMPLES}, res={RESOLUTION}"
    )

    everything = list(
        zip(
            data.items(),
            viridis(len(data)),  # more confused = darker tones
        )
    )
    everything = everything[::-1]  # draw from least to most confused
    for ((i, j), dd), color in everything:
        h, e = dd["hist"], dd["edges"][:-1]
        figure.line(e, h, color=color, width=1)

    rows.append(figure)

figure = bkl.column(rows)
bk.show(figure)

In [None]:
export_png(figure, filename=LC_PATH / "inter_class_confused_dd.png")

In [None]:
SIZE = 500

i_true, j_pred = confusion_pairs[0]
p1 = (y_true == i_true) & (y_pred == i_true)  # correctly classified as i_true
p2 = (y_true == j_pred) & (y_pred == j_pred)  # correctly classified as j_pred
p3 = (y_true == i_true) & (y_pred == j_pred)  # i_true but classified in j_pred
p = p1 | p2 | p3

rows = []
for sm, u in tqdm(latent_embeddings_2d.items()):
    figure = bk.figure(width=SIZE, height=SIZE)
    figure.title = (
        f"Correctly classified {i_true} → {i_true} (blue), "
        f"{j_pred} → {j_pred} (green)\n"
        f"Misclassified {i_true} → {j_pred} (red)\n" + sm
    )

    figure.scatter(u[p1][:, 0], u[p1][:, 1], marker="x", color="blue", size=2)
    figure.scatter(u[p2][:, 0], u[p2][:, 1], marker="x", color="green", size=1)
    figure.scatter(u[p3][:, 0], u[p3][:, 1], color="red", size=3)

    rows.append(figure)

figure = bkl.column(rows)
bk.show(figure)

In [None]:
from sklearn.neighbors import NearestNeighbors

SIZE = 500

rows = []
for sm, u in tqdm(latent_embeddings.items()):
    u_2d = latent_embeddings_2d[sm]

    index = NearestNeighbors(n_neighbors=5)
    index.fit(u[p1])  # KNN index of all samples i_true → i_true
    knn_dst, knn_idx = index.kneighbors(
        u[p3]
    )  # KNNs of all samples i_true → j_pred

    figure = bk.figure(width=SIZE, height=SIZE)
    figure.title = (
        f"Correctly classified {i_true} → {i_true} (blue), {j_pred} → {j_pred} (green)\n"
        f"Misclassified {i_true} → {j_pred} (red)\n" + sm
    )

    figure.scatter(u_2d[p1][:, 0], u_2d[p1][:, 1], color="blue", size=2)
    figure.scatter(u_2d[p2][:, 0], u_2d[p2][:, 1], color="green", size=1)
    figure.scatter(u_2d[p3][:, 0], u_2d[p3][:, 1], color="red", size=4)

    # Iterate over all 2D repr. of misclassified samples i_true → j_pred
    for ia, a in enumerate(u_2d[p3]):
        ib, d = knn_idx[ia, 0], knn_dst[ia, 0]
        b, width = u_2d[p1][ib], 1 / (1 + np.exp(-d)) * 2 - 1
        figure.line([a[0], b[0]], [a[1], b[1]], color="black", width=width)

    rows.append(figure)

figure = bkl.column(rows)
bk.show(figure)

In [None]:
export_png(figure, filename=LC_PATH / "knn_error_correction.png")

# LCC on highly confused classes

In this section we consider the subset of sample belonging to the most highly confused classes, and the logits associated with them. We can think of these logits as latent embeddings in a n_classes-dimensional latent space. We are interested in applying latent cluster correction in this setting.

Of course this is a simplified setting since we operate directly on the latent representations rather than the weights that produced them.

In [None]:
from functools import partial

import torch
from sklearn.metrics import hamming_loss
from torch import optim
from torch.nn.functional import cross_entropy

from nlnas.correction import clustering_loss, otm_matching_predicates


def standardize_labels(
    y: torch.Tensor, matching: dict[int, set[int]]
) -> tuple[np.ndarray, dict[int, set[int]]]:
    """
    The keys of `matching` must be among the values of `y`.

    Returns an equivalent label vector and matching but where the values of the
    label vector are consecutive numbers starting from 0.
    """
    _y, _m = torch.zeros_like(y), {}
    for i_new, i_old in enumerate(torch.unique(y)):
        _y[y == i_old], _m[i_new] = i_new, matching[int(i_old)]
    return _y, _m


def full_louvain_pipeline(
    u: torch.Tensor,
    yt: torch.Tensor,
    k: int = 10,
) -> tuple[
    np.ndarray, list[set[int]], dict[int, set[int]], torch.Tensor, np.ndarray
]:
    """
    Returns
    * the Louvain cluster label vector;
    * the Louvain communities;
    * a matching between the true labels of `yt` and the Louvain labels;
    * the Louvain loss
    * the two-dimentional missclustering predicate: words, `p3[a, i]` is `True`
      if sample `i` is in true class `a` but not in any  Louvain class matched
      with `a`.
    """
    if not isinstance(u, torch.Tensor):
        u = torch.tensor(u)
    communities, yc = louvain_communities(u, k=k)
    matching = class_otm_matching(yt, yc)
    _yt, _m = standardize_labels(yt, matching)
    _, _, p_miss, _ = otm_matching_predicates(_yt, yc, _m)
    loss = clustering_loss(u, _yt, yc, _m, k)
    return yc, communities, matching, loss, p_miss


def evaluate(
    u: torch.Tensor,
    y_true: torch.Tensor,
    weight_clst: float = 1,
    weight_ce: float = 1,
    k: int = 20,
    extra_data: bool = True,
) -> tuple[torch.Tensor, dict]:
    """Returns the loss and a bunch of data in a dict"""
    y_clst, communities, matching, loss_clst, p_miss = full_louvain_pipeline(
        u, y_true, k=k
    )
    loss_ce = cross_entropy(u, y_true)
    loss = weight_ce * loss_ce + weight_clst * loss_clst
    u_detach = u.detach().cpu().numpy().copy()
    y_pred = u_detach.argmax(axis=-1)
    epoch_data = {
        "loss_clst": loss_clst.item(),
        "loss_ce": loss_ce.item(),
        "loss": loss.item(),
        "acc": accuracy_score(y_true.detach().cpu(), y_pred),
        "n_err": (y_true.cpu().numpy() != y_pred).sum(),
        "n_miss": p_miss.sum(),
        "n_communities": len(communities),
        "n_unmatched": sum(len(v) == 0 for v in matching.values()),
    }
    if extra_data:
        epoch_data["u"] = u_detach
        epoch_data["communities"] = communities
        epoch_data["matching"] = matching
        epoch_data["y_clst"] = y_clst
        epoch_data["y_pred"] = y_pred
        epoch_data["p_miss"] = p_miss
    return loss, epoch_data


def louvain_descent(
    u0: np.ndarray,
    y_true: np.ndarray | torch.Tensor,
    n_epochs: int,
    weight_clst: float = 1,
    weight_ce: float = 1,
    k: int = 20,
    lr: float = 1e-3,
    extra_data: bool = True,
) -> list[dict]:
    DEVICE = "cuda:0"
    u = torch.tensor(u0, requires_grad=True, device=DEVICE)
    if isinstance(y_true, np.ndarray):
        y_true = torch.tensor(y_true).long()
    y_true = y_true.to(DEVICE)
    optimizer = optim.Adam([u], lr=lr)
    training_data = []

    _eval = partial(
        evaluate,
        y_true=y_true,
        weight_clst=weight_clst,
        weight_ce=weight_ce,
        k=k,
        extra_data=extra_data,
    )

    # BEFORE TRAINING
    loss, initial_data = _eval(u)
    training_data.append(initial_data)

    # TRAINING
    progress = tqdm(range(n_epochs))
    for i in progress:
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        loss, epoch_data = _eval(u)
        training_data.append(epoch_data)
        progress.set_postfix(
            {
                "loss_clst": np.round(epoch_data["loss_clst"], 3),
                "acc": np.round(epoch_data["acc"], 3),
                "n_err": epoch_data["n_err"],
                "n_miss": epoch_data["n_miss"],
            }
        )

    return training_data

We create a predicate (aka a mask aka a boolean array) to select samples from the top `N_CONFUSION_PAIRS` pairs of most confused classes.

In [None]:
from nlnas.correction import class_otm_matching, louvain_communities

N_CONFUSION_PAIRS = 10

p = np.full_like(y_true, False, dtype=bool)
for i_true, j_pred in confusion_pairs[:N_CONFUSION_PAIRS]:
    p1 = (y_true == i_true) & (y_pred == i_true)  # i_true -> i_true
    p2 = (y_true == j_pred) & (y_pred == j_pred)  # j_pred -> j_pred
    p3 = (y_true == i_true) & (y_pred == j_pred)  # i_true -> j_pred
    print(
        "Label pair",
        i_true,
        "&",
        j_pred,
        ":",
        p3.sum(),
        "missclassified samples",
    )
    p = p | p1 | p2 | p3
print("Total:              ", p.sum(), "samples")

In [None]:
N_EPOCHS = 1000

K = 2
WEIGHT_CLST, WEIGHT_CE = 1, 1e-5
LR = 1e-3

training_data = louvain_descent(
    latent_embeddings[HEAD_NAME][p],
    y_true[p],
    n_epochs=N_EPOCHS,
    weight_clst=WEIGHT_CLST,
    weight_ce=WEIGHT_CE,
    k=K,
)

In [None]:
rows = []

kw = {"width": 500, "height": 200, "toolbar_location": None}
x = np.arange(len(training_data))

figure = bk.figure(title="Louvain loss", **kw)
figure.line(x, [d["loss_clst"] for d in training_data])
rows.append(figure)

figure = bk.figure(title="Accuracy", **kw)
figure.line(x, [d["acc"] for d in training_data])
rows.append(figure)

figure = bk.figure(title="Nb. of errors", **kw)
figure.line(x, [d["n_err"] for d in training_data])
rows.append(figure)

figure = bk.figure(title="Nb. of clusters", **kw)
figure.line(x, [d["n_communities"] for d in training_data])
rows.append(figure)

figure = bk.figure(title="Nb. missclustered samples", **kw)
figure.line(x, [d["n_miss"] for d in training_data])
rows.append(figure)

param_str = (
    f"w_clst={WEIGHT_CLST}, w_ce={WEIGHT_CE}, k={K}, lr={LR}, n={p.sum()}"
)
for figure in rows:
    figure.title.text += "\n" + param_str

figure = bkl.column(rows)
bk.show(figure)

In [None]:
export_png(
    figure,
    filename=LC_PATH
    / f"lcc_metrics_top_confused_wclst={WEIGHT_CLST}_wce={WEIGHT_CE}_k={K}_lr={LR}.png",
)

In [None]:
from cuml import UMAP
from sklearn.preprocessing import MinMaxScaler

from nlnas.plotting import class_scatter

N_SNAPSHOTS = 10
SIZE = 400

rows = []
kw = {"width": SIZE, "height": SIZE, "toolbar_location": None}

for i in tqdm(np.linspace(0, len(training_data) - 1, N_SNAPSHOTS, dtype=int)):
    rows.append([])

    d = training_data[i]
    u_2d = UMAP().fit_transform(d["u"])
    u_2d = MinMaxScaler().fit_transform(u_2d)
    u_err_2d = u_2d[d["y_pred"] != y_true[p]]
    n_err = d["n_err"]
    is_miss = d["p_miss"].sum(axis=0) > 0
    u_miss_2d, n_miss = u_2d[is_miss], is_miss.sum()

    figure = bk.figure(
        title=f"[epoch={i}] True labels, missclassified in red (n={n_err})",
        **kw,
    )
    class_scatter(figure, u_2d, y_true[p])
    figure.scatter(u_err_2d[:, 0], u_err_2d[:, 1], size=3, color="red")
    rows[-1].append(figure)

    figure = bk.figure(title=f"[epoch={i}] Louvain labels", **kw)
    class_scatter(figure, u_2d, d["y_clst"])
    rows[-1].append(figure)

    figure = bk.figure(
        title=f"[epoch={i}] Louvain labels, misclustered in red (n={n_miss})",
        **kw,
    )
    class_scatter(figure, u_2d, d["y_clst"])
    figure.scatter(u_miss_2d[:, 0], u_miss_2d[:, 1], size=3, color="red")
    rows[-1].append(figure)

param_str = (
    f"w_clst={WEIGHT_CLST}, w_ce={WEIGHT_CE}, k={K}, lr={LR}, n={p.sum()}"
)
for r in rows:
    for figure in r:
        figure.title.text += "\n" + param_str

figure = bkl.grid(rows)
bk.show(figure)

In [None]:
export_png(
    figure,
    filename=LC_PATH
    / f"lcc_umap_top_confused_wclst={WEIGHT_CLST}_wce={WEIGHT_CE}_k={K}_lr={LR}.png",
)

In [None]:
N_SNAPSHOTS = 10
SIZE = 250

rows = []
kw = {"width": 2 * SIZE, "height": SIZE, "toolbar_location": None}

for i in tqdm(np.linspace(0, len(training_data) - 1, N_SNAPSHOTS, dtype=int)):
    epoch_data = training_data[i]
    v, yt = epoch_data["u"], y_true[p]

    figure = bk.figure(**kw)
    figure.title = (
        f"[epoch={i}] Full dataset DD (black)\n"
        "vs. inter-class DD btw highly-confused class pairs (green)\n"
        "vs. intra-class DD (red)"
    )

    h, e = distance_distribution(v)
    figure.line(e[:-1], h, width=2, color="black")

    for i_true, j_pred in tqdm(
        confusion_pairs[:N_CONFUSION_PAIRS], leave=False
    ):
        h, e = distance_distribution(v[yt == i_true], v[yt == j_pred])
        figure.line(e[:-1], h, color="green")
        h, e = distance_distribution(v[yt == i_true])
        figure.line(e[:-1], h, color="red", width=0.25)
        h, e = distance_distribution(v[yt == j_pred])
        figure.line(e[:-1], h, color="red", width=0.25)

    rows.append(figure)

param_str = (
    f"w_clst={WEIGHT_CLST}, w_ce={WEIGHT_CE}, k={K}, lr={LR}, n={p.sum()}"
)
for figure in rows:
    figure.title.text += "\n" + param_str

figure = bkl.column(rows)
bk.show(figure)

In [None]:
export_png(
    figure,
    filename=LC_PATH
    / f"lcc_dd_top_confused_wclst={WEIGHT_CLST}_wce={WEIGHT_CE}_k={K}_lr={LR}.png",
)

# LCC on whole dataset

In [None]:
latent_embeddings[HEAD_NAME].shape

In [None]:
LCC_N_SAMPLES = len(latent_embeddings[HEAD_NAME]) // 1
LCC_N_SAMPLES

## Finding the right $k$

In [None]:
from nlnas.correction import otm_matching_predicates

K = 50

y_clst, communities, matching, loss, p_miss = full_louvain_pipeline(
    latent_embeddings[HEAD_NAME][:LCC_N_SAMPLES], y_true[:LCC_N_SAMPLES], k=K
)
print("Found", len(communities), "communities")
print("Loss:", np.round(loss.item(), 5))

n_classes = len(np.unique(y_true[:LCC_N_SAMPLES]))
n_unmatched = sum(len(v) == 0 for v in matching.values())
n_miss = p_miss.sum()
print(
    "Unmatched true classes:",
    n_unmatched,
    "->",
    np.round(n_unmatched / n_classes * 100, 3),
    "%",
)
print(
    "Missclustered samples:",
    n_miss,
    "->",
    np.round(n_miss / LCC_N_SAMPLES * 100, 3),
    "%",
)

In [None]:
import pandas as pd

df = pd.DataFrame(
    columns=[
        "k",
        "n_communities",
        "loss_clst",
        "n_unmatched",
        "n_miss",
        "r_unmatched",
        "r_miss",
        "community_size",
    ]
)
ks = [2, 3, 4, 5, 6, 7, 8, 9, 10]

n_samples, n_classes = len(y_true), len(np.unique(y_true))

for k in tqdm(ks):
    y_clst, communities, matching, loss, p_miss = full_louvain_pipeline(
        latent_embeddings[HEAD_NAME][:LCC_N_SAMPLES],
        y_true[:LCC_N_SAMPLES],
        k=k,
    )
    n_unmatched = sum(len(v) == 0 for v in matching.values())
    n_miss = p_miss.sum()
    df.loc[len(df)] = {
        "k": k,
        "n_communities": len(communities),
        "community_size": [len(c) for c in communities],
        "loss_clst": loss.item(),
        "n_unmatched": n_unmatched,
        "r_unmatched": n_unmatched / n_classes,
        "n_miss": n_miss,
        "r_miss": n_miss / n_samples,
    }

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

fig, ax = plt.subplots(2, 2, figsize=(10, 5))
sns.lineplot(data=df, x="k", y="loss_clst", ax=ax[0][0])
sns.lineplot(data=df, x="k", y="n_communities", ax=ax[0][1])
sns.lineplot(data=df, x="k", y="r_unmatched", ax=ax[1][0])
sns.lineplot(data=df, x="k", y="r_miss", ax=ax[1][1])

In [None]:
_df = df[["k", "community_size"]].explode("community_size")
sns.boxplot(data=_df, x="k", y="community_size")

## LCC

In [None]:
N_EPOCHS = 100

K = 2
WEIGHT_CLST, WEIGHT_CE = 1, 0
LR = 1e-3

training_data = louvain_descent(
    latent_embeddings[HEAD_NAME][:LCC_N_SAMPLES],
    y_true[:LCC_N_SAMPLES],
    n_epochs=N_EPOCHS,
    weight_clst=WEIGHT_CLST,
    weight_ce=WEIGHT_CE,
    k=K,
    extra_data=False,
)

In [None]:
rows = []

kw = {"width": 500, "height": 200, "toolbar_location": None}
x = np.arange(len(training_data))

figure = bk.figure(title="Louvain loss", **kw)
figure.line(x, [d["loss_clst"] for d in training_data])
rows.append(figure)

figure = bk.figure(title="Accuracy", **kw)
figure.line(x, [d["acc"] for d in training_data])
rows.append(figure)

figure = bk.figure(title="Nb. of errors", **kw)
figure.line(x, [d["n_err"] for d in training_data])
rows.append(figure)

figure = bk.figure(title="Nb. of clusters", **kw)
figure.line(x, [d["n_communities"] for d in training_data])
rows.append(figure)

figure = bk.figure(title="Nb. missclustered samples", **kw)
figure.line(x, [d["n_miss"] for d in training_data])
rows.append(figure)

param_str = f"w_clst={WEIGHT_CLST}, w_ce={WEIGHT_CE}, k={K}, lr={LR}, n={LCC_N_SAMPLES}"
for figure in rows:
    figure.title.text += "\n" + param_str

figure = bkl.column(rows)
bk.show(figure)

In [None]:
export_png(
    figure,
    filename=LC_PATH / f"lcc_metrics_full_ds_w={WEIGHT}_k={K}_lr={LR}.png",
)