In [1]:
import sys; sys.path.append("../")

# Functions from my code
from dataset_wrappper import NewsData
from utils_train import transfer_batch_to_device
from utils_evaluation import tokenizer_batch_decode
from run_validation import load_model_for_eval

# General libs
import os
from pathlib import Path
from transformers import RobertaTokenizerFast
import torch
import numpy as np
import copy
import pickle
import pandas as pd

# Plotting
import matplotlib.pyplot as plt
plt.style.reload_library()
plt.style.use('thesis_style')

%matplotlib inline
%config InlineBackend.figure_format='retina'

# Globals
BATCH_SIZE = 2
DEVICE = "cuda:0"
CHECKPOINT_TYPE = "best" # else "best"
RESULT_DIR = Path("result-files")

# Data
data = NewsData(batch_size=BATCH_SIZE, tokenizer_name="roberta", dataset_name="ptb_text_only", max_seq_len=64)
# data = NewsData(batch_size=BATCH_SIZE, tokenizer_name="roberta", dataset_name="cnn_dailymail", max_seq_len=64)
validation_loader = data.val_dataloader(batch_size=BATCH_SIZE)
train_loader = data.train_dataloader()
tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')

Is file!
train 42068
validation 3370
test 3761


In [2]:
def get_clean_name(run_name):
    latent_size = run_name.split("-")[4][-2:]
    if "autoencoder" in run_name:
        FB = "autoencoder"
    else:
        FB = run_name.split("-")[6]
        if len(FB) == 3:
            FB += "0"
        FB = "FB-" + FB
    clean_name = f"NZ-{latent_size} | {FB}"
    return clean_name

PTB_run_name_paths = {}
for r in os.listdir("../Runs"):
    if "PTB" in r:
        path = Path("../Runs") / r / f"checkpoint-{CHECKPOINT_TYPE}.pth"
        PTB_run_name_paths[r] = path

clean_names = []
for r in PTB_run_name_paths.keys():
#     print(get_clean_name(r))
    clean_names.append(get_clean_name(r))
    
clean_names = sorted(clean_names)
for i, n in enumerate(clean_names):
    print(i, "\t", n)

0 	 NZ-32 | FB-0.00
1 	 NZ-32 | FB-0.25
2 	 NZ-32 | FB-0.50
3 	 NZ-32 | FB-0.75
4 	 NZ-32 | FB-1.00
5 	 NZ-32 | FB-1.50
6 	 NZ-32 | autoencoder
7 	 NZ-64 | FB-0.00
8 	 NZ-64 | FB-0.25
9 	 NZ-64 | FB-0.50
10 	 NZ-64 | FB-0.75
11 	 NZ-64 | FB-1.00
12 	 NZ-64 | FB-1.50
13 	 NZ-64 | autoencoder
14 	 NZ-68 | autoencoder
15 	 NZ-68 | autoencoder
16 	 NZ-68 | autoencoder
17 	 NZ-68 | autoencoder
18 	 NZ-68 | autoencoder
19 	 NZ-68 | autoencoder
20 	 NZ-68 | autoencoder


## 1 Collect latents of the whole train set for a few runs

In [3]:
runs =  [
 "2021-02-03-PTB-latent32-FB-0.5-run-09:31:02", 
 "2021-02-03-PTB-latent32-FB-1.50-run-12:13:36",
 "2021-02-03-PTB-latent32-autoencoder-run-17:30:41"]

result_file = "train_latents.pickle"

if os.path.isfile(result_file):
    print(f"Load latents from file {result_file}")
    latents_runs = pickle.load( open( result_file, "rb" ) )
