### Analysis of the DCv2 Results for the MNIST Dataset

This notebook analyses the results of the unsupervised clustering with DCv2 on the MNIST dataset provided by `torchvision`.

The aspects analyzed are:

1. Accuracy per MLP head.
2. Similarity between the clustering results of the different heads that cluster the same crops (which should become similar).


In [None]:
import sklearn.metrics
import torch
import pathlib
import torchvision

device = torch.device("cuda:0")
plots = pathlib.Path("/p/project1/deepacf/emmerich1/plots/paper-1")

In [None]:
mnist = torchvision.datasets.MNIST(
    root="/p/project1/deepacf/emmerich1/data/mnist",
    train=True,
    download=True,
)

In [None]:
import numpy as np
import sklearn.cluster
import sklearn.mixture


def infer_dataset_labels_for_clusters(
    labels_true: torch.LongTensor,
    labels_pred: torch.LongTensor,
    return_bincount: bool = False,
) -> dict[int, int]:
    """Infer which dataset label corresponds to the respective clusters.

    Parameters
    ----------
    cluster_indexes: torch.Tensor
        The cluster indexes.
    labels_true : np.ndarray
        The true labels for each sample from the dataset.

    Returns
    -------
    dict[int, int]
        Contains the k-Means cluster labels as keys, and the
        respective real dataset label as value.

    """
    # Maps cluster labels -> dataset labels (e.g. numbers for MNIST)
    # E.g. cluster index 2 might correspond
    # to the dataset label 7 (number 7 for MNIST).
    # Then, `labels_map = {..., 2: 7, ...}`.
    labels_map: dict[int, int] = {}
    n_clusters = labels_pred.unique().max() + 1

    for i in range(n_clusters):
        # Get indexes of the samples assigned to the cluster.

        # Get the dataset labels for each of the samples.
        labels = labels_true[labels_pred == i]
        # print(f"Dataset labels of the samples assigned to cluster {i}: {labels}")

        # Get the number of samples for each dataset label.
        counts = labels.bincount()
        # print(f"Counts per dataset label for cluster {i}: {counts}")

        # Assign dataset label with most samples to the respective cluster.
        labels_map[i] = counts if return_bincount else counts.argmax()
        # print(f"Cluster {i} corresponds to dataset label {labels_map[i]}")

    return labels_map


def infer_predicted_dataset_labels(
    labels_pred: torch.LongTensor, cluster_labels_map: dict[int, int]
) -> np.ndarray:
    """Infers the respective dataset label predicted by k-Means for each sample."""
    return torch.LongTensor(
        [cluster_labels_map[int(i)] for i in labels_pred]
    ).to(device=device)


def accuracy_score(
    labels_true: torch.LongTensor, labels_pred: torch.LongTensor
) -> float:
    return float((labels_true == labels_pred).sum() / labels_true.shape[0])

In [None]:
load_to_first_gpu = lambda storage, loc: storage.cuda(0)
path = pathlib.Path("/p/project1/deepacf/emmerich1/dcv2/mnist-1-node-1-gpu/")
epochs = [int(p.name.split("-")[1]) for p in path.glob("*-assignments.pt")]
epochs.sort()
epochs

In [None]:
%%time

accuracies = []

for epoch in epochs:
    indexes = torch.load(
        path / f"epoch-{epoch}-indexes.pt",
        map_location=load_to_first_gpu,
    )
    assignments = torch.load(
        path / f"epoch-{epoch}-assignments.pt",
        map_location=load_to_first_gpu,
    )

    accs = []

    for i in range(assignments.shape[0]):
        labels_dcv2 = assignments[i]
        # print(f"{labels_dcv2=}")

        labels_mnist = torch.LongTensor([sample[1] for sample in mnist]).to(
            device=device
        )
        # print(f"{labels_mnist=}")

        labels_map = infer_dataset_labels_for_clusters(
            labels_true=labels_mnist, labels_pred=labels_dcv2
        )
        # print(f"{labels_map=}")

        labels_dcv2_mnist = infer_predicted_dataset_labels(
            labels_pred=labels_dcv2, cluster_labels_map=labels_map
        )
        # print(f"{labels_dcv2_mnist=}")

        accuracy = accuracy_score(
            labels_true=labels_mnist, labels_pred=labels_dcv2_mnist
        )
        accs.append(accuracy)

    accuracies.append(max(accs))

