This notebook is used to check that methods in `correction.clustering` work as expected.

In [None]:
%load_ext jupyter_black
%load_ext autoreload
%autoreload 2

In [None]:
import bokeh.layouts as bkl
import bokeh.plotting as bk
from bokeh.io import output_notebook

from nlnas.plotting import export_png

output_notebook()

import sys

from loguru import logger as logging

logging.remove()
logging.add(
    sys.stdout,
    level="INFO",
    format="[<level>{level: <8}</level>] <level>{message}</level>",
)

# Loading stuff

In [None]:
from pathlib import Path
import turbo_broccoli as tb

HF_DATASET_NAME = "cifar100"
# HF_MODEL_NAME = "timm/tinynet_e.in1k"
HF_MODEL_NAME = "timm/mobilenetv3_small_050.lamb_in1k"
SUBMODULE = "model.conv_head"
VERSION = 0

DATASET_NAME = HF_DATASET_NAME.replace("/", "-")
MODEL_NAME = HF_MODEL_NAME.replace("/", "-")

RESULT_FILE_PATH = (
    Path("out/ftlcc") / DATASET_NAME / MODEL_NAME / f"results.{VERSION}.json"
)
RESULTS = tb.load_json(RESULT_FILE_PATH)

## Model

In [None]:
from nlnas.classifiers.timm import TimmClassifier

CKPT_PATH = Path("out/ftlcc") / RESULTS["model"]["best_checkpoint"]["path"]
logging.info("Best model checkpoint path: {}", CKPT_PATH)

model = TimmClassifier.load_from_checkpoint(CKPT_PATH)

## Dataset

In [None]:
from nlnas.datasets.huggingface import HuggingFaceDataset

dataset = HuggingFaceDataset(
    HF_DATASET_NAME,
    fit_split=RESULTS["dataset"]["train_split"],
    val_split=RESULTS["dataset"]["val_split"],
    test_split=RESULTS["dataset"]["test_split"],
    predict_split=RESULTS["dataset"]["train_split"],  # not a typo
    label_key=RESULTS["dataset"]["label_key"],
    image_processor=model.get_image_processor(HF_MODEL_NAME),
)

n_classes = dataset.n_classes()
y_true = dataset.y_true("train").numpy()
logging.info("y_true: {}", y_true.shape)

## Latent embeddings

In [None]:
from sklearn.preprocessing import StandardScaler

from nlnas.utils import load_tensor_batched

latent_embeddings = load_tensor_batched(
    RESULT_FILE_PATH.parent
    / "analysis"
    / str(RESULTS["model"]["best_checkpoint"]["version"])
    / "embeddings"
    / "train",
    prefix=SUBMODULE,
    tqdm_style="notebook",
)
latent_embeddings = latent_embeddings.numpy()
# latent_embeddings = latent_embeddings.reshape(len(latent_embeddings), -1)
# latent_embeddings = StandardScaler().fit_transform(latent_embeddings)

logging.info("Latent embedding array: {}", latent_embeddings.shape)

## Predictions 

In [None]:
logits = load_tensor_batched(
    RESULT_FILE_PATH.parent
    / "analysis"
    / str(RESULTS["model"]["best_checkpoint"]["version"])
    / "embeddings"
    / "train",
    prefix="y_pred",
    tqdm_style="notebook",
)
logits = logits.numpy()
y_pred = logits.argmax(axis=-1)

logging.info("logits: {}", logits.shape)

In [None]:
from torchmetrics.functional.classification import multiclass_accuracy
import torch

multiclass_accuracy(torch.tensor(logits), torch.tensor(y_true), num_classes=n_classes)

## Clustering data

In [None]:
# Recompute the clustering data from scratch

from tempfile import TemporaryDirectory

from nlnas.classifiers.base import full_dataset_latent_clustering

with TemporaryDirectory() as tmp:
    lc_data = full_dataset_latent_clustering(
        model,
        dataset,
        tmp,
        method="louvain",
        device="cuda",
        scaling="standard",
        classes=None,
        split="train",
        tqdm_style="notebook",
    )

y_clst = lc_data["model.conv_head"].y_clst
matching = lc_data["model.conv_head"].matching
knn_indices = lc_data["model.conv_head"].knn_indices

# Analysis

## Clustering

In [None]:
import numpy as np

# i_clst -> nb of samples in cluster i_clst
clst_size = {i_clst: (y_clst == i_clst).sum() for i_clst in np.unique(y_clst)}

# i_true -> nb of samples that are in clustered matched to i_true
n_matched = {
    i_true: sum(clst_size[j_clst] for j_clst in m) for i_true, m in matching.items()
}

In [None]:
# This is just for curiosity

i_true, n = sorted(n_matched.items(), key=lambda kv: kv[1], reverse=True)[0]
logging.info(
    (
        "Top true class by number of samples in matched clusters: \n"
        "  i_true = {}\n"
        "  matched clusters: {}\n"
        "  nb. of matched samples: {}"
    ),
    i_true,
    matching[i_true],
    n,
)

i_true = sorted(matching.keys(), key=lambda lbl: len(matching[lbl]), reverse=True)[0]
ns = set(map(lambda j_clst: clst_size[j_clst], matching[i_true]))
logging.info(
    (
        "Top true class by number of matched clusters: \n"
        "  i_true = {}\n"
        "  matched clusters: {}\n"
        "  nb. of samples in clusters (resp.): {}\n"
        "  total nb. of matched samples: {}"
    ),
    i_true,
    matching[i_true],
    ns,
    sum(ns),
)

