# Imports

In [1]:
import sys; sys.path.append("../..")

In [7]:
# Torch, datasets, transformers, spacy
from datasets import load_from_disk
from transformers import RobertaTokenizerFast
from torch.utils.data import DataLoader
import torch.nn as nn
import torch
import spacy

# My utils
from utils_train import load_from_checkpoint, transfer_batch_to_device
from utils_evaluation import tokenizer_batch_decode, reconstruct_autoregressive
from train import get_model_on_device

# Plotting
import matplotlib.pyplot as plt
import seaborn as sns; sns.set()
from pylab import rcParams
%matplotlib inline

# Standard utils
import numpy as np
import pandas as pd
from functools import partial

# Sys
import os

In [3]:
# Data loading
batch_size = 32
num_workers = 2

# Device
device_name = "cuda:0"

# Data

In [4]:
def collate_fn(encoded_samples, tokenizer):
    """
    A function that assembles a batch. This is where padding is done, since it depends on
    the maximum sequence length in the batch.

    :param examples: list of truncated, tokenised & encoded sequences
    :return: padded_batch (batch x max_seq_len)
    """

    # Combine the tensors into a padded batch
    padded_batch = tokenizer.pad(encoded_samples, return_tensors='pt')

    return padded_batch

# VALIDATION DATA
valid_dataset = load_from_disk("/home/cbarkhof/code-thesis/NewsVAE/NewsData/22DEC-cnn_dailymail-roberta-seqlen64/validation")

# TOKENIZER
tokenizer = RobertaTokenizerFast.from_pretrained('roberta-base')

# TEST DATA LOADER
valid_loader = DataLoader(valid_dataset, collate_fn=partial(collate_fn, tokenizer=tokenizer),
                          batch_size=batch_size, num_workers=num_workers, pin_memory=True)

print(f"Number of test samples: {len(valid_dataset)}, number of batches of size {batch_size}: {int(np.floor(len(valid_dataset) / batch_size))}")



Number of test samples: 13368, number of batches of size 32: 417


# Run names & paths

In [12]:
run_dir = '/home/cbarkhof/code-thesis/NewsVAE/Runs'
runs_29DEC_names = ["-".join(run_name.split('-')[2:-5]) for run_name in os.listdir(run_dir) if "29DEC" in run_name]
runs_29DEC_paths = [run_dir + '/' + run_name for run_name in os.listdir(run_dir) if "29DEC" in run_name]

['/home/cbarkhof/code-thesis/NewsVAE/Runs/29DEC-exp7-AUTO-ENCODER-run-2020-12-29-10:42:11',
 '/home/cbarkhof/code-thesis/NewsVAE/Runs/29DEC-exp1-CYCLICAL-2-GRADSTEPS-135000-run-2020-12-28-21:40:07',
 '/home/cbarkhof/code-thesis/NewsVAE/Runs/29DEC-exp3-CYCLICAL-4-GRADSTEPS-6750-run-2020-12-28-21:52:48',
 '/home/cbarkhof/code-thesis/NewsVAE/Runs/29DEC-exp4-FREEBITS-0.5-ANNEAL-5000-run-2020-12-29-01:17:04',
 '/home/cbarkhof/code-thesis/NewsVAE/Runs/29DEC-exp5-FREEBITS-0.25-ANNEAL-5000-run-2020-12-29-02:02:31',
 '/home/cbarkhof/code-thesis/NewsVAE/Runs/29DEC-exp6-FREEBITS-0.125-ANNEAL-5000-run-2020-12-29-04:28:31',
 '/home/cbarkhof/code-thesis/NewsVAE/Runs/29DEC-exp2-CYCLICAL-3-GRADSTEPS-9000-run-2020-12-29-10:20:01']

In [13]:
def load_model(path):
    vae_model = get_model_on_device(device_name="cuda:0", latent_size=768, gradient_checkpointing=False, 
                                    add_latent_via_memory=True, add_latent_via_embeddings=True,
                                    do_tie_weights=True, world_master=True)

    _, _, vae_model, _, global_step, epoch, best_valid_loss = load_from_checkpoint(vae_model, path, world_master=True, ddp=False, use_amp=False)
    
    return vae_model

In [None]:
teacher_pred_text_all = []
teacher_correct_all = []
teacher_pred_ids_all = []

input_text = []

autoreg_pred_text_all = []
autoreg_correct_all = []
autoreg_pred_ids_all = []

for batch_i, batch in enumerate(valid_loader):
    print(f"{batch_i+1:5d} / {len(valid_loader):5d}", end='\r')
    
    with torch.no_grad():
        batch = transfer_batch_to_device(batch)
        
        # ----- TEACHER-FORCED -------
        teacher_output = vae_model.forward(batch["input_ids"], batch["attention_mask"], 1.0, return_predictions=True,
                                           return_attention_probs=False, return_exact_match_acc=False, return_latents=False,
                                           return_mu_logvar=False, objective='beta-vae', hinge_kl_loss_lambda=0.5)
        
        # Teacher forced returns batch x 63 (sequence of 62 + end symbol)
        teacher_ids = teacher_output["predictions"][:, :-1]
        teacher_text = tokenizer_batch_decode(teacher_ids, tokenizer)
        
        teacher_pred_text_all.append(teacher_text)
        teacher_pred_ids_all.append(teacher_ids.cpu())
        
        # ----- AUTO-REGRESSIVE (with nucleus sampling) -------
        autoreg_text, autoreg_ids = reconstruct_autoregressive(vae_model, batch, tokenizer, add_latent_via_embeddings=True,
                                                               add_latent_via_memory=True, max_seq_len=64, nucleus_sampling=True,
                                                               temperature=1.0, top_k=0, top_p=0.9, device_name="cuda:0",
                                                               return_attention_to_latent=False)
        
        
        # Auto-regressive returns batch x 64 (start + sequence of 62 + end symbol)
        autoreg_ids = autoreg_ids[:, 1:-1]
        autoreg_text = [remove_start_end_token(t) for t in autoreg_text] # remove this post-process
        
        autoreg_pred_text_all.append(autoreg_text)
        autoreg_pred_ids_all.append(autoreg_ids.cpu())
        
        # ----- INPUT -------
        input_ids = batch["input_ids"][:, 1:-1]
        input_text = tokenizer_batch_decode(input_ids, tokenizer)
        input_text = [remove_start_end_token(t) for t in input_text]
        
        input_text_all.append(input_text)
        input_ids_all.append(input_ids.cpu())
                
        # ----- EXACT MATCH -------
        teacher_correct = (teacher_ids == input_ids).float().cpu()
        teacher_correct_all.append(teacher_correct)
        
        autoreg_correct = (autoreg_ids == input_ids).float().cpu()
        autoreg_correct_all.append(autoreg_correct)
        
        if batch_i == max_batches - 1:
            break