In [None]:
import numpy as np
import matplotlib.pyplot as plt
import umap
import os
from PCH import HDBSCAN
from PCH.utils import constraints_from_estimate, augment_labels
from sklearn.metrics import adjusted_rand_score
from matplotlib import collections as mc

In [None]:
def plot_constraints(visual_embedding, selected_labels, constraints, s=.1):
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.scatter(
        visual_embedding[:, 0],
        visual_embedding[:, 1],
        c=selected_labels,
        s=s,
        cmap="tab20",
    )
    ax.add_collection(
        mc.LineCollection(
            visual_embedding[constraints["ML"]],
            linewidths=2,
            color="purple",
            linestyle="dashed",
            alpha=.85,
        )
    )
    ax.add_collection(
        mc.LineCollection(
            visual_embedding[constraints["CL"]],
            linewidths=2,
            color="black",
            linestyle="dashed",
            alpha=.25,
        )
    )
    fig.show()    

In [None]:
def merge_constraints(current_constraints, new_constraints):
    for key in new_constraints:
        if key not in current_constraints:
            current_constraints[key] = []
        current_constraints[key].extend(new_constraints[key])
    return current_constraints

In [None]:
import umap
from matplotlib.patches import Ellipse
import tqdm

def generate_embedding(train_dataloader, encoder, transform):
    embeddings = []
    targets = []
    for batch in tqdm.tqdm(train_dataloader):
        x, y = batch
        x = x.view(-1, 1, 28, 28).float().cuda()
        x /= 255
        z = encoder(x)
        z = transform(z)
        embeddings.append(z.detach().cpu().numpy())
        targets.append(y.detach().cpu().numpy())
    embeddings = np.concatenate(embeddings)
    targets = np.concatenate(targets)
    return embeddings, targets


In [None]:
from tsv.data import FMNISTDataModule
from tsv.natvamp import DLSV
import os
import torch

data_module = FMNISTDataModule(batch_size=256, num_workers=4, persistent_workers=False)
data_module.setup('fit')
data_loader = data_module.train_dataloader()

In [None]:
MODEL = "dlsv"
RUN_NAME = "fmnist-stochastic"
SAVE_DIR = os.path.join("embeddings", MODEL, RUN_NAME)
embedding = np.load(os.path.join(SAVE_DIR, "embeddings.npy"))
targets = np.load(os.path.join(SAVE_DIR, "labels.npy"))

In [None]:
print(os.path.join(SAVE_DIR, "embeddings.npy"))

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(14, 10))
ax.set_title('Loaded embeddings', fontsize=18)
ax.scatter(embedding[:, 0], embedding[:, 1], c=targets, cmap='tab20', s=0.1)
fig.show()

In [None]:
MODEL = "dlsv"
RUN_NAME = "fmnist-stochastic"
SAVE_DIR = os.path.join("logs", MODEL, RUN_NAME)
CHKPT_DIR = os.path.join(SAVE_DIR, "checkpoints")
chkpt_path = os.path.join(CHKPT_DIR, os.listdir(CHKPT_DIR)[0])
print(chkpt_path)

with torch.no_grad():
    model = DLSV.load_from_checkpoint(chkpt_path).cuda()
    model.train()
    embedding, targets = generate_embedding(data_loader, lambda x: model.q_z(x)[0], lambda x: x)

In [None]:
visual_embedding = umap.UMAP(min_dist=0.0, n_neighbors=5, repulsion_strength=1).fit_transform(embedding) if embedding.shape[-1] > 2 else embedding

In [None]:
umap_embedding = umap.UMAP(min_dist=0.0, n_neighbors=5, n_components=embedding.shape[-1]).fit_transform(embedding)

In [None]:
from PCH.utils import make_targets_from_sets
targets_five = make_targets_from_sets(targets, [{0, 3}, {1}, {2, 4, 6}, {8}, {5, 7, 9}])

