In [None]:
from torchvision.datasets import FashionMNIST
from tsv.natvamp import DLSV
import torch
import numpy as np
import matplotlib.pyplot as plt
from itertools import product
import seaborn as sns

fmnist_train = FashionMNIST(
    "FMNIST",
    train=True,
    download=True,
)
control_data = fmnist_train.data.view(-1, 1, 28, 28).float().numpy()
control_data /= 255
control_labels = fmnist_train.targets.numpy()

In [None]:
# model = ModularNVPW.load_from_checkpoint("/home/zain/code/two-stage/logs/nvpw/fmnist-pseudodiverge/checkpoints/epoch=99-step=5900.ckpt")
model = DLSV.load_from_checkpoint("/home/zain/code/two-stage/logs/dlsv/fmnist-stochastic/checkpoints/epoch=142-step=33605.ckpt")
torch.set_grad_enabled(False)

In [None]:
model.eval().cuda()

In [None]:
def plot_pseudos_and_representatives(pseudos, representatives):
    num_pseudos = pseudos.shape[0]
    num_reps = representatives.shape[0]
    fig, axs = plt.subplots(
        num_pseudos, num_reps + 1, figsize=(8, 28)
    )
    for i, ax in enumerate(axs):
        pseudo = pseudos[i, 0]
        ax[0].imshow(pseudo)
        ax[0].axis("off")
        for j in range(1, num_reps + 1):
            rep = representatives[j - 1, i]
            ax[j].imshow(rep)
            ax[j].axis("off")
    fig.tight_layout()
    fig.show()

In [None]:
def plot_pseudos():
    width = np.ceil(np.sqrt(model.num_pseudos))
    width = int(width)
    length = width
    if model.num_pseudos % width == 0:
        length = model.num_pseudos // width
    fig, ax = plt.subplots(length, width, figsize=(14, 10))

    for i, _ax in zip(range(model.num_pseudos), ax.flatten()):
        pseudo = model.pseudos[i].cpu().view(28, 28).numpy()
        _ax.imshow(pseudo)
    fig.show()

In [None]:
import umap
from matplotlib.patches import Ellipse
import tqdm

def generate_embedding(train_dataloader, encoder, transform):
    embeddings = []
    targets = []
    for batch in tqdm.tqdm(train_dataloader):
        x, y = batch
        x = x.view(-1, 1, 28, 28).float().cuda()
        x /= 255
        z = encoder(x)
        z = transform(z)
        embeddings.append(z.detach().cpu().numpy())
        targets.append(y.detach().cpu().numpy())
    embeddings = np.concatenate(embeddings)
    targets = np.concatenate(targets)
    return embeddings, targets

def plot_embedding(embeddings, targets, model):
    fig, axes = plt.subplots(1, 1, figsize=(10, 10))

    if embeddings.shape[-1] > 2:
        reducer = umap.UMAP(min_dist=0)

        visual_embedding = reducer.fit_transform(embeddings)
        if mu_p is not None:
            mu_p = mu_p.detach().cpu().numpy()
            embedded_pseudos = reducer.transform(mu_p)
    else:
        mu_p = None
        logvar_p = None
        if hasattr(model, "pseudos"):
            mu_p, logvar_p, *_ = model.q_z(model.get_pseudos())
        if logvar_p is not None:
            logvar_p = logvar_p.detach().cpu().numpy()
            std_p = np.exp(0.5 * logvar_p)
        visual_embedding = embeddings
        if mu_p is not None:
            embedded_pseudos = mu_p.detach().cpu().numpy()
            for embedded_pseudo, std in zip(embedded_pseudos, std_p):
                print(f"Making ellipse at {embedded_pseudo} with std {std}")
                axes.add_patch(
                    Ellipse(
                        xy=embedded_pseudo,
                        width=3 * std[0],
                        height=3 * std[1],
                        edgecolor="r",
                        fc="grey",
                        lw=2,
                    )
                )
    assert targets is not None
    axes.scatter(
        visual_embedding[:, 0],
        visual_embedding[:, 1],
        c=targets,
        s=0.75,
        cmap="tab10",
    )
    if mu_p is not None:
        axes.scatter(
            embedded_pseudos[:, 0],
            embedded_pseudos[:, 1],
            c="black",
            s=50,
            marker="x",
        )
    fig.show()

In [None]:
plot_pseudos()

In [None]:
from tsv.data import FMNISTDataModule

data_module = FMNISTDataModule(batch_size=256, num_workers=4, persistent_workers=False)
data_module.setup('fit')
data_loader = data_module.train_dataloader()
embeddings, targets = generate_embedding(data_loader, lambda x: model.q_z(x)[0], lambda x: x)
plot_embedding(embeddings, targets, model=model)

In [None]:
model.project_pseudos()

In [None]:
generate_embedding(data_loader, lambda x: model.q_z(x)[0], lambda x: x)

In [None]:
plot_pseudos()

In [None]:
def calculate_divergences(model, indices):
    divergences = np.full((len(indices), len(indices)), np.inf)
    for idx, jdx in product(indices, indices):
        if idx == jdx:
            continue
        x, y = model.get_pseudos()[[idx, jdx]]
        x = x.unsqueeze(0)
        y = y.unsqueeze(0)
        kl_div = model.general_kl(*model.q_z(x)[:2], *model.q_z(y)[:2])
        divergences[idx, jdx] = kl_div.item()
    return divergences