else: 
    print(f"Compute latents on train set...")
    latents_runs = {}

    for r, p in PTB_run_name_paths.items():

        if r not in runs:
            continue

        print(get_clean_name(r))

        latent_size = int(r.split("-")[4][-2:])
        vae_model = load_model_for_eval(p, device_name=DEVICE, 
                                        latent_size=latent_size, 
                                        add_latent_via_memory=True,
                                        add_latent_via_embeddings=False, 
                                        do_tie_weights=True, 
                                        do_tie_embedding_spaces=True,
                                        add_decoder_output_embedding_bias=False)

        latents, mus, logvars, strings = [], [], [], []

        for batch_i, batch in enumerate(train_loader):

            print(f"Batch {batch_i:3d} /{len(train_loader):3d}", end="\r")

            with torch.no_grad():
                batch = transfer_batch_to_device(batch, device_name=DEVICE)

                enc_out = vae_model.encoder.encode(batch["input_ids"], batch["attention_mask"], 
                                                   n_samples=1,
                                                   hinge_kl_loss_lambda=0.0,
                                                   return_log_q_z_x=False,
                                                   return_log_p_z=False,
                                                   return_embeddings=False)

                latents.append(enc_out["latent_z"].cpu())
                mus.append(enc_out["mu"].cpu())
                logvars.append(enc_out["logvar"].cpu())
                strings.extend(tokenizer_batch_decode(batch["input_ids"].cpu(), tokenizer))

        latents = torch.cat(latents, dim=0)
        mus = torch.cat(mus, dim=0)
        logvars = torch.cat(logvars, dim=0)

        latents_runs[r] = {
            "latents": latents,
            "mus": mus,
            "logvars": logvars,
            "strings": strings
        }

    pickle.dump( latents_runs, open( result_file, "wb" ) )

Load latents from file train_latents.pickle


In [15]:
# Load the correct model
# r = "2021-02-03-PTB-latent32-FB-1.50-run-12:13:36"

r = "2021-02-03-PTB-latent32-autoencoder-run-17:30:41"
latent_size = int(r.split("-")[4][-2:])
p = PTB_run_name_paths[r]

import train
import utils_train

# latent_size = 768

# p = "/home/cbarkhof/code-thesis/NewsVAE/Runs/23NOV-AUTOENCODER-MemoryOnly-run-2020-11-23-19:10:58/checkpoint-50000.pth"
# p = "/home/cbarkhof/code-thesis/NewsVAE/Runs/13JAN-exp6-AUTO-ENCODER-Higher-Linear-Sched-run-2021-01-13-14:21:29/checkpoint-54000.pth"
    
    
do_tie_embeddings = True
do_tie_weights = True
add_latent_via_memory = True
add_latent_via_embeddings = False
latent_size = 32
add_decoder_output_embedding_bias = False
vae_model = train.get_model_on_device(device_name=DEVICE,
                                      latent_size=latent_size,
                                      gradient_checkpointing=False,
                                      add_latent_via_memory=add_latent_via_memory,
                                      add_latent_via_embeddings=add_latent_via_embeddings,
                                      do_tie_weights=do_tie_weights,
                                      do_tie_embedding_spaces=do_tie_embeddings,
                                      world_master=True,
                                      add_decoder_output_embedding_bias=add_decoder_output_embedding_bias)

_, _, vae_model, _, _, _, _ = utils_train.load_from_checkpoint(vae_model, p, 
                                                               do_tie_embeddings=do_tie_embeddings, 
                                                               do_tie_weights=do_tie_weights)

vae_model.eval()

Loading model...
Replacing linear output layer with one without bias!


Some weights of the model checkpoint at roberta-base were not used when initializing VAE_Encoder_RobertaModel: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing VAE_Encoder_RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing VAE_Encoder_RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of VAE_Encoder_RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.embeddings.position_ids']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Tying encoder decoder RoBERTa checkpoint weights!
<class 'modules.decoder_roberta_new.VaeDecoderRobertaModel'> and <class 'modules.encoder_roberta.VAE_Encoder_RobertaModel'> are not equal. In this case make sure that all encoder weights are correctly initialized. 
The following encoder weights were not tied to the decoder ['roberta/pooler']
Tying embedding spaces!
Done model...


TypeError: load_from_checkpoint() got an unexpected keyword argument 'do_tie_embeddings'