In [None]:
mu, logvar, *_ = model.q_z(model.get_pseudos())

In [None]:
from tsv.natvamp import log_normal_diag

log_likelihoods = log_normal_diag(torch.tensor(embedding, device=mu.device).unsqueeze(1), mu.unsqueeze(0), logvar.unsqueeze(0)).sum(-1)

In [None]:
MLE_labels = log_likelihoods.argmax(-1).cpu().numpy()

In [None]:
adjusted_rand_score(targets, MLE_labels)

In [None]:
adjusted_rand_score(targets_five, MLE_labels)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(14, 10))
ax.set_title('NVP embeddings', fontsize=18)
ax.scatter(visual_embedding[:, 0], visual_embedding[:, 1], c=targets, cmap='tab20', s=0.1)
fig.show()

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(14, 10))
ax.set_title('UMAP projection of the NVP embeddings', fontsize=18)
ax.scatter(umap_embedding[:, 0], umap_embedding[:, 1], c=targets, cmap='tab20', s=.1)
fig.show()

In [None]:
hdb = HDBSCAN(min_cluster_size=500, constraint_mode="t-synthetic")
labels = hdb.fit_predict(embedding)

In [None]:
print(adjusted_rand_score(targets, labels))
augmented_labels = augment_labels(embedding, labels)
print(adjusted_rand_score(targets, augmented_labels))

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(14, 10))
ax.set_title('NVP', fontsize=18)
ax.scatter(visual_embedding[:, 0], visual_embedding[:, 1], c=augmented_labels, cmap='tab20', s=.1)
fig.show()

In [None]:
umap_labels = hdb.fit_predict(umap_embedding)
print(adjusted_rand_score(targets, labels))
augmented_labels = augment_labels(embedding, labels)
print(adjusted_rand_score(targets, augmented_labels))
fig, ax = plt.subplots(1, 1, figsize=(14, 10))
ax.set_title('NVP', fontsize=18)
ax.scatter(umap_embedding[:, 0], umap_embedding[:, 1], c=augmented_labels, cmap='tab20', s=.1)
fig.show()

In [None]:
constraints = (
    constraints_from_estimate(
        visual_embedding,
        labels=augmented_labels,
        ground_truth=targets,
        n_samples=20,
        n_subsample=10000,
    )
)
plot_constraints(visual_embedding, augmented_labels, constraints)

In [None]:
constrained_labels = hdb.fit(visual_embedding, constraints=constraints).labels_

In [None]:
print(adjusted_rand_score(targets, constrained_labels))
augmented_constrained_labels = augment_labels(embedding, constrained_labels)
print(adjusted_rand_score(targets, augmented_constrained_labels))

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(14, 10))
ax.set_title('NVP + C', fontsize=18)
ax.scatter(visual_embedding[:, 0], visual_embedding[:, 1], c=augmented_constrained_labels, cmap='tab20', s=.1)
fig.show()

In [None]:
constrained_umap_labels = hdb.fit(umap_embedding, constraints=constraints).labels_
print(adjusted_rand_score(targets, constrained_umap_labels))
augmented_constrained_umap_labels = augment_labels(embedding, constrained_umap_labels)
print(adjusted_rand_score(targets, augmented_constrained_umap_labels))
fig, ax = plt.subplots(1, 1, figsize=(14, 10))
ax.set_title('NVP', fontsize=18)
ax.scatter(umap_embedding[:, 0], umap_embedding[:, 1], c=augmented_constrained_umap_labels, cmap='tab20', s=.1)
fig.show()

In [None]:
constraints = merge_constraints(
    constraints,
    constraints_from_estimate(
        visual_embedding,
        labels=augmented_constrained_labels,
        ground_truth=targets,
        n_samples=20,
        n_subsample=10000,
    ),
)
plot_constraints(visual_embedding, augmented_constrained_labels, constraints)

