In [14]:
import sys; sys.path.append("../")

# Functions from my code
from dataset_wrappper import NewsData
from utils_train import transfer_batch_to_device
from utils_evaluation import tokenizer_batch_decode
from run_validation import load_model_for_eval

# General libs
import os
from pathlib import Path
from transformers import RobertaTokenizerFast
import torch
import numpy as np
import copy
import pickle
import pandas as pd

# Plotting
import matplotlib.pyplot as plt
plt.style.reload_library()
plt.style.use('thesis_style')

%matplotlib inline
%config InlineBackend.figure_format='retina'

# Globals
BATCH_SIZE = 200
DEVICE = "cuda:0"
CHECKPOINT_TYPE = "best" # else "best"
RESULT_DIR = Path("result-files")

# Data
data = NewsData(batch_size=BATCH_SIZE, tokenizer_name="roberta", dataset_name="ptb_text_only", max_seq_len=64)
# data = NewsData(batch_size=BATCH_SIZE, tokenizer_name="roberta", dataset_name="cnn_dailymail", max_seq_len=64)
validation_loader = data.val_dataloader(batch_size=BATCH_SIZE)
train_loader = data.train_dataloader()
tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')

Is file!
train 42068
validation 3370
test 3761


In [19]:
def get_clean_name(run_name):
    if "latent32" in run_name:
        latent_size = 32
    elif "latent64" in run_name:
        latent_size = 64
    else:
        latent_size = 768
    #latent_size = run_name.split("-")[4][-2:]
    if "autoencoder" in run_name:
        FB = "autoencoder"
    else:
        FB = run_name.split("-")[6]
        if len(FB) == 3:
            FB += "0"
        FB = "FB-" + FB
    clean_name = f"NZ-{latent_size} | {FB}"
    return clean_name

PTB_run_name_paths = {}
for r in os.listdir("../Runs"):
    if "PTB" in r:
        path = Path("../Runs") / r / f"checkpoint-{CHECKPOINT_TYPE}.pth"
        PTB_run_name_paths[r] = path

clean_names = []
for r in PTB_run_name_paths.keys():
#     print(get_clean_name(r))
    clean_names.append(get_clean_name(r))
    
clean_names = sorted(clean_names)
for i, n in enumerate(clean_names):
    print(i, "\t", n)

0 	 NZ-32 | FB-0.00
1 	 NZ-32 | FB-0.25
2 	 NZ-32 | FB-0.50
3 	 NZ-32 | FB-0.75
4 	 NZ-32 | FB-1.00
5 	 NZ-32 | FB-1.50
6 	 NZ-32 | autoencoder
7 	 NZ-32 | autoencoder
8 	 NZ-64 | FB-0.00
9 	 NZ-64 | FB-0.25
10 	 NZ-64 | FB-0.50
11 	 NZ-64 | FB-0.75
12 	 NZ-64 | FB-1.00
13 	 NZ-64 | FB-1.50
14 	 NZ-64 | autoencoder
15 	 NZ-64 | autoencoder
16 	 NZ-768 | autoencoder
17 	 NZ-768 | autoencoder
18 	 NZ-768 | autoencoder
19 	 NZ-768 | autoencoder
20 	 NZ-768 | autoencoder
21 	 NZ-768 | autoencoder
22 	 NZ-768 | autoencoder


## 1 Collect latents of the whole train set for a few runs

In [26]:
runs =  [
 "2021-02-03-PTB-latent32-FB-0.5-run-09:31:02", 
 "2021-02-03-PTB-latent32-FB-1.50-run-12:13:36",
 "2021-02-03-PTB-latent32-autoencoder-run-17:30:41",
 "2021-03-01-1MAR-PTB-latent32-autoencoderEmbeddings-run-18:26:18"]

result_file = "train_latents.pickle"

# if os.path.isfile(result_file):
#     print(f"Load latents from file {result_file}")
#     latents_runs = pickle.load( open( result_file, "rb" ) )
# else: 
print(f"Compute latents on train set...")
# latents_runs = {}