dict(zip(epochs, accuracies))

In [None]:
for head_1, head_2 in [
    (assignments[0], assignments[2]),
    (assignments[1], assignments[3]),
]:
    labels_true = head_1
    labels_pred = head_2
    labels_map = infer_dataset_labels_for_clusters(
        labels_true=labels_true, labels_pred=labels_pred
    )
    # print(f"{labels_map=}")

    labels_dcv2_mnist = infer_predicted_dataset_labels(
        labels_pred=labels_pred, cluster_labels_map=labels_map
    )
    # print(f"{labels_dcv2_mnist=}")

    accuracy = accuracy_score(labels_true=labels_true, labels_pred=labels_pred)
    print(f"{accuracy=}")

In [None]:
distances = torch.load(
    path / "epoch-399-distances.pt",
    map_location=load_to_first_gpu,
)

In [None]:
import matplotlib.pyplot as plt

dist = distances[0]

for cluster_index in range(10):
    cluster_samples = assignments[0] == cluster_index
    dist_indexes = indexes[0][cluster_samples]
    dist_cluster = dist[cluster_samples]
    indexes_near_cluster_center = dist_indexes[dist_cluster.argsort()][:16]

    _, axs = plt.subplots(4, 4, figsize=(8, 8))
    print(f"==== Cluster {cluster_index} ===")
    for i, row in enumerate(axs):
        for j, col in enumerate(row):
            index = indexes_near_cluster_center[4 * i + j]
            sample, label = mnist[index]
            col.imshow(sample, cmap="Greys", interpolation=None)
            col.set_title(f"Number {label}")
            if i < 3:
                col.set_xticks([])
    plt.show()

In [None]:
# Get head with max accuracy
assignments = torch.load(
    path / "epoch-399-assignments.pt",
    map_location=load_to_first_gpu,
)

accs = []

for i in range(assignments.shape[0]):
    labels_dcv2 = assignments[i]
    # print(f"{labels_dcv2=}")

    labels_mnist = torch.LongTensor([sample[1] for sample in mnist]).to(
        device=device
    )
    # print(f"{labels_mnist=}")

    labels_map = infer_dataset_labels_for_clusters(
        labels_true=labels_mnist, labels_pred=labels_dcv2
    )
    # print(f"{labels_map=}")

    labels_dcv2_mnist = infer_predicted_dataset_labels(
        labels_pred=labels_dcv2, cluster_labels_map=labels_map
    )
    # print(f"{labels_dcv2_mnist=}")

    accuracy = accuracy_score(
        labels_true=labels_mnist, labels_pred=labels_dcv2_mnist
    )
    accs.append((i, accuracy))

In [None]:
max_acc_index = 2

labels_dcv2 = assignments[max_acc_index]

labels_mnist = torch.LongTensor([sample[1] for sample in mnist]).to(
    device=device
)

bincount_per_cluster = infer_dataset_labels_for_clusters(
    labels_true=labels_mnist, labels_pred=labels_dcv2, return_bincount=True
)
bincount_per_cluster

In [None]:
import matplotlib.pyplot as plt
import numpy as np

# data from https://allisonhorst.github.io/palmerpenguins/
width = 0.5

fig, ax = plt.subplots(figsize=(4, 3))

bottom = np.zeros(10)

for i in range(10):
    x = list(range(0, len(bincount_per_cluster)))
    y = [t[i].cpu() for t in bincount_per_cluster.values()]
    ax.bar(x, y, bottom=bottom, color=plt.cm.tab10(i), label=str(i), width=0.8)
    bottom += y

ax.set_ylabel("# samples")
ax.set_xlabel("Cluster")
ax.set_xticks(list(range(10)))
ax.legend(loc="upper right", bbox_to_anchor=(1.25, 1.0))

plt.savefig(plots / "figB3.pdf", bbox_inches="tight")