In [24]:
with torch.no_grad():
    for batch in validation_loader:
        batch = transfer_batch_to_device(batch, device_name=DEVICE)
        input_ids, attention_mask = batch["input_ids"], batch["attention_mask"]
        print("input_ids", input_ids.shape)
        #out = vae_model(input_ids=input_ids, attention_mask=attention_mask, beta=1.0, return_cross_entropy=False)
        enc_out = vae_model.encoder.encode(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"], return_embeddings=True)
        print(enc_out.keys())
        print(enc_out["word_embeddings"].shape)
        
        dec_out = vae_model.decoder(enc_out["latent_z"], input_ids, attention_mask,
                                    return_cross_entropy=False,
                                    return_last_hidden_state=True)
        
        print(dec_out.keys())
        
        print(dec_out["last_hidden_state"].shape)
        
        break

  return torch.tensor(x, **format_kwargs)
  return torch.tensor(x, **format_kwargs)
  return torch.tensor(x, **format_kwargs)
  return torch.tensor(x, **format_kwargs)


input_ids torch.Size([2, 36])
dict_keys(['mu', 'logvar', 'latent_z', 'kl_loss', 'hinge_kl_loss', 'mmd_loss', 'log_q_z_x', 'log_p_z', 'word_embeddings'])
torch.Size([2, 36, 768])
dict_keys(['cross_entropy', 'cross_entropy_per_word', 'predictions', 'exact_match', 'attention_probs', 'attention_to_latent', 'hidden_states', 'probabilities', 'last_hidden_state', 'logits', 'past_key_values', 'cross_attentions'])
torch.Size([2, 36, 768])


In [5]:
# vae_model.eval()
# with torch.no_grad():

#     for batch in validation_loader:
#         batch = transfer_batch_to_device(batch, device_name=DEVICE)
#         input_ids, attention_mask = batch["input_ids"], batch["attention_mask"]

#         # Encode
#         enc_out = vae_model.encoder.encode(input_ids, attention_mask, 
#                                            n_samples=1, hinge_kl_loss_lambda=0.0,
#                                            return_log_q_z_x=False, return_log_p_z=False, 
#                                            return_embeddings=False)
#         latent_to_decoder_output = vae_model.decoder.latent_to_decoder(enc_out["latent_z"])
        
#         # Decode and return cache
#         dec_out = vae_model.decoder.model(latent_to_decoder_output=latent_to_decoder_output,
#                                           input_ids=input_ids[:, :5],
#                                           attention_mask=None,
#                                           return_cross_entropy=False,
#                                           return_predictions=True,
#                                           use_cache=True)
        
#         preds_forward_1 = dec_out["predictions"]
        
#         past_key_values = dec_out["past_key_values"]
# #         print(past_key_values[0][0].shape)
#         # cut off the latent
        
#         past_key_values = tuple([tuple([pair[0][:, :, 1:3, :], pair[1][:, :, 1:3, :]]) for pair in past_key_values])
#         print("**", past_key_values[0][0].shape)
        
#         # Now check whether the predictions for second foward with cache
#         dec_out = vae_model.decoder.model(latent_to_decoder_output=latent_to_decoder_output,
#                                           input_ids=input_ids[:, 2:10],
#                                           attention_mask=None,
#                                           return_cross_entropy=False,
#                                           return_predictions=True,
#                                           past_key_values=past_key_values,
#                                           use_cache=True)
        
        
#         print(preds_forward_1)
#         print(dec_out["predictions"])
        
        
#         break

In [6]:
# from pytorch_lightning import seed_everything

# seed_everything(0)

# def naive_auto_regressive_reconstruct(vae_model, input_ids, attention_mask, tokenizer):
#     vae_model.eval()
    
#     batch_size = input_ids.shape[0]
#     generated_so_far = torch.tensor([[tokenizer.bos_token_id, tokenizer.eos_token_id] for _ in range(batch_size)]).to(DEVICE)
    
#     print("Original:\n")
#     for t in tokenizer_batch_decode(input_ids.cpu(), tokenizer):
#         print(t)
    
#     with torch.no_grad():
        
            
#         enc_out = vae_model.encoder.encode(input_ids, attention_mask)
        
# #         print("\n\nAuto-regressive forward VAE:\n")
# #         for t in tokenizer_batch_decode(out["predictions"].cpu(), tokenizer):
# #             print(t)
            
#         latent_to_decoder_output = vae_model.decoder.latent_to_decoder(enc_out["latent_z"])
        
#         test_caches = []
        
#         # do a naive recurrent forward pass
#         for i in range(10):
#             dec_out = vae_model.decoder(input_ids=generated_so_far, attention_mask=None, 
#                                         latent_z=enc_out["latent_z"], return_predictions=True,
#                                         return_cross_entropy=False)
            
#              # Forward the decoder
#             dec_out = vae_model.decoder.model(latent_to_decoder_output=latent_to_decoder_output,
#                                               input_ids=generated_so_far,
#                                               attention_mask=None,
#                                               return_cross_entropy=False,
#                                               return_predictions=True,
#                                               use_cache=False)
            
            
#             new_preds = dec_out['predictions'][:, -1]
            
#             # Concat into <last prediction> </s> format for next round
#             generated_so_far = torch.cat(
#                 (generated_so_far[:, :-1], new_preds.unsqueeze(1), generated_so_far[:, -1].unsqueeze(1)), dim=1)   
            
#         latent_z = enc_out["latent_z"]
        
#         dec_out_2 = vae_model.decoder.autoregressive_decode(latent_z, max_seq_len=10,
#                                                             return_predictions=True,
#                                                             return_probabilities=False,
#                                                             return_logits=False,
#                                                             nucleus_sampling=False,
#                                                             device_name="cuda:0")
        
#         out_3 = vae_model(input_ids, 0.0, attention_mask,
#                           auto_regressive=True,
#                           return_predictions=True)
        
#         print(generated_so_far[:, 1:-1].shape, dec_out_2["predictions"].shape, out_3["predictions"].shape)
        
#         print(out_3['predictions'][:, :10])
#         print(generated_so_far[:, 1:-1])
#         print(dec_out_2["predictions"])
        
#         print(tokenizer_batch_decode(generated_so_far[:, 1:-1], tokenizer))
#         print(tokenizer_batch_decode(dec_out_2["predictions"], tokenizer))
        
        
# #         print("\n\nNaive autoregressive forward:\n")
# #         for t in tokenizer_batch_decode(generated_so_far.cpu(), tokenizer):
# #             print(t)
# vae_model.eval()
# with torch.no_grad():

#     for batch in validation_loader:
#         batch = transfer_batch_to_device(batch, device_name=DEVICE)
#         input_ids, attention_mask = batch["input_ids"], batch["attention_mask"]
#         naive_auto_regressive_reconstruct(vae_model, input_ids, attention_mask, tokenizer)
#         break

In [7]:
from pytorch_lightning import seed_everything

seed_everything(0)

def naive_auto_regressive_reconstruct(vae_model, input_ids, attention_mask, tokenizer):
    
#     batch_size = input_ids.shape[0]
#     generated_so_far = torch.tensor([[tokenizer.bos_token_id, tokenizer.eos_token_id] for _ in range(batch_size)]).to(DEVICE)
    
#     enc_out = vae_model.encoder.encode(input_ids, attention_mask)
#     latent_to_decoder_output = vae_model.decoder.latent_to_decoder(enc_out["latent_z"])
    
#     # Do a naive recurrent forward pass
#     for i in range(32):

#          # Forward the decoder
#         dec_out = vae_model.decoder.model(latent_to_decoder_output=latent_to_decoder_output,
#                                           input_ids=generated_so_far,
#                                           attention_mask=None,
#                                           return_cross_entropy=False,
#                                           return_predictions=True,
#                                           use_cache=False)


#         new_preds = dec_out['predictions'][:, -1]

#         # Concat into <last prediction> </s> format for next round
#         generated_so_far = torch.cat(
#             (generated_so_far[:, :-1], new_preds.unsqueeze(1), generated_so_far[:, -1].unsqueeze(1)), dim=1)   
        
    
    out = vae_model(input_ids, 0.0, attention_mask, auto_regressive=True, return_predictions=True)
#     out_2 = vae_model.decoder.autoregressive_decode(enc_out["latent_z"], max_seq_len=32,
#                                                     return_predictions=True,
#                                                     return_probabilities=False,
#                                                     return_logits=False)
    
    t0 = tokenizer_batch_decode(input_ids[:, :32].cpu(), tokenizer)
#     t1 = tokenizer_batch_decode(generated_so_far[:, :], tokenizer)
    t2 = tokenizer_batch_decode(out["predictions"][:, :], tokenizer)
#     t3 = tokenizer_batch_decode(out_2["predictions"][:, :], tokenizer)
    
    for i, ts in enumerate([t0, t2]):
        if i == 0:
            print("Original")
        else:
            print("\nReconstructed")
        
        for t in ts:
            if "<s>" in t:
                t = t[3:]
            index_end = t.find("</s>")
            if index_end != -1:
                t = t[:index_end]
            print("-  ", t)
    print("-----")
        

vae_model.eval()

with torch.no_grad():
    
    for batch_i, batch in enumerate(validation_loader):
        batch = transfer_batch_to_device(batch, device_name=DEVICE)
        input_ids, attention_mask = batch["input_ids"], batch["attention_mask"]
        naive_auto_regressive_reconstruct(vae_model, input_ids, attention_mask, tokenizer)
    
        if batch_i == 5:
            break

  return torch.tensor(x, **format_kwargs)
  return torch.tensor(x, **format_kwargs)
  return torch.tensor(x, **format_kwargs)
  return torch.tensor(x, **format_kwargs)


Original
-   consumers may want to move their telephones a little closer to the tv set
-   <unk> <unk> watching abc's monday night football can now vote during <unk> for the greatest play in N years from among four or five 

Reconstructed
-   Consumers may want to move their televisions a little closer to the telephones set to their tv box. Telephones may want to move a little closer to the tv set
-   Mathawan  saying can remember overeating during this second greatest football season for  four years in the night from 'Nsca can scout or  downloading omens
-----
Original
-   two weeks ago viewers of several nbc <unk> consumer segments started calling a N number for advice on various <unk> issues
-   and the new syndicated reality show hard copy records viewers'opinions for possible airing on the next day's show

Reconstructed
-   Two n viewers of several weeks  number one announcements started discussing a Nbc  dietary advice for those on various levels of internet
-   And the new syndi

## 2 Compute neares neighbors of samples from the priors to latents from the train set

In [4]:
# Get some samples from the prior

latent_size = 32
scale = 1.0  # <-- standard normal
n_samples = 10
loc = torch.zeros(latent_size, device="cpu")
std = torch.ones(latent_size, device="cpu") * scale
prior_dist = torch.distributions.normal.Normal(loc, std)
prior_samples = prior_dist.sample((n_samples,))

In [5]:
from sklearn.neighbors import NearestNeighbors

n_neighbors = 5

runs =  [
 "2021-02-03-PTB-latent32-FB-0.5-run-09:31:02", 
 "2021-02-03-PTB-latent32-FB-1.50-run-12:13:36",
 "2021-02-03-PTB-latent32-autoencoder-run-17:30:41"]

for r in runs:
    print("*"*50)
    print(get_clean_name(r))
    print("*"*50)
    print()
    
    # Get the mean of the posteriors as the encodings
    posterior_means = latents_runs[r]["mus"]
    
    # Load the correct model
    latent_size = int(r.split("-")[4][-2:])
    p = PTB_run_name_paths[r]
    vae_model = load_model_for_eval(p, device_name=DEVICE, 
                                    latent_size=latent_size, 
                                    add_latent_via_memory=True,
                                    add_latent_via_embeddings=False, 
                                    do_tie_weights=True, 
                                    do_tie_embedding_spaces=True,
                                    add_decoder_output_embedding_bias=False)
    print("\n\n")
    print("-"*30)

    # Make a brute-forced nearest neighbor graph of the validation set, find 10 neighbors
    nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm='brute').fit(posterior_means.numpy())

    # Get the closest data train posterior means to a sample from the prior
    # distances between those and indices of the initial graph (corresponding to latents matrix)
    distances, indices = nbrs.kneighbors(prior_samples.cpu().numpy())

    # Go over the samples in the priors and evaluate the closest train posteriors
    for s_i in range(len(prior_samples)):

        sample_from_prior = prior_samples[s_i, :]

        closest_post_means_train = []
        closest_string_train = []

        ds = []
        for idx, d in zip(indices[s_i], distances[s_i]):

            closest_post_means_train.append(posterior_means[idx, :])
            closest_string_train.append(latents_runs[r]["strings"][idx])
            ds.append(d)

        sample_and_train_neighbors = torch.stack([sample_from_prior] + closest_post_means_train, dim=0).to(DEVICE)
        
        vae_model.eval()
        with torch.no_grad():
            dec = vae_model.decoder.autoregressive_decode(sample_and_train_neighbors,
                                                          labels=None,
                                                          max_seq_len=32,
                                                          return_exact_match=False,
                                                          return_cross_entropy=False,
                                                          return_attention_probs=False,
                                                          return_attention_to_latent=False,
                                                          return_hidden_states=False,
                                                          return_last_hidden_state=False,
                                                          return_predictions=True,
                                                          return_probabilities=False,
                                                          return_output_word_embeddings=False,
                                                          return_logits=False,
                                                          tokenizer=tokenizer,
                                                          nucleus_sampling=False, # <-- to check variability from the latent turn this off
                                                          reduce_seq_dim_ce="sum",
                                                          reduce_seq_dim_exact_match="none",
                                                          reduce_batch_dim_exact_match="none",
                                                          reduce_batch_dim_ce="mean",
                                                          device_name=DEVICE)
        
        # The first prediction is the one belonging to the sample from prior
        # the rest are reconstructions of the posterior mean encodings
        
        # Post process the strings, cutting of padding etc.
        text = tokenizer_batch_decode(dec["predictions"], tokenizer)
        N = len(text)
        text += closest_string_train
        ts = []
        for t in text:
            if "<s>" in t:
                t = t[3:]
            index_end = t.find("</s>")
            if index_end == -1:
                ts.append(t)
            else:
                ts.append(t[:index_end])
                
        closest_string_train = ts[N:]
        text = ts[:N]

        print("Sample from prior decoded:\n", ts[0])
        print("-"*60)
        print("Means from train sample latents that lay closest:")
        for i, (recon, orig) in enumerate(zip(ts[1:], closest_string_train)):
            print(f"{i} -- D = {ds[i]:.2f}")
            print("** Auto-regressive reconstruction:\n", recon)
            print("** Original string:\n", orig)
            print()
        print("-"*60)
        print("-"*60)
        print()