for r, p in PTB_run_name_paths.items():

    if r not in runs:
        continue
    
    if r in latents_runs:
        continue

    print(get_clean_name(r))

    if "latent32" in r:
        latent_size = 32
    elif "latent64" in r:
        latent_size = 64
    else:
        latent_size = 768
        
    if "Embeddings" in r:
        add_latent_via_embeddings = True
    else:
        add_latent_via_embeddings = False
    vae_model = load_model_for_eval(p, device_name=DEVICE, 
                                    latent_size=latent_size, 
                                    add_latent_via_memory=True,
                                    add_latent_via_embeddings=add_latent_via_embeddings, 
                                    do_tie_weights=True, 
                                    do_tie_embedding_spaces=True,
                                    add_decoder_output_embedding_bias=False)

    latents, mus, logvars, strings = [], [], [], []

    for batch_i, batch in enumerate(train_loader):

        print(f"Batch {batch_i:3d} /{len(train_loader):3d}", end="\r")

        with torch.no_grad():
            batch = transfer_batch_to_device(batch, device_name=DEVICE)

            enc_out = vae_model.encoder.encode(batch["input_ids"], batch["attention_mask"], 
                                               n_samples=1,
                                               hinge_kl_loss_lambda=0.0,
                                               return_log_q_z_x=False,
                                               return_log_p_z=False,
                                               return_embeddings=False)

            latents.append(enc_out["latent_z"].cpu())
            mus.append(enc_out["mu"].cpu())
            logvars.append(enc_out["logvar"].cpu())
            strings.extend(tokenizer_batch_decode(batch["input_ids"].cpu(), tokenizer))

    latents = torch.cat(latents, dim=0)
    mus = torch.cat(mus, dim=0)
    logvars = torch.cat(logvars, dim=0)

    latents_runs[r] = {
        "latents": latents,
        "mus": mus,
        "logvars": logvars,
        "strings": strings
    }

pickle.dump( latents_runs, open( result_file, "wb" ) )

Compute latents on train set...
NZ-32 | autoencoder
Loading model...
Replacing linear output layer with one without bias!


Some weights of the model checkpoint at roberta-base were not used when initializing VAE_Encoder_RobertaModel: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing VAE_Encoder_RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing VAE_Encoder_RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of VAE_Encoder_RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.embeddings.position_ids']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Tying encoder decoder RoBERTa checkpoint weights!
<class 'modules.decoder_roberta_new.VaeDecoderRobertaModel'> and <class 'modules.encoder_roberta.VAE_Encoder_RobertaModel'> are not equal. In this case make sure that all encoder weights are correctly initialized. 
The following encoder weights were not tied to the decoder ['roberta/pooler']
Tying embedding spaces!
Done model...
Loading VAE_model, optimizer and scheduler from ../Runs/2021-03-01-1MAR-PTB-latent32-autoencoderEmbeddings-run-18:26:18/checkpoint-best.pth
Removing module string from state dict from checkpoint
Checkpoint global_step: best, epoch: 22, best_valid_loss: 76.88829040527344
Tying encoder decoder RoBERTa checkpoint weights!
<class 'modules.decoder_roberta_new.VaeDecoderRobertaModel'> and <class 'modules.encoder_roberta.VAE_Encoder_RobertaModel'> are not equal. In this case make sure that all encoder weights are correctly initialized. 
The following encoder weights were not tied to the decoder ['roberta/pooler']
Tying

  return torch.tensor(x, **format_kwargs)
  return torch.tensor(x, **format_kwargs)
  return torch.tensor(x, **format_kwargs)
  return torch.tensor(x, **format_kwargs)


Batch 210 /211

## 2 Compute neares neighbors of samples from the priors to latents from the train set

In [27]:
# Get some samples from the prior

latent_size = 32
scale = 1.0  # <-- standard normal
n_samples = 10
loc = torch.zeros(latent_size, device="cpu")
std = torch.ones(latent_size, device="cpu") * scale
prior_dist = torch.distributions.normal.Normal(loc, std)
prior_samples = prior_dist.sample((n_samples,))

In [29]:
from sklearn.neighbors import NearestNeighbors

n_neighbors = 5

runs =  [
 "2021-02-03-PTB-latent32-FB-0.5-run-09:31:02", 
 "2021-02-03-PTB-latent32-FB-1.50-run-12:13:36",
 "2021-02-03-PTB-latent32-autoencoder-run-17:30:41",
 "2021-03-01-1MAR-PTB-latent32-autoencoderEmbeddings-run-18:26:18"]

