In [1]:
import torch

from transformers import AutoTokenizer
from models.bart_extractor import BartExtractor, BART_BASE
from dataset.msc_summary_turns import MSC_Turns
from dataset.msc_summary import MSC_Summaries

import utils.logging as logging

In [2]:
logging.set_log_level(logging.SPAM)

# Settings for dataset
datadir = '/Users/FrankVerhoef/Programming/PEX/data/'
basedir = 'msc/msc_personasummary/'
sessions = [1, 2, 3, 4]
len_context = 2
speaker_prefixes = ["<self>", "<other>"]
nofact_token = '<nofact>'
add_tokens = speaker_prefixes + [nofact_token]
test_samples = 20

# Settings for model
checkpoint_dir = '/Users/FrankVerhoef/Programming/PEX/checkpoints/'
load = 'trained_bart'

# Setup
tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")
if add_tokens is not None:
    num_added_toks = tokenizer.add_tokens(add_tokens)
nofact_token_id = tokenizer.convert_tokens_to_ids(nofact_token) if nofact_token != '' else tokenizer.eos_token_id
assert nofact_token_id != tokenizer.unk_token_id, "nofact_token '{}' must be known token".format(nofact_token)

model = BartExtractor(bart_base=BART_BASE, nofact_token_id=nofact_token_id)
model.bart.resize_token_embeddings(len(tokenizer))

dataset_config = {
    'basedir': datadir + basedir,
    'sessions': sessions,
    'tokenizer': tokenizer,
    'len_context': len_context,
    'speaker_prefixes': speaker_prefixes,
    'nofact_token': nofact_token,
    'batch_format': 'huggingface',
    'batch_pad_id': tokenizer.pad_token_id
} 
msc_turns = MSC_Turns(subset='test', max_samples=test_samples, **dataset_config)

logging.info("Loading model from {}".format(checkpoint_dir + load))
model.load_state_dict(torch.load(checkpoint_dir + load, map_location=torch.device('cpu')))

2023-04-09 21:51:53,631 INFO     | Loading model from /Users/FrankVerhoef/Programming/PEX/checkpoints/trained_bart


<All keys matched successfully>

In [3]:
for i in range(3):
    print(msc_turns[i])