In [None]:
from nlnas.correction.clustering import otm_matching_predicates, _mc_cc_predicates

p1, p2, p3, p4 = otm_matching_predicates(y_true, y_clst, matching)
p_mc, p_cc = _mc_cc_predicates(y_true, y_clst, matching)
logging.info(
    "OTM matching predicate shapes: {} {} {} {}", p1.shape, p2.shape, p3.shape, p4.shape
)
logging.info("MC/CC predicate shapes: {} {}", p_mc.shape, p_cc.shape)

In [None]:
# Testing if p1 is what is expected
# p1[i_true, j] is True if j-th sample in class i_true

for i_true in np.unique(y_true):
    a, b = np.where(y_true == i_true)[0], np.where(p1[i_true])[0]
    assert len(a) == len(b)
    assert (a == b).all()

In [None]:
# Testing if p2 is what is expected
# p2[i_true, j] is True if j-th sample is in a cluster matched to i_true

for i_true in np.unique(y_true):
    a = np.where(np.isin(y_clst, list(matching[i_true])))[0]
    b = np.where(p2[i_true])[0]
    assert len(a) == len(b)
    assert (a == b).all()

In [None]:
# Testing if p3 is what is expected
# p3[i_true, j] is True if j-th sample is in true class i_true but not in any
# cluster matched to i_true

for i_true, p in enumerate(p3):
    for j in np.where(p)[0]:
        assert y_clst[j] not in matching[i_true]

In [None]:
# Testing if p4 is what is expected
# p4[i_true, j] is True if j-th sample is NOT in true class i_true but in a
# cluster matched to i_true

for i_true, p in enumerate(p4):
    for j in np.where(p)[0]:
        assert y_true[j] != i_true
        assert y_clst[j] in matching[i_true]

At this point we're confident that the OTM matching predicates are accurate

In [None]:
# Testing if p_cc is what is expected
# p_cc[i_true, j] is True if j-th sample is in true class i_true and in a
# cluster matched to i_true

for i_true, p in enumerate(p_cc):
    for j in np.where(p)[0]:
        assert y_true[j] == i_true
        assert y_clst[j] in matching[i_true]

## KNN indices

Here we study the actual `LatentClusteringData` computed by `full_dataset_latent_clustering`, particularly the KNN indices within

In [None]:
# Reminder: if i_true is a key in knn_indices, then knn_indices[i_true] is a
# tuple containing
# 1. A NearestNeighbor object fitted on...
# 2. ... the set of correctly clustered samples

for i_true, (knn, v) in knn_indices.items():
    assert len(v) == p_cc[i_true].sum() == knn.n_samples_fit_
    w = latent_embeddings[p_cc[i_true]]
    assert v.shape == w.shape
    assert (v == w).all()

In [None]:
# Let's start with a random entry knn_indices dict

i_true, (knn, v) = next(iter(knn_indices.items()))
logging.info("i_true={}, v.shape={}", i_true, v.shape)
logging.info("Matched clusters ({}): {}", len(matching[i_true]), matching[i_true])

n = sum((y_clst == j_clst).sum() for j_clst in matching[i_true])
logging.info("Nb. of samples in matched clusters: {}", n)
logging.info("Nb. of correctly clustered samples: {}", p_cc[i_true].sum())
logging.info("Nb. of misclustered samples: {}", p_mc[i_true].sum())

In [None]:
from nlnas.correction.clustering import lcc_targets

targets = lcc_targets(
    torch.tensor(latent_embeddings), y_true, y_clst, matching, knn_indices
)

In [None]:
# Make sure that the target tensors look like what's expected

# First, there should be a target entry for each knn index
assert set(knn_indices.keys()) == set(targets.keys())

for i_true, (p, t) in targets.items():
    # p should point to misclusterd samples in true class i_true
    assert (p == p_mc[i_true]).all()
    assert (y_true[p] == i_true).all()
    # Shape of individual targets are what's expected
    assert t.shape[1:] == latent_embeddings.shape[1:]
    # There are as many targets as there are misclustered samples (in i_true)
    assert t.shape[0] == p.sum()

## Confusion

In [None]:
from nlnas.correction.choice import confusion_graph, heaviest_connected_subgraph
import numpy as np
import networkx as nx

graph = confusion_graph(y_pred, y_true, n_classes=len(np.unique(y_true)), threshold=10)
nx.draw_spring(graph, with_labels=True, node_size=100)

In [None]:
hcsg, w = heaviest_connected_subgraph(graph)
logging.info(
    "Top connected confusion: {} labels, {} confused samples", len(hcsg), int(w)
)
nx.draw_spring(hcsg, with_labels=True, node_size=100)

In [None]:
from nlnas.correction.choice import top_confusion_pairs

tcp5 = top_confusion_pairs(y_pred, y_true, n_classes=len(np.unique(y_true)), n_pairs=5)
a, b = tcp5[0]
logging.info(
    "Top confusion pair: {}, {} confused samples",
    tcp,
    graph.edges[a, b]["weight"],
)

In [None]:
idx = np.where(((y_true == a) & (y_pred == b)) | ((y_true == b) & (y_pred == a)))[0]
logging.info("Indices of {}/{} confused samples:\n{}", a, b, idx)