In [None]:
constrained_second_pass_labels = hdb.fit(visual_embedding, constraints=constraints).labels_

fig, ax = plt.subplots(1, 1, figsize=(14, 10))
ax.set_title('NVP + 2xC', fontsize=18)
ax.scatter(visual_embedding[:, 0], visual_embedding[:, 1], c=constrained_second_pass_labels, cmap='tab20', s=.1)
fig.show()


In [None]:
print(adjusted_rand_score(labels, constrained_second_pass_labels))
augmented_constrained_second_pass_labels = augment_labels(embedding, constrained_second_pass_labels)
print(adjusted_rand_score(labels, augmented_constrained_second_pass_labels))

In [None]:
constraints = merge_constraints(
    constraints,
    constraints_from_estimate(
        visual_embedding,
        labels=augmented_constrained_labels,
        ground_truth=targets,
        n_samples=20,
        n_subsample=10000,
    ),
)
plot_constraints(visual_embedding, augmented_constrained_second_pass_labels, constraints)

In [None]:
constrained_third_pass_labels = hdb.fit(visual_embedding, constraints=constraints).labels_

fig, ax = plt.subplots(1, 1, figsize=(14, 10))
ax.set_title('NVP + 3xC', fontsize=18)
ax.scatter(visual_embedding[:, 0], visual_embedding[:, 1], c=constrained_third_pass_labels, cmap='tab20', s=.1)
fig.show()

In [None]:
print(adjusted_rand_score(labels, constrained_third_pass_labels))
augmented_constrained_third_pass_labels = augment_labels(embedding, constrained_third_pass_labels)
print(adjusted_rand_score(labels, augmented_constrained_third_pass_labels))

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(14, 10))
ax.set_title('Augmented NVP + C', fontsize=18)
ax.scatter(visual_embedding[:, 0], visual_embedding[:, 1], c=augmented_constrained_labels, cmap='tab20', s=.1)
fig.show()

In [None]:
def make_plots(embedding, umap_embedding, targets, labels, constrained_labels, umap_labels, constrained_umap_labels):
    fig, ax = plt.subplots(2, 5, figsize=(20, 14))
    for i, title in enumerate(["NVP", "NVP + C", "UMAP", "UMAP + C", "Ground Truth"]):
        ax[0, i].set_title(title, fontsize=18)
        _labels = (labels, constrained_labels, umap_labels, constrained_umap_labels, targets)[i]
        if i != 4:
            _labels = augment_labels(embedding, _labels)
        ax[0, i].scatter(visual_embedding[:, 0], visual_embedding[:, 1], c=_labels, cmap='tab20', s=.1)

        # ax[1, i].set_title(title, fontsize=18)
        # ax[1, i].scatter(umap_embedding[:, 0], umap_embedding[:, 1], c=_labels, cmap='tab20', s=.1)

    fig.show()

In [None]:
fig, ax = plt.subplots(2, 1, figsize=(10, 14))
ax[0].set_title('Augmented HDBSCAN Labels using NVP embedding + constraints', fontsize=18)
ax[0].scatter(umap_embedding[:, 0], umap_embedding[:, 1], c=augmented_constrained_labels, cmap='tab20', s=.1)
ax[1].set_title('Ground Truth', fontsize=18)
ax[1].scatter(umap_embedding[:, 0], umap_embedding[:, 1], c=targets, cmap='tab20', s=.1)
fig.show()

print(adjusted_rand_score(labels, augmented_constrained_labels))

In [None]:
hdb.fit(umap_embedding, constraints=constraints)
constrained_umap_labels = hdb.labels_

fig, ax = plt.subplots(1, 1, figsize=(14, 10))
ax.set_title('UMAP + C', fontsize=18)
ax.scatter(embedding[:, 0], embedding[:, 1], c=constrained_umap_labels, cmap='tab20', s=.1)
fig.show()
print(adjusted_rand_score(labels, constrained_umap_labels))