**************************************************
NZ-32 | FB-0.50
**************************************************

Loading model...
OLD FILE decoder_roberta.py activated!


Some weights of the model checkpoint at roberta-base were not used when initializing VAE_Decoder_RobertaForCausalLM: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing VAE_Decoder_RobertaForCausalLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing VAE_Decoder_RobertaForCausalLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of VAE_Decoder_RobertaForCausalLM were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.embeddings.position_ids', 'lm_head.decoder.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Replacing linear output layer with one without bias!


Some weights of the model checkpoint at roberta-base were not used when initializing VAE_Encoder_RobertaModel: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing VAE_Encoder_RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing VAE_Encoder_RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of VAE_Encoder_RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.embeddings.position_ids']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Tying encoder decoder RoBERTa checkpoint weights!
<class 'modules.decoder_roberta.VAE_Decoder_RobertaModel'> and <class 'modules.encoder_roberta.VAE_Encoder_RobertaModel'> are not equal. In this case make sure that all encoder weights are correctly initialized. 
The following encoder weights were not tied to the decoder ['roberta/pooler']
Tying embedding spaces!
Done model...
Loading VAE_model, optimizer and scheduler from ../Runs/2021-02-03-PTB-latent32-FB-0.5-run-09:31:02/checkpoint-best.pth
Removing module string from state dict from checkpoint
Checkpoint global_step: best, epoch: 37, best_valid_loss: 84.10610651086878