for r in runs:
    print("*"*50)
    print(r)
    print("*"*50)
    print()
    
    # Get the mean of the posteriors as the encodings
    posterior_means = latents_runs[r]["mus"]
    
    # Load the correct model
    
    if "latent32" in r:
        latent_size = 32
    elif "latent64" in r:
        latent_size = 64
    else:
        latent_size = 768
        
    p = PTB_run_name_paths[r]
    
    if "Embeddings" in r:
        add_latent_via_embeddings = True
    else:
        add_latent_via_embeddings = False
    
    vae_model = load_model_for_eval(p, device_name=DEVICE, 
                                    latent_size=latent_size, 
                                    add_latent_via_memory=True,
                                    add_latent_via_embeddings=add_latent_via_embeddings, 
                                    do_tie_weights=True, 
                                    do_tie_embedding_spaces=True,
                                    add_decoder_output_embedding_bias=False)
    print("\n\n")
    print("-"*30)

    # Make a brute-forced nearest neighbor graph of the validation set, find 10 neighbors
    nbrs = NearestNeighbors(n_neighbors=n_neighbors, algorithm='brute').fit(posterior_means.numpy())

    # Get the closest data train posterior means to a sample from the prior
    # distances between those and indices of the initial graph (corresponding to latents matrix)
    distances, indices = nbrs.kneighbors(prior_samples.cpu().numpy())

    # Go over the samples in the priors and evaluate the closest train posteriors
    for s_i in range(len(prior_samples)):

        sample_from_prior = prior_samples[s_i, :]

        closest_post_means_train = []
        closest_string_train = []

        ds = []
        for idx, d in zip(indices[s_i], distances[s_i]):

            closest_post_means_train.append(posterior_means[idx, :])
            closest_string_train.append(latents_runs[r]["strings"][idx])
            ds.append(d)

        sample_and_train_neighbors = torch.stack([sample_from_prior] + closest_post_means_train, dim=0).to(DEVICE)
        
        vae_model.eval()
        with torch.no_grad():
            dec = vae_model.decoder.autoregressive_decode(sample_and_train_neighbors,
                                                          labels=None,
                                                          max_seq_len=32,
                                                          return_exact_match=False,
                                                          return_cross_entropy=False,
                                                          return_attention_probs=False,
                                                          return_attention_to_latent=False,
                                                          return_hidden_states=False,
                                                          return_last_hidden_state=False,
                                                          return_predictions=True,
                                                          return_probabilities=False,
                                                          return_output_embeddings=False,
                                                          return_logits=False,
                                                          nucleus_sampling=False, # <-- to check variability from the latent turn this off
                                                          reduce_seq_dim_ce="sum",
                                                          reduce_seq_dim_exact_match="none",
                                                          reduce_batch_dim_exact_match="none",
                                                          reduce_batch_dim_ce="mean",
                                                          device_name=DEVICE)
        
        # The first prediction is the one belonging to the sample from prior
        # the rest are reconstructions of the posterior mean encodings
        
        # Post process the strings, cutting of padding etc.
        text = tokenizer_batch_decode(dec["predictions"], tokenizer)
        N = len(text)
        text += closest_string_train
        ts = []
        for t in text:
            if "<s>" in t:
                t = t[3:]
            index_end = t.find("</s>")
            if index_end == -1:
                ts.append(t)
            else:
                ts.append(t[:index_end])
                
        closest_string_train = ts[N:]
        text = ts[:N]

        print("Sample from prior decoded:\n", ts[0])
        print("-"*60)
        print("Means from train sample latents that lay closest:")
        for i, (recon, orig) in enumerate(zip(ts[1:], closest_string_train)):
            print(f"{i} -- D = {ds[i]:.2f}")
            print("** Auto-regressive reconstruction:\n", recon)
            print("** Original string:\n", orig)
            print()
        print("-"*60)
        print("-"*60)
        print()

**************************************************
2021-02-03-PTB-latent32-FB-0.5-run-09:31:02
**************************************************

Loading model...
Replacing linear output layer with one without bias!


Some weights of the model checkpoint at roberta-base were not used when initializing VAE_Encoder_RobertaModel: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing VAE_Encoder_RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing VAE_Encoder_RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of VAE_Encoder_RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.embeddings.position_ids']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Tying encoder decoder RoBERTa checkpoint weights!
<class 'modules.decoder_roberta_new.VaeDecoderRobertaModel'> and <class 'modules.encoder_roberta.VAE_Encoder_RobertaModel'> are not equal. In this case make sure that all encoder weights are correctly initialized. 
The following encoder weights were not tied to the decoder ['roberta/pooler']
Tying embedding spaces!
Done model...
Loading VAE_model, optimizer and scheduler from ../Runs/2021-02-03-PTB-latent32-FB-0.5-run-09:31:02/checkpoint-best.pth
Removing module string from state dict from checkpoint
Checkpoint global_step: best, epoch: 37, best_valid_loss: 84.10610651086878
Tying encoder decoder RoBERTa checkpoint weights!
<class 'modules.decoder_roberta_new.VaeDecoderRobertaModel'> and <class 'modules.encoder_roberta.VAE_Encoder_RobertaModel'> are not equal. In this case make sure that all encoder weights are correctly initialized. 
The following encoder weights were not tied to the decoder ['roberta/pooler']
Tying embedding spaces!