In [None]:
augmented_constrained_umap_labels = augment_labels(embedding, constrained_umap_labels)
print(adjusted_rand_score(labels, constrained_umap_labels))

fig, ax = plt.subplots(1, 1, figsize=(14, 10))
ax.set_title('UMAP + C', fontsize=18)
ax.scatter(embedding[:, 0], embedding[:, 1], c=augmented_constrained_umap_labels, cmap='tab20', s=.1)
fig.show()
print(adjusted_rand_score(labels, augmented_constrained_umap_labels))

In [None]:
def eval_constraints(embeddings, names, targets, constraints):
    hdb = HDBSCAN(min_cluster_size=500)
    hdb_t = HDBSCAN(min_cluster_size=500, constraint_mode="t-synthetic")
    labels = []
    for _embedding, name in zip(embeddings, names):
        hdb.fit(_embedding)
        labels.append(augment_labels(_embedding, hdb.labels_))
        print(f"{name} | {adjusted_rand_score(targets, labels[-1]):.2f}")

        hdb_t.fit(_embedding, constraints=constraints)
        labels.append(augment_labels(_embedding, hdb_t.labels_))
        print(f"{name} + C | {adjusted_rand_score(targets, labels[-1]):.2f}")

    return labels

In [None]:
all_labels = eval_constraints((embedding, umap_embedding), ("NVP", "UMAP"), targets, constraints)

In [None]:
from sklearn.metrics import normalized_mutual_info_score, fowlkes_mallows_score

for labels, name in zip(all_labels, ["NVP", "NVP + C", "UMAP", "UMAP + C"]):
    print(name)
    for scorer in [adjusted_rand_score, normalized_mutual_info_score, fowlkes_mallows_score]:
        print(f"{scorer.__name__} | {scorer(targets, labels):.2f}")
    print("\n")

In [None]:
make_plots(visual_embedding, umap_embedding, targets, *all_labels)

In [None]:
make_plots(embedding, umap_embedding, targets, labels, constrained_labels, umap_labels, constrained_umap_labels)

In [None]:
constraints = merge_constraints(
    constraints_from_estimate(
        embedding,
        labels=augmented_labels,
        ground_truth=labels,
        n_samples=100,
        n_subsample=10000,
    ),
    constraints,
)
plot_constraints(embedding, labels, constraints)
plot_constraints(embedding, augmented_labels, constraints)

In [None]:
hdb = HDBSCAN(min_cluster_size=500, constraint_mode="t-synthetic")
hdb.fit(embedding, constraints=constraints)
constrained_labels.append(hdb.labels_)

fig, ax = plt.subplots(1, 1, figsize=(14, 10))
ax.set_title('HDBSCAN Labels using visual embedding + 2x constraints', fontsize=18)
ax.scatter(embedding[:, 0], embedding[:, 1], c=constrained_labels[1], cmap='tab20', s=.1)
fig.show()
print(adjusted_rand_score(labels, constrained_labels[1]))

In [None]:
constraints = merge_constraints(
    constraints_from_estimate(
        embedding,
        labels=constrained_labels[1],
        ground_truth=labels,
        n_samples=1000,
        n_subsample=20000,
    ),
    constraints,
)
plot_constraints(embedding, labels, constraints)
plot_constraints(embedding, constrained_labels[0], constraints)

In [None]:
hdb = HDBSCAN(min_cluster_size=500, constraint_mode="t-synthetic")
hdb.fit(embedding, constraints=constraints)
# constrained_labels.append(hdb.labels_)
clabel = hdb.labels_

