In [None]:
from torchvision.datasets import FashionMNIST
from tsv.natvamp import ModularNVPW
import torch
import numpy as np
import matplotlib.pyplot as plt
from itertools import product
import seaborn as sns

fmnist_train = FashionMNIST(
    "FMNIST",
    train=True,
    download=True,
)
control_data = fmnist_train.data.view(-1, 1, 28, 28).float().numpy()
control_data /= 255
control_labels = fmnist_train.targets.numpy()

In [None]:
model = ModularNVPW.load_from_checkpoint("/home/zain/code/two-stage/logs/nvpw/fmnist-pseudodiverge/checkpoints/epoch=99-step=5900.ckpt")
torch.set_grad_enabled(False)

In [None]:
model.eval().cuda()

In [None]:
idx = 0
raw = control_data[idx:idx+1]
x = raw.squeeze(1)
with torch.no_grad():
    x_hat = model(torch.tensor(raw).cuda())[0].cpu().view(-1, 28, 28).numpy()
plt.imshow(x[0])
plt.show()
plt.imshow(x_hat[0])
plt.show()
torch.nn.functional.mse_loss(torch.tensor(x), torch.tensor(x_hat), reduction="mean")

In [None]:
def plot_pseudos_and_representatives(pseudos, representatives):

    fig, axs = plt.subplots(
        pseudos.shape[0], representatives.shape[0] + 1, figsize=(8, 28)
    )
    print(representatives.shape)
    for i, ax in enumerate(axs):
        pseudo = pseudos[i, 0]
        ax[0].imshow(pseudo)
        ax[0].axis("off")
        for j in range(1, representatives.shape[0] + 1):
            rep = representatives[j - 1, i]
            ax[j].imshow(rep)
            ax[j].axis("off")
    fig.tight_layout()
    fig.show()

In [None]:
def plot_pseudos():
    width = np.ceil(np.sqrt(model.num_pseudos))
    width = int(width)
    length = width
    if model.num_pseudos % width == 0:
        length = model.num_pseudos // width
    fig, ax = plt.subplots(length, width, figsize=(14, 10))

    for i, _ax in zip(range(model.num_pseudos), ax.flatten()):
        pseudo = model.pseudos[i].cpu().view(28, 28).numpy()
        _ax.imshow(pseudo)
    fig.show()

In [None]:
plot_pseudos()

In [None]:
model.project_pseudos()

In [None]:
plot_pseudos()

In [None]:
def calculate_divergences(model, indices):
    divergences = []
    for idx, jdx in product(indices, indices):
        if idx == jdx:
            continue
        x, y = model.get_pseudos()[[idx, jdx]]
        x = x.unsqueeze(0)
        y = y.unsqueeze(0)
        kl_div = model.general_kl(*model.q_z(x)[:2], *model.q_z(y)[:2])
        divergences.append((idx, jdx, kl_div.item()))
    return divergences

In [None]:
divergences = calculate_divergences(model, range(model.num_pseudos))
sorted_divergences = sorted(divergences, key=lambda x: x[2])

In [None]:
raw_divergences = np.array([divergence[2] for divergence in divergences])
plt.hist(raw_divergences, bins=50)

In [None]:
sorted_divergences

In [None]:
def _condense(divergences, num_pseudos):
    groups = []
    included = [False] * num_pseudos
    for idx, jdx, _ in divergences:
        if included[idx] and included[jdx]:
            continue
        
        if not included[idx] and not included[jdx]:
            groups.append({idx, jdx})
            included[idx] = True
            included[jdx] = True
            continue

        for group in groups:
            if idx in group or jdx in group:
                group |= {idx, jdx}
                included[idx] = True
                included[jdx] = True                
                break
    return groups

In [None]:
condensed_groups = _condense(sorted_divergences, model.num_pseudos)

In [None]:
condensed_groups

In [None]:
model.merge_pseudos((2, 3))

In [None]:
plot_pseudos()

In [None]:
model.merge_pseudos((17, 13))

In [None]:
plot_pseudos()

In [None]:
divergences = calculate_divergences(model, range(model.num_pseudos))
sorted_divergences = sorted(divergences, key=lambda x: x[2])

In [None]:
sorted_divergences

In [None]:
model.merge_pseudos((5, 7))

In [None]:
plot_pseudos()

In [None]:
model.merge_pseudos((6, 13))

In [None]:
plot_pseudos()

In [None]:
def propose_merges(model, num_proposals):
    merges = []
    divergences = calculate_divergences(model, range(model.num_pseudos))
    sorted_divergences = sorted(divergences, key=lambda x: x[2])
    
    for idx, jdx, divergence in divergences:

    return merges

In [None]:
divergences = calculate_divergences(model, range(model.num_pseudos))
sorted_divergences = sorted(divergences, key=lambda x: x[2])

In [None]:
sorted_divergences

In [None]:
model.num_pseudos

In [None]:
# torch.save(model, "/home/zain/code/two-stage/logs/nvpw/fmnist-pseudodiverge/model.pt")

In [None]:
from torch.utils.data import DataLoader
import tqdm
data_loader = DataLoader(
            fmnist_train.train_data,
            batch_size=256,
            num_workers=4,
            shuffle=False,
            persistent_workers=False,
            pin_memory=True,
            prefetch_factor=5,
        )


In [None]:
from tsv.natvamp import log_normal_diag

pseudos = model.get_pseudos()
mu_p, logvar_p, *_ = model.q_z(pseudos)
max_likelihood_idx = []
likelihoods = []
for batch in tqdm.tqdm(data_loader):
    x = batch.float().cuda().view(-1, 1, 28, 28)
    x /= 255
    mu = model.q_z(x)[0]
    likelihoods.append(log_normal_diag(mu.unsqueeze(1), mu_p.unsqueeze(0), logvar_p.unsqueeze(0), reduction="sum", dim=-1).cpu())
likelihoods = torch.cat(likelihoods, 0)
max_likelihood_idx = likelihoods.argmax(1).cpu().numpy()


In [None]:
def find_representatives(pseudo_idx, num_representatives):
    base_idx = np.arange(likelihoods.shape[0])
    paragon_mask = max_likelihood_idx == pseudo_idx
    base_idx = base_idx[paragon_mask]
    paragon_likelihoods = likelihoods[paragon_mask, pseudo_idx]
    vals, indices = torch.sort(paragon_likelihoods.cpu(), descending=True)
    representatives = base_idx[indices][:num_representatives]
    return representatives

In [None]:
idx = 0
plt.imshow(control_data[find_representatives(0, 3)[idx]][0])

In [None]:
plot_pseudos()

In [None]:
representatives = torch.cat(
    [torch.tensor(control_data[find_representatives(i, 3)]) for i in range(model.num_pseudos)], 1
)

In [None]:
representatives.shape

In [None]:
plot_pseudos_and_representatives(pseudos.cpu(), representatives.cpu())