Sample from prior decoded:
 but the problem is that the fed isn't <unk> to the problem
------------------------------------------------------------
Means from train sample latents that lay closest:
0 -- D = 4.64
** Auto-regressive reconstruction:
 but the market isn't as easy as it used to be
** Original string:
 there's a rising fear that perhaps mrs. thatcher's style of management has become a political liability says bill martin senior economist at london brokers <unk> & drew

1 -- D = 4.69
** Auto-regressive reconstruction:
 but the <unk> of the fed's policy isn't likely to change
** Original string:
 what's worrying her supporters is that the economic cycle may be out of <unk> with the political timetable

2 -- D = 4.70
** Auto-regressive reconstruction:
 but it's not clear whether the <unk> of the <unk> is enough to topple the government
** Original string:
 but i feel that if somebody doesn't get up and start talking about this now the next time around when we have the next <unk

Replacing linear output layer with one without bias!


Some weights of the model checkpoint at roberta-base were not used when initializing VAE_Encoder_RobertaModel: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing VAE_Encoder_RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing VAE_Encoder_RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of VAE_Encoder_RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.embeddings.position_ids']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Tying encoder decoder RoBERTa checkpoint weights!
<class 'modules.decoder_roberta_new.VaeDecoderRobertaModel'> and <class 'modules.encoder_roberta.VAE_Encoder_RobertaModel'> are not equal. In this case make sure that all encoder weights are correctly initialized. 
The following encoder weights were not tied to the decoder ['roberta/pooler']
Tying embedding spaces!
Done model...
Loading VAE_model, optimizer and scheduler from ../Runs/2021-02-03-PTB-latent32-FB-1.50-run-12:13:36/checkpoint-best.pth
Removing module string from state dict from checkpoint
Checkpoint global_step: best, epoch: 56, best_valid_loss: 71.11977527759693
Tying encoder decoder RoBERTa checkpoint weights!
<class 'modules.decoder_roberta_new.VaeDecoderRobertaModel'> and <class 'modules.encoder_roberta.VAE_Encoder_RobertaModel'> are not equal. In this case make sure that all encoder weights are correctly initialized. 
The following encoder weights were not tied to the decoder ['roberta/pooler']
Tying embedding spaces!


Sample from prior decoded:
 trouble is that the japanese are a <unk> of the japanese organization to the u.s.
------------------------------------------------------------
Means from train sample latents that lay closest:
0 -- D = 3.77
** Auto-regressive reconstruction:
 britain's current system of control is a <unk> of the power of perestroika which has been severely limited by the <unk> of
** Original string:
 sverdlovsk is a large gray cloud over glasnost and indeed over the legitimacy of the arms-control process itself

1 -- D = 4.39
** Auto-regressive reconstruction:
 japan is a major importer of the u.s.
** Original string:
 japan is a major target for the soviets

2 -- D = 4.41
** Auto-regressive reconstruction:
 besides being a big <unk> for the big board a big market for program trading is a big drain on <unk>
** Original string:
 jumping in on big deals is a high profile way to <unk> the problem of not having a strong <unk> network

3 -- D = 4.44
** Auto-regressive reconstruct

Sample from prior decoded:
 the white house is very much worried about mr. bush's <unk> that it is too hard to read congress
------------------------------------------------------------
Means from train sample latents that lay closest:
0 -- D = 5.06
** Auto-regressive reconstruction:
 mr. kasparov couldn't be reached for comment yesterday but he said the <unk> of the new crowd will be difficult for anyone outside the
** Original string:
 mr. stoltzman must have worried that his audience might not be able to take it he warned us in advance that new york <unk> lasts N N minutes

1 -- D = 5.10
** Auto-regressive reconstruction:
 mr. bush may not be a friendly man for the bush administration but he is likely to be able to persuade congress to <unk> its own rules
** Original string:
 mr. baker may want to avoid criticism from senate majority leader george mitchell but as secretary of state his audience is the entire free world not just congress

2 -- D = 5.14
** Auto-regressive reconstructi