------------------------------


TypeError: forward() got an unexpected keyword argument 'use_cache'

In [None]:
for r, neighbor_triplets in set_of_neighbor_triplets.items():
    print(get_clean_name(r))
    
    samples = torch.cat(neighbor_triplets, dim=0)
    print(samples.shape)
    
    p = PTB_run_name_paths[r]
    vae_model = load_model_for_eval(p, device_name="cuda:0", 
                                    latent_size=latent_size, 
                                    add_latent_via_memory=True,
                                    add_latent_via_embeddings=False, 
                                    do_tie_weights=True, 
                                    do_tie_embedding_spaces=True,
                                    add_decoder_output_embedding_bias=False)
    
    dec = vae_model.decoder.autoregressive_decode(samples,
                                                    labels=None,
                                                    max_seq_len=32,
                                                    return_exact_match=False,
                                                    return_cross_entropy=False,
                                                    return_attention_probs=False,
                                                    return_attention_to_latent=False,
                                                    return_hidden_states=False,
                                                    return_last_hidden_state=False,
                                                    return_predictions=True,
                                                    return_probabilities=False,
                                                    return_output_word_embeddings=False,
                                                    return_logits=False,
                                                    tokenizer=tokenizer,
                                                    nucleus_sampling=False, # <-- to check variability from the latent turn this off
                                                    reduce_seq_dim_ce="sum",
                                                    reduce_seq_dim_exact_match="none",
                                                    reduce_batch_dim_exact_match="none",
                                                    reduce_batch_dim_ce="mean",
                                                    device_name=device_name)
        
        text = tokenizer_batch_decode(dec["predictions"], tokenizer)
        ts = []
        for t in text:
            index_end = t.find("</s>")
            if index_end == -1:
                ts.append(t)
            else:
                ts.append(t[:index_end])
        for i, t in enumerate(ts):
            if i % 3 == 0:
                print("-"*50)
                
            print(i, t)
            

