This notebook can be used to test a model on training/testing data to see how well reconstruction works. It is also possible to visualize memory entries directly

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

import sys; sys.path.insert(0, '..')
import torch 
import pytorch_lightning as pl
from torch.nn import functional as F

from collections.abc import Iterable

import matplotlib.pyplot as plt

from systems.memae_autoencoder_system import MemaeSystem
from systems.ae_autoencoder_system import AESystem
from models.conditional.conditional_memae_mnist import ConditionalMemaeMNIST
from models.memae_mnist_flat import MemaeMNISTFlat
from datamodules.mnist_dm import MNISTDataModule
from datamodules.cifar_dm import CIFARDataModule

from helpers import parse_runs, select_checkpoint_from_run_df, show_normalized_img, plot_vector_as_bar



In [None]:
# Specify the directories of the data that should be loaded.
# All subfolders are automatically analyzed
data_dirs = [

]

In [None]:
# Load the data and select the checkpoint
runs = parse_runs(data_dirs, ["Seed", "Model Type"]) # Always create a seed level
checkpoint = select_checkpoint_from_run_df(runs)

In [None]:
# Load the checkpoint
print(f"Selected: '{checkpoint.value}'")
system = AESystem.load_from_checkpoint(checkpoint.value, learning_rate=0)
_ = system.eval()

In [None]:
# Visualize the memory entries
memory = system.model.mem_rep.memory
if not isinstance(memory, Iterable):
    # Memory is not conditional, put it in a list so we can still iterate it
    memory = [memory]

for mem in memory:
    for i in range(10):
        entry = mem.weight[i]
        decoded_mem = system.model.decoder(entry.unsqueeze(0).unsqueeze(2).unsqueeze(2))
        show_normalized_img(decoded_mem, save=False, None)

In [None]:
# Load a sample from the dataset
data_sample_class = 0
condition = None
dm = CIFARDataModule([data_sample_class], 1, 1, data_dir='../data')
dm.prepare_data()
dm.setup()

# Perform reconstruction
samples,y = next(iter(dm.train_dataloader()))

if condition is not None:
    out = system(sample, torch.tensor([condition]))
else:
    out = system(sample)

show_normalized_img(sample, save=False, filename=None)
show_normalized_img(out, save=False, filename=None)

In [None]:
# Visualize the various stages of the addressing vector
encoded = system.model.encoder(sample).detach()
out = system.model.mem_rep(encoded)
encoded_hat = out["output"].detach()
att = out["att"].detach()
att_pre_softmax = out["pre_softmax_att"].detach()
att_post_softmax = F.softmax(att_pre_softmax, dim=1)


In [None]:
plot_vector_as_bar(encoded.flatten(), "Value", "Index")

In [None]:
plot_vector_as_bar(att_pre_softmax.flatten(), "Value", "Index", False, "attention_pre_softmax")

In [None]:
plot_vector_as_bar(att_post_softmax.flatten(), "Value", "Index", False, "attention_post_softmax")

In [None]:
plot_vector_as_bar(att.flatten(), "Value", "Index", False, "attention_final")

In [None]:
plot_vector_as_bar(encoded_hat.flatten(), "Value", "Index")