fig, ax = plt.subplots(1, 1, figsize=(14, 10))
ax.set_title('HDBSCAN Labels using visual embedding + 3x constraints', fontsize=18)
ax.scatter(embedding[:, 0], embedding[:, 1], c=clabel, cmap='tab20', s=.1)
fig.show()
print(adjusted_rand_score(labels, clabel))

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(20, 10))
ax[0].set_title('16-dim LSV + UMAP + 1000 constraints', fontsize=18)
ax[0].scatter(visual_embedding[:, 0], visual_embedding[:, 1], c=all_labels[-1], cmap='tab20', s=.1)
ax[1].set_title('Ground Truth', fontsize=18)
ax[1].scatter(visual_embedding[:, 0], visual_embedding[:, 1], c=targets, cmap='tab20', s=.1)
fig.show()
print(f"ARI = {adjusted_rand_score(targets, all_labels[-1]):.2f}")

In [None]:
from sklearn.metrics import normalized_mutual_info_score, fowlkes_mallows_score


def iter_eval(
    fit_embedding,
    augment_embedding,
    constraint_embedding,
    visual_embedding,
    targets,
    n_iter=5,
    constraints_per_iteration=20,
):
    iter_labels = []
    hdb_iter = HDBSCAN(min_cluster_size=500, constraint_mode="t-synthetic")
    constraints = {}
    for i in range(n_iter):
        hdb_iter.fit(fit_embedding, constraints=constraints)
        iter_labels.append(augment_labels(augment_embedding, hdb_iter.labels_))
        constraints = merge_constraints(
            constraints_from_estimate(
                constraint_embedding,
                labels=iter_labels[-1],
                ground_truth=targets,
                n_samples=constraints_per_iteration,
                n_subsample=20000,
            ),
            constraints,
        )
        print(
            f"Iteration {i + 1} | ARI = {adjusted_rand_score(targets, iter_labels[-1]):.2f}"
            f" | NMI = {normalized_mutual_info_score(targets, iter_labels[-1]):.2f}"
            f" | FMI = {fowlkes_mallows_score(targets, iter_labels[-1]):.2f}"
        )
        fig, ax = plt.subplots(1, 1, figsize=(14, 10))
        ax.set_title(f"NVP + C | Iteration {i + 1}", fontsize=18)
        ax.scatter(
            visual_embedding[:, 0],
            visual_embedding[:, 1],
            c=iter_labels[-1],
            cmap="tab20",
            s=0.1,
        )
        fig.show()

In [None]:
iter_eval(
    fit_embedding=umap_embedding,
    augment_embedding=umap_embedding,
    constraint_embedding=umap_embedding,
    visual_embedding=umap_embedding,
    targets=targets,
    n_iter=5,
    constraints_per_iteration=20,
)

In [None]:
augmented_labels = augment_labels(embedding, clabel)
print(adjusted_rand_score(labels, augmented_labels))

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(14, 10))
ax.set_title('HDBSCAN Labels using visual embedding + 3x constraints', fontsize=18)
ax.scatter(embedding[:, 0], embedding[:, 1], c=augmented_labels, cmap='tab20', s=.1)
fig.show()

In [None]:
np.save(SAVE_DIR + "constrained_labels.npy", augmented_labels)
print(f"Saving to {SAVE_DIR}constrained_labels.npy")

In [None]:
from torchvision.datasets import FashionMNIST
from torchvision import transforms
from tsv.natvamp import NVPW
import torch
fmnist_train = FashionMNIST(
    "FMNIST",
    train=True,
    download=True,
)
control_data = fmnist_train.data.view(-1, 1, 28, 28).float().numpy()
control_data /= 255
control_labels = fmnist_train.targets.numpy()

In [None]:
model = NVPW.load_from_checkpoint("/home/zain/code/two-stage/logs/nvpw/fmnist-pseudodiverge/checkpoints/epoch=99-step=5900.ckpt")
torch.set_grad_enabled(False)

In [None]:
torch.set_grad_enabled(False)
model.train().cuda()

In [None]:
idx = 1
raw = control_data[idx:idx+2]
x = raw.squeeze(1)
with torch.no_grad():
    x_hat = model(torch.tensor(raw).cuda())[0].cpu().view(-1, 28, 28).numpy()