Some weights of the model checkpoint at roberta-base were not used when initializing VAE_Encoder_RobertaModel: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing VAE_Encoder_RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing VAE_Encoder_RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of VAE_Encoder_RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.embeddings.position_ids']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Tying encoder decoder RoBERTa checkpoint weights!
<class 'modules.decoder_roberta_new.VaeDecoderRobertaModel'> and <class 'modules.encoder_roberta.VAE_Encoder_RobertaModel'> are not equal. In this case make sure that all encoder weights are correctly initialized. 
The following encoder weights were not tied to the decoder ['roberta/pooler']
Tying embedding spaces!
Done model...
Loading VAE_model, optimizer and scheduler from ../Runs/2021-02-03-PTB-latent32-autoencoder-run-17:30:41/checkpoint-best.pth
Removing module string from state dict from checkpoint
Checkpoint global_step: best, epoch: 56, best_valid_loss: 62.82812358714916
Tying encoder decoder RoBERTa checkpoint weights!
<class 'modules.decoder_roberta_new.VaeDecoderRobertaModel'> and <class 'modules.encoder_roberta.VAE_Encoder_RobertaModel'> are not equal. In this case make sure that all encoder weights are correctly initialized. 
The following encoder weights were not tied to the decoder ['roberta/pooler']
Tying embedding spac

Sample from prior decoded:
 we see a new danger to the very notion of wall street's original title
------------------------------------------------------------
Means from train sample latents that lay closest:
0 -- D = 7.24
** Auto-regressive reconstruction:
 you 'd think the risk was all about selling stock to offset the euphoria of euphoria over the new york times's new financial times a
** Original string:
 our fear was people would look just at the beta of a gold fund and say here is an investment with very low risk says john <unk> director of research for the chicago-based group

1 -- D = 7.36
** Auto-regressive reconstruction:
 your editorial was missing the point of many of our news stories and answers to the <unk> problem
** Original string:
 your story missed some essential points of the conference on the global environment are we <unk>

2 -- D = 7.64
** Auto-regressive reconstruction:
 this is the kind of emergency situation that <unk> the washington post and other financial 

Replacing linear output layer with one without bias!


Some weights of the model checkpoint at roberta-base were not used when initializing VAE_Encoder_RobertaModel: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
- This IS expected if you are initializing VAE_Encoder_RobertaModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing VAE_Encoder_RobertaModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of VAE_Encoder_RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.embeddings.position_ids']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Tying encoder decoder RoBERTa checkpoint weights!
<class 'modules.decoder_roberta_new.VaeDecoderRobertaModel'> and <class 'modules.encoder_roberta.VAE_Encoder_RobertaModel'> are not equal. In this case make sure that all encoder weights are correctly initialized. 
The following encoder weights were not tied to the decoder ['roberta/pooler']
Tying embedding spaces!
Done model...
Loading VAE_model, optimizer and scheduler from ../Runs/2021-03-01-1MAR-PTB-latent32-autoencoderEmbeddings-run-18:26:18/checkpoint-best.pth
Removing module string from state dict from checkpoint
Checkpoint global_step: best, epoch: 22, best_valid_loss: 76.88829040527344
Tying encoder decoder RoBERTa checkpoint weights!
<class 'modules.decoder_roberta_new.VaeDecoderRobertaModel'> and <class 'modules.encoder_roberta.VAE_Encoder_RobertaModel'> are not equal. In this case make sure that all encoder weights are correctly initialized. 
The following encoder weights were not tied to the decoder ['roberta/pooler']
Tying

Sample from prior decoded:
 the <unk> index is based on the number of <unk> banks in the u.s. and europe
------------------------------------------------------------
Means from train sample latents that lay closest:
0 -- D = 6.12
** Auto-regressive reconstruction:
 the <unk> agency said the <unk> of the national association of financial institutions is the largest agency in the country
** Original string:
 the <unk> newspaper publishes the withheld portion of the swedish national audit bureau's report

1 -- D = 6.15
** Auto-regressive reconstruction:
 the bank of japan said it is considering <unk> its credit-card portfolio to attract new credit
** Original string:
 bank of england governor robin <unk> urged banks to be cautious in financing leveraged buy-outs

2 -- D = 6.21
** Auto-regressive reconstruction:
 the bank said it will continue to issue new debt and will continue to issue new subordinated debt to banks in the u.s. and europe
** Original string:
 lloyd's of london said it pl

Sample from prior decoded:
 the <unk> of the national defense institute and the national defense institute of <unk> were the first to publish a <unk> of the world's most advanced