('<self> That sounds really fun! Have you been to any concerts lately? <other> No, it has been awhile. We do have one coming up in town in a few months. Just waiting for the tickets to go on sale. My oldest granddaughter asked if we could go. So wife and I will be taking her as a birthday gift.', 'I enjoy celebrating milestones with my grandchildren.')
("<self> That's so cool! I am glad you're considerate of others food choices, I think I could learn a thing or two from you! Would you like to cook together at some point to see if we could put our minds together to make something interesting and new? <other> Oh I love the idea. Let me know when you are thinking and I will be there! I am sure we could come up with some great dishes! ", 'I live near Speaker 2.')
("<self> Ok, thank you. I will keep that in mind. Any tips are more than welcome. I am scared of having an accident that anything that can help avoid nd make me better is much appreciated. After your accident did it take long to g

In [4]:
eval_kwargs = {'device': 'cpu', 'log_interval': 10, 'decoder_max': 20}

logging.info("Evaluating model on {} samples of testdata in {} with arguments {}".format(len(msc_turns), basedir, eval_kwargs))
eval_stats = msc_turns.evaluate(model, **eval_kwargs)

2023-04-09 21:52:00,940 INFO     | Evaluating model on 20 samples of testdata in msc/msc_personasummary/ with arguments {'device': 'cpu', 'log_interval': 10, 'decoder_max': 20}
2023-04-09 21:52:01,698 SPAM     | Generate: pred_fact=tensor([True])
2023-04-09 21:52:01,699 SPAM     | Generate: gen_out=tensor([[    2,     0,     0,     0,   100,    33,    10, 21002,     4,     2]])
context:     <self> That sounds really fun! Have you been to any concerts lately? <other> No, it has been awhile. We do have one coming up in town in a few months. Just waiting for the tickets to go on sale. My oldest granddaughter asked if we could go. So wife and I will be taking her as a birthday gift.
target:      I enjoy celebrating milestones with my grandchildren.
prediction:  I have a granddaughter.
----------------------------------------
2023-04-09 21:52:02,241 SPAM     | Generate: pred_fact=tensor([False])
2023-04-09 21:52:02,242 SPAM     | Generate: gen_out=tensor([[    2,     0, 50267,     4,     2]

In [5]:
msc_summaries = MSC_Summaries(
    basedir=datadir + basedir, 
    sessions=sessions, 
    subset='test', 
    tokenizer=tokenizer, 
    speaker_prefixes=speaker_prefixes, 
    max_samples=test_samples, 
    batch_pad_id=tokenizer.pad_token_id
)

In [6]:
for i in range(3):
    print(msc_summaries[i])

(['<self> What kind of car do you want to buy? <other> I have always wanted a Jeep, they seem like so much fun to drive, and there are so many things you can do to customize them.  ', "<self> Nice! I love my Jeep.  What kinds of customizations would you want to make? <other> I'm not exactly sure,  I have always liked the idea of lifting it up a little bit with some nice tires.  I also like the idea of dropping the top and taking of the doors to feel the breeze when you drive it.  Does your Jeep have a hard or soft top?", '<self> Mine has a hard top.  I like to be protected from damage in case of an accident. <other> That is smart thinking.  Do you ever take off the top or doors?  Do you have any other customizations on your Jeep?  Do you use it to tow or haul anything?', "<self> I keep the tops and doors on because they make me feel safer.  I'm trying to decide whether or not I want to make any customizations.  I've never had to tow anything before and I'm hoping I don't have to.  I'm 

In [7]:
eval_kwargs = {'device': 'cpu', 'log_interval': 10, 'decoder_max': 20}
eval_stats = msc_summaries.evaluate(model, **eval_kwargs)

2023-04-09 21:52:36,508 SPAM     | Generate: pred_fact=tensor([ True,  True, False, False,  True, False])
2023-04-09 21:52:36,510 SPAM     | Generate: gen_out=tensor([[    2,     0,     0,     0,   100,    33,   460,   770,    10, 16932,
             4,     2,     1,     1,     1,     1,     1,     1],
        [    2,     0,     0,     0, 50267,    34,    10,   543,    50,  3793,
           299,     4,     2,     1,     1,     1,     1,     1],
        [    2,     0, 50267,    34,    10,   543,   299,     4,     2,     1,
             1,     1,     1,     1,     1,     1,     1,     1],
        [    2,     0, 50267,   219,     4, 50267,     4,   175,     4,     2,
             1,     1,     1,     1,     1,     1,     1,     1],
        [    2,     0,     0,     0,   100,    33,    57,  2053,    59,   562,
            10,  7261,     4,    38,   923, 13677,     4,     2],
        [    2,     0, 50267,  1677,    32,  1406,    15,     5,   921,     4,
             2,     1,     1,     1, 

In [8]:
s1 = "<self> Are you settling in the city at all or do you still really miss the country? <other> I am settling in, but I really miss it."
s2 = "<self> Are you settling in the city? <other> No, I really miss it."
encoded_utterances = tokenizer(text=[s1, s2], return_tensors='pt', padding=True)
encoded_utterances


{'input_ids': tensor([[    0, 50265,  3945,    47, 15433,    11,     5,   343,    23,    70,
            50,   109,    47,   202,   269,  2649,     5,   247,   116,  1437,
         50266,    38,   524, 15433,    11,     6,    53,    38,   269,  2649,
            24,     4,     2],
        [    0, 50265,  3945,    47, 15433,    11,     5,   343,   116,  1437,
         50266,   440,     6,    38,   269,  2649,    24,     4,     2,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0]])}

In [9]:
input_ids = torch.cat([encoded_utterances['input_ids'], torch.ones(20, dtype=torch.long).view(2, 10)], dim=1)
attn_mask = torch.cat([encoded_utterances['attention_mask'], torch.zeros(20, dtype=torch.long).view(2,10)], dim=1)
pred_tokens_2 = model.generate(
    input_ids=input_ids.to('cpu'), 
    attention_mask=attn_mask.to('cpu'),
    min_length=2,
    max_new_tokens=20, 
    num_beams=1,
    do_sample=False,
)
tokenizer.batch_decode(pred_tokens_2)

2023-04-09 21:56:12,731 SPAM     | Generate: pred_fact=tensor([True, True])
2023-04-09 21:56:12,732 SPAM     | Generate: gen_out=tensor([[   2,    0,    0,    0,  100,  524,   11,    5,  343,    4,   38, 2649,
            5,  247,    4,    2],
        [   2,    0,    0,    0,  100, 2649, 1207,   11,    5,  343,    4,    2,
            1,    1,    1,    1]])


['</s><s><s><s>I am in the city. I miss the country.</s>',
 '</s><s><s><s>I miss living in the city.</s><pad><pad><pad><pad>']

In [10]:
encoded_utterances

{'input_ids': tensor([[    0, 50265,  3945,    47, 15433,    11,     5,   343,    23,    70,
            50,   109,    47,   202,   269,  2649,     5,   247,   116,  1437,
         50266,    38,   524, 15433,    11,     6,    53,    38,   269,  2649,
            24,     4,     2],
        [    0, 50265,  3945,    47, 15433,    11,     5,   343,   116,  1437,
         50266,   440,     6,    38,   269,  2649,    24,     4,     2,     1,
             1,     1,     1,     1,     1,     1,     1,     1,     1,     1,
             1,     1,     1]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
         1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0,
         0, 0, 0, 0, 0, 0, 0, 0, 0]])}

In [11]:
pred_tokens = model.generate(
    input_ids=encoded_utterances['input_ids'].to('cpu'), 
    attention_mask=encoded_utterances['attention_mask'].to('cpu'),
    min_length=2,
    max_new_tokens=20, 
    num_beams=1,
    do_sample=False,
)

2023-04-09 21:56:27,299 SPAM     | Generate: pred_fact=tensor([True, True])
2023-04-09 21:56:27,300 SPAM     | Generate: gen_out=tensor([[   2,    0,    0,    0,  100,  524,   11,    5,  343,    4,   38, 2649,
            5,  247,    4,    2],
        [   2,    0,    0,    0,  100, 2649, 1207,   11,    5,  343,    4,    2,
            1,    1,    1,    1]])


In [12]:
pred_tokens

tensor([[   2,    0,    0,    0,  100,  524,   11,    5,  343,    4,   38, 2649,
            5,  247,    4,    2],
        [   2,    0,    0,    0,  100, 2649, 1207,   11,    5,  343,    4,    2,
            1,    1,    1,    1]])

In [13]:
tokenizer.batch_decode(pred_tokens)

['</s><s><s><s>I am in the city. I miss the country.</s>',
 '</s><s><s><s>I miss living in the city.</s><pad><pad><pad><pad>']

In [14]:
model.bart.config.pad_token_id

1