In [None]:
def _find_twins(divergences_triplet):
    twins = []
    for idx, jdx in divergences_triplet:
        if (jdx, idx) in twins or (idx, jdx) in twins:
            continue
        if (jdx, idx) in divergences_triplet:
            if jdx < idx:
                idx, jdx = jdx, idx
            twins.append((idx, jdx))
    return twins

In [None]:
def propose_merges(model):
    divergences = calculate_divergences(model, range(model.num_pseudos))
    sorted_divergences = np.argmin(divergences, axis=1)
    twin_idxs = np.argsort(divergences[np.arange(divergences.shape[0]), sorted_divergences])    
    divergences_triplet = {}
    for idx in twin_idxs:
        divergences_triplet[(idx, sorted_divergences[idx])] = divergences[idx, sorted_divergences[idx]]
    twins = _find_twins(divergences_triplet)
    sorted_twins = sorted(twins, reverse=True, key=lambda x: x[1])
    return sorted_twins

In [None]:
merges = propose_merges(model)
merges

In [None]:
from torch.utils.data import DataLoader
import tqdm
data_loader = DataLoader(
            fmnist_train.train_data,
            batch_size=256,
            num_workers=4,
            shuffle=False,
            persistent_workers=False,
            pin_memory=True,
            prefetch_factor=5,
        )


In [None]:
from tsv.natvamp import log_normal_diag

pseudos = model.get_pseudos()
mu_p, logvar_p, *_ = model.q_z(pseudos)
max_likelihood_idx = []
likelihoods = []
for batch in tqdm.tqdm(data_loader):
    x = batch.float().cuda().view(-1, 1, 28, 28)
    x /= 255
    mu = model.q_z(x)[0]
    likelihoods.append(log_normal_diag(mu.unsqueeze(1), mu_p.unsqueeze(0), logvar_p.unsqueeze(0), reduction="sum", dim=-1).cpu())
likelihoods = torch.cat(likelihoods, 0)
max_likelihood_idx = likelihoods.argmax(1).cpu().numpy()


In [None]:
unique_pseudos, counts = np.unique(max_likelihood_idx, return_counts=True)
ordered_pseudo_idxs = unique_pseudos[np.argsort(counts)]

In [None]:
ordered_pseudo_idxs

In [None]:
counts[ordered_pseudo_idxs]

In [None]:
divergences = calculate_divergences(model, np.arange(model.num_pseudos))

In [None]:
closest_idx = divergences.argmin(axis=1)

In [None]:
must_murge_idxs = ordered_pseudo_idxs[counts[ordered_pseudo_idxs]<100]

In [None]:
merge_target_idxs = closest_idx[must_murge_idxs]

In [None]:
print([(idx, closest_idx[idx]) for idx in must_murge_idxs])
print([(11, 19), (9, 13), (2, 10), (5, 7), (0, 1)])

In [None]:
model.merge_pseudos(merges)

In [None]:
plot_pseudos()

In [None]:
torch.save(model, "/home/zain/code/two-stage/logs/nvpw/fmnist-pseudodiverge/merge_1.pt")

In [None]:
from torch.utils.data import DataLoader
import tqdm
data_loader = DataLoader(
            fmnist_train.train_data,
            batch_size=256,
            num_workers=4,
            shuffle=False,
            persistent_workers=False,
            pin_memory=True,
            prefetch_factor=5,
        )


In [None]:
from tsv.natvamp import log_normal_diag

pseudos = model.get_pseudos()
mu_p, logvar_p, *_ = model.q_z(pseudos)
max_likelihood_idx = []
likelihoods = []
for batch in tqdm.tqdm(data_loader):
    x = batch.float().cuda().view(-1, 1, 28, 28)
    x /= 255
    mu = model.q_z(x)[0]
    likelihoods.append(log_normal_diag(mu.unsqueeze(1), mu_p.unsqueeze(0), logvar_p.unsqueeze(0), reduction="sum", dim=-1).cpu())
likelihoods = torch.cat(likelihoods, 0)
max_likelihood_idx = likelihoods.argmax(1).cpu().numpy()


In [None]:
def find_representatives(pseudo_idx, num_representatives):
    base_idx = np.arange(likelihoods.shape[0])
    paragon_mask = max_likelihood_idx == pseudo_idx
    print(np.sum(paragon_mask))
    base_idx = base_idx[paragon_mask]
    paragon_likelihoods = likelihoods[paragon_mask, pseudo_idx]
    vals, indices = torch.sort(paragon_likelihoods.cpu(), descending=True)
    if len(indices) >= num_representatives:
        return base_idx[indices][:num_representatives]
    else:
        return None

In [None]:
idx = 0
plt.imshow(control_data[find_representatives(0, 3)[idx]][0])

In [None]:
plot_pseudos()

In [None]:
def _yield_representatives():
    for i in range(model.num_pseudos):
        representatives = find_representatives(i, 3)
        if representatives is not None:
            yield representatives

representatives = [torch.tensor(control_data[rep]) for rep in _yield_representatives()]
print([rep.shape for rep in representatives])
representatives = torch.cat(representatives, 1)

In [None]:
representatives.shape

In [None]:
plot_pseudos_and_representatives(pseudos.cpu(), representatives.cpu())