In [None]:
import torch

from transformers import AutoTokenizer
from models.bart_extractor import BartExtractor, BART_BASE
from dataset.msc_summary_turns import MSC_Turns
from dataset.msc_summary import MSC_Summaries

import utils.logging as logging

In [None]:
logging.set_log_level(logging.SPAM)

# Settings for test
datadir = '/Users/FrankVerhoef/Programming/PEX/data/'
basedir = 'msc/msc_personasummary/'
checkpoint_dir = '/Users/FrankVerhoef/Programming/PEX/checkpoints/'
load = 'trained_bart'
sessions = [1, 2, 3, 4]
len_context = 2
speaker_prefixes = ["<self>", "<other>"]
nofact_token = '<nofact>'
add_tokens = speaker_prefixes + [nofact_token]
test_samples = 20

# Setup
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base", padding_side='left')
if add_tokens is not None:
    num_added_toks = tokenizer.add_tokens(add_tokens)
nofact_token_id = tokenizer.convert_tokens_to_ids(nofact_token) if nofact_token != '' else tokenizer.eos_token_id
assert nofact_token_id != tokenizer.unk_token_id, "nofact_token '{}' must be known token".format(nofact_token)

model = BartExtractor(bart_base=BART_BASE, nofact_token_id=nofact_token_id)
model.bart.resize_token_embeddings(len(tokenizer))

dataset_config = {
    'basedir': datadir + basedir,
    'sessions': sessions,
    'tokenizer': tokenizer,
    'len_context': len_context,
    'speaker_prefixes': speaker_prefixes,
    'nofact_token': nofact_token,
    'batch_format': 'huggingface',
    'batch_pad_id': tokenizer.pad_token_id
} 
msc_turns = MSC_Turns(subset='test', max_samples=test_samples, **dataset_config)

logging.info("Loading model from {}".format(checkpoint_dir + load))
model.load_state_dict(torch.load(checkpoint_dir + load, map_location=torch.device('cpu')))

In [None]:
for i in range(3):
    print(msc_turns[i])

In [None]:
eval_kwargs = {'device': 'cpu', 'log_interval': 10, 'decoder_max': 20}

logging.info("Evaluating model on {} samples of testdata in {} with arguments {}".format(len(msc_turns), basedir, eval_kwargs))
eval_stats = msc_turns.evaluate(model, **eval_kwargs)

In [None]:
msc_summaries = MSC_Summaries(
    basedir=datadir + basedir, 
    sessions=sessions, 
    subset='test', 
    tokenizer=tokenizer, 
    speaker_prefixes=speaker_prefixes, 
    max_samples=test_samples, 
    batch_pad_id=tokenizer.pad_token_id
)

In [None]:
eval_kwargs = {'device': 'cpu', 'log_interval': 10, 'decoder_max': 20}
eval_stats = msc_summaries.evaluate(model, **eval_kwargs)