In [None]:
n_samples = 100
device_name = "cuda:0"

runs =  ["2021-02-03-PTB-latent32-FB-0.00-run-14:32:09",
 "2021-02-03-PTB-latent32-FB-0.5-run-09:31:02", 
 "2021-02-03-PTB-latent32-FB-1.0-run-11:43:17", 
 "2021-02-03-PTB-latent32-FB-1.50-run-12:13:36",
 "2021-02-03-PTB-latent32-autoencoder-run-17:30:41"]

text_preds = {}
atts_to_latent = {}

for r in runs:
    
    p = PTB_run_name_paths[r]
    print(get_clean_name(r))
    
    latent_size = int(r.split("-")[4][-2:])
    
    vae_model = load_model_for_eval(p, device_name="cuda:0", 
                                    latent_size=latent_size, 
                                    add_latent_via_memory=True,
                                    add_latent_via_embeddings=False, 
                                    do_tie_weights=True, 
                                    do_tie_embedding_spaces=True,
                                    add_decoder_output_embedding_bias=False)
    
    vae_model.eval()

    with torch.no_grad():
#         prior_sample = vae_model.sample_from_prior(latent_size, n_samples=n_samples, device_name=device_name)
        
        loc = torch.zeros(latent_size, device=device_name)
        scale = torch.ones(latent_size, device=device_name)*2.0
        prior_dist = torch.distributions.normal.Normal(loc, scale)
        prior_sample = prior_dist.sample((n_samples,))
        
        prior_sample = torch.zeros(latent_size, device=device_name)
        
        prior_decoded_ar = vae_model.decoder.autoregressive_decode(
                                                    prior_sample,
                                                    labels=None,
                                                    max_seq_len=32,
                                                    return_exact_match=False,
                                                    return_cross_entropy=False,
                                                    return_attention_probs=False,
                                                    return_attention_to_latent=True,
                                                    return_hidden_states=False,
                                                    return_last_hidden_state=False,
                                                    return_predictions=True,
                                                    return_probabilities=False,
                                                    return_output_word_embeddings=False,
                                                    return_logits=False,
                                                    tokenizer=tokenizer,
                                                    nucleus_sampling=False, # <-- to check variability from the latent turn this off
                                                    reduce_seq_dim_ce="sum",
                                                    reduce_seq_dim_exact_match="none",
                                                    reduce_batch_dim_exact_match="none",
                                                    reduce_batch_dim_ce="mean",
                                                    device_name=device_name)
        
        atts_to_latent[r] = prior_decoded_ar["attention_to_latent"].cpu()

        prior_decoded_ar_text = tokenizer_batch_decode(prior_decoded_ar["predictions"], tokenizer)
        ts = []
        for t in prior_decoded_ar_text:
            index_end = t.find("</s>")
            if index_end == -1:
                ts.append(t)
            else:
                ts.append(t[:index_end])
        text_preds[r] = ts

for (r, ts), (_, aw) in zip(text_preds.items(), atts_to_latent.items()):
    print("*"*40)
    print(get_clean_name(r))
    print("*"*40)
    # batch, n_heads, n_layers, seq_len_query, seq_len_val
    # attention_to_latent = attention_probs[:, :, :, :-1, 0]
    aw_mean_head_layers = aw.mean(dim=1).mean(dim=1)
#     print(aw_mean_head_layers.shape)
#     print(aw.shape)
    plt.imshow(aw_mean_head_layers.numpy())
    
    plt.title("Average attention to latent over sequence", y=1.05, size=12)
    plt.colorbar()
    plt.ylabel("Samples in batch")
    plt.xlabel("Sequence dimension")
    plt.show()
    for i, t in enumerate(ts):
        print(i, t)
        
    print("-"*30)