plt.imshow(x[0])
plt.show()
plt.imshow(x_hat[0])
plt.show()
torch.nn.functional.mse_loss(torch.tensor(x), torch.tensor(x_hat), reduction="mean")

In [None]:
width = np.ceil(np.sqrt(model.num_pseudos))
width = int(width)
length = width
if model.num_pseudos % width == 0:
    length -= 1
fig, ax = plt.subplots(length, width, figsize=(14, 10))

for i, _ax in zip(range(model.num_pseudos), ax.flatten()):
    pseudo = model.pseudos[i].cpu().view(28, 28).numpy()
    _ax.imshow(pseudo)

In [None]:
model.merge_pseudos(0, 1)

In [None]:
visual = umap.UMAP(n_components=2, min_dist=0).fit_transform(control_data)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(14, 10))
ax.set_title('UMAP projection of the NVP embeddings with estimated labels', fontsize=18)
ax.scatter(visual[:, 0], visual[:, 1], c=fmnist_train.targets, cmap='tab20', s=.1)
fig.show()

In [None]:
hdb = HDBSCAN(min_cluster_size=500, min_samples=5)
estimated_labels = hdb.fit_predict(visual)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(14, 10))
ax.set_title('UMAP projection of the NVP embeddings with estimated labels', fontsize=18)
ax.scatter(visual[:, 0], visual[:, 1], c=estimated_labels, cmap='tab20', s=.1)
fig.show()

In [None]:
print(adjusted_rand_score(fmnist_train.targets, estimated_labels))

In [None]:
hdb = HDBSCAN(min_cluster_size=500, min_samples=5, constraint_mode="synthetic")
constraints = constraints_from_estimate(control_data, estimated_labels, fmnist_train.targets, 100)
hdb.fit(visual, constraints=constraints)
second_estimated_labels = hdb.labels_

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(14, 10))
ax.set_title('UMAP projection of the NVP embeddings with estimated labels', fontsize=18)
ax.scatter(visual[:, 0], visual[:, 1], c=second_estimated_labels, cmap='tab20', s=.1)
fig.show()

In [None]:
print(adjusted_rand_score(fmnist_train.targets, second_estimated_labels))

In [None]:
new_constraints = constraints_from_estimate(control_data, second_estimated_labels, fmnist_train.targets, 100)
constraints['ML'].extend(new_constraints['ML'])
constraints['CL'].extend(new_constraints['CL'])

In [None]:
hdb = HDBSCAN(min_cluster_size=500, min_samples=5, constraint_mode="synthetic")
hdb.fit(visual, constraints=constraints)
third_estimated_labels = hdb.labels_

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(14, 10))
ax.set_title('UMAP projection of the NVP embeddings with estimated labels', fontsize=18)
ax.scatter(visual[:, 0], visual[:, 1], c=third_estimated_labels, cmap='tab20', s=.1)
fig.show()

In [None]:
print(adjusted_rand_score(fmnist_train.targets, third_estimated_labels))

In [None]:
control_embedding = umap.UMAP(n_components=16, min_dist=0).fit_transform(fmnist_train.data.view(-1, 28*28))

In [None]:
import numpy as np
rng = np.random.RandomState(42)
sample_idxs = rng.choice(len(control_embedding), data.shape[0], replace=False)
control_embedding = control_embedding[sample_idxs]


In [None]:
hdb = HDBSCAN(min_cluster_size=50)
estimated_labels_control = hdb.fit_predict(control_embedding)

fig, ax = plt.subplots(1, 1, figsize=(14, 10))
ax.set_title('UMAP projection of the NVP embeddings with control labels', fontsize=18)
ax.scatter(control_visual[:, 0], control_visual[:, 1], c=estimated_labels_control, cmap='tab20', s=.1)
fig.show()

In [None]:
print(adjusted_rand_score(fmnist_train.targets, estimated_labels))