------------------------------------------------------------
Means from train sample latents that lay closest:
0 -- D = 5.25
** Auto-regressive reconstruction:
 the <unk> <unk> spacecraft was launched by the <unk> <unk> in <unk> mass. in N and was <unk> by the u.s.
** Original string:
 the galileo spacecraft <unk> <unk> toward the planet jupiter while five <unk> aboard the space shuttle atlantis measured the earth's ozone layer

1 -- D = 5.48
** Auto-regressive reconstruction:
 the <unk> <unk> of the <unk> <unk> of the world's top football players was the <unk> of the world's top football players
** Original string:
 polls once named tokyo giants star <unk> <unk> a <unk> <unk> <unk> soul as the male symbol of japan

2 -- D = 5.56
** Auto-regressive reconstruction:
 the <unk> computer system was in the <unk> of

In [None]:
n_samples = 100
device_name = "cuda:0"

runs =  ["2021-02-03-PTB-latent32-FB-0.00-run-14:32:09",
 "2021-02-03-PTB-latent32-FB-0.5-run-09:31:02", 
 "2021-02-03-PTB-latent32-FB-1.0-run-11:43:17", 
 "2021-02-03-PTB-latent32-FB-1.50-run-12:13:36",
 "2021-02-03-PTB-latent32-autoencoder-run-17:30:41"]

text_preds = {}
atts_to_latent = {}

for r in runs:
    
    p = PTB_run_name_paths[r]
    print(get_clean_name(r))
    
    latent_size = int(r.split("-")[4][-2:])
    
    vae_model = load_model_for_eval(p, device_name="cuda:0", 
                                    latent_size=latent_size, 
                                    add_latent_via_memory=True,
                                    add_latent_via_embeddings=False, 
                                    do_tie_weights=True, 
                                    do_tie_embedding_spaces=True,
                                    add_decoder_output_embedding_bias=False)
    
    vae_model.eval()

    with torch.no_grad():
#         prior_sample = vae_model.sample_from_prior(latent_size, n_samples=n_samples, device_name=device_name)
        
        loc = torch.zeros(latent_size, device=device_name)
        scale = torch.ones(latent_size, device=device_name)*2.0
        prior_dist = torch.distributions.normal.Normal(loc, scale)
        prior_sample = prior_dist.sample((n_samples,))
        
        prior_sample = torch.zeros(latent_size, device=device_name)
        
        prior_decoded_ar = vae_model.decoder.autoregressive_decode(
                                                    prior_sample,
                                                    labels=None,
                                                    max_seq_len=32,
                                                    return_exact_match=False,
                                                    return_cross_entropy=False,
                                                    return_attention_probs=False,
                                                    return_attention_to_latent=True,
                                                    return_hidden_states=False,
                                                    return_last_hidden_state=False,
                                                    return_predictions=True,
                                                    return_probabilities=False,
                                                    return_output_word_embeddings=False,
                                                    return_logits=False,
                                                    tokenizer=tokenizer,
                                                    nucleus_sampling=False, # <-- to check variability from the latent turn this off
                                                    reduce_seq_dim_ce="sum",
                                                    reduce_seq_dim_exact_match="none",
                                                    reduce_batch_dim_exact_match="none",
                                                    reduce_batch_dim_ce="mean",
                                                    device_name=device_name)
        
        atts_to_latent[r] = prior_decoded_ar["attention_to_latent"].cpu()

        prior_decoded_ar_text = tokenizer_batch_decode(prior_decoded_ar["predictions"], tokenizer)
        ts = []
        for t in prior_decoded_ar_text:
            index_end = t.find("</s>")
            if index_end == -1:
                ts.append(t)
            else:
                ts.append(t[:index_end])
        text_preds[r] = ts

for (r, ts), (_, aw) in zip(text_preds.items(), atts_to_latent.items()):
    print("*"*40)
    print(get_clean_name(r))
    print("*"*40)
    # batch, n_heads, n_layers, seq_len_query, seq_len_val
    # attention_to_latent = attention_probs[:, :, :, :-1, 0]
    aw_mean_head_layers = aw.mean(dim=1).mean(dim=1)
#     print(aw_mean_head_layers.shape)
#     print(aw.shape)
    plt.imshow(aw_mean_head_layers.numpy())
    
    plt.title("Average attention to latent over sequence", y=1.05, size=12)
    plt.colorbar()
    plt.ylabel("Samples in batch")
    plt.xlabel("Sequence dimension")
    plt.show()
    for i, t in enumerate(ts):
        print(i, t)
        
    print("-"*30)