In [1]:
import os
import json
import torch
from transformers import BartTokenizer
from model.modeling_bart import BartForConditionalGeneration
from eval_utils import create_encoder_input, create_decoder_input, pad_dataset


# Model setup

In [2]:
checkpoint_dir = 'checkpoints_persona_chat_peacok_random_induced_original'
tokenizer_checkpoint = checkpoint_dir
model_checkpoint = os.path.join(checkpoint_dir, 'checkpoint_25_lr_7.1666220328667066e-06')
device = 'cuda:5'

In [3]:
tokenizer = BartTokenizer.from_pretrained(tokenizer_checkpoint)
query_id, res_id, latent_id, persona_id, partner_id = \
    tokenizer.convert_tokens_to_ids([
        '<query>', '<response>', '<latent>', '<persona>', '<partner>'
    ])
sep_id = tokenizer.sep_token_id
eos_id = tokenizer.eos_token_id

model_checkpoint = BartForConditionalGeneration.from_pretrained(model_checkpoint)
model = model_checkpoint.to(device)
model = model.eval()


# Dialogue Data

In [6]:
dialogues_path = '/home/cutura/mnt/home/cutura/bart_peacok/data/persona_peacok/valid_persona_original_chat_ext_random_induced.json'
dialogues = json.load(open(dialogues_path))

len(dialogues)

1000

In [7]:
dialogue_id = '17878'
dialog = dialogues[dialogue_id]
dialog.keys()

dict_keys(['persona1_ori', 'persona1_ext', 'persona2_ext', 'text', 'text_plain', 'text_plain_cands'])

In [8]:
dialog['persona1_ori'] + dialog['persona1_ext']

['i read twenty books a year.',
 "i'm a stunt double as my second job.",
 'i only eat kosher.',
 'i was raised in a single parent household.',
 'i am a publisher, here is my character trait, love of books.',
 'i am a single mother, here is what i regularly or consistently do, works two jobs.',
 'i am a postalworker, here is my character trait, hard working.',
 'i am a rancher, here is my character trait, hard working.',
 'i am a stunt man, here is what i regularly or consistently do related to other people or social groups, trusts my life to the stunt coordinator.']

In [9]:
list(dialog['persona2_ext'].values())[-1]

['i am a homemaker. here is what i regularly or consistently do: stay at home with the kids',
 'i am a housewife. here is what i regularly or consistently do: stay at home to raise children',
 'i am a migrant worker. here is what i did in the past: worked in the fields',
 'i am a meterologist. here is my character trait: loves weather',
 'i am a director. here is what i regularly or consistently do: films movies']

In [10]:
history_full = []
for turn_data in dialog['text_plain']:
    query, response = turn_data.split('\t')
    history_full.append(query)
    history_full.append(response)
history_full

['hello what are doing today ?',
 'i am good , i just got off work and tired , i have two jobs .',
 'i just got done watching a horror movie',
 "i rather read , i've read about 20 books this year .",
 'wow ! i do love a good horror movie . loving this cooler weather',
 'but a good movie is always good .',
 'yes ! my son is in junior high and i just started letting him watch them too',
 'i work in the movies as well .',
 'neat ! ! i used to work in the human services field',
 'yes it is neat , i stunt double , it is so much fun and hard work .',
 'yes i bet you can get hurt . my wife works and i stay at home',
 'nice , i only have one parent so now i help out my mom .',
 'i bet she appreciates that very much .',
 'she raised me right , i am just like her .',
 'my dad was always busy working at home depot',
 'now that i am older home depot is my toy r us .']

# Generating responses

In [11]:
def text_to_token_ids(text):
    return [
        tokenizer.convert_tokens_to_ids(tokenizer.tokenize(l, add_prefix_space=True))
        for l in text
    ]

def generate_response(persona, partner, history, max_length=50, beam=2):
    persona_token_ids = text_to_token_ids(persona)
    partner_token_ids = text_to_token_ids(partner)
    history_token_ids = text_to_token_ids(history)

    input_ids, attention_mask, _, _ = create_encoder_input(
        persona_token_ids, partner_token_ids, history_token_ids,
        query_id, res_id, latent_id, persona_id, partner_id, sep_id, eos_id
    )
    tensor_input_ids = torch.tensor(input_ids, device=device).unsqueeze(0)
    tensor_attention_mask = torch.tensor(attention_mask, device=device).unsqueeze(0)

    with torch.no_grad():
        out_ids = model_checkpoint.generate(
            input_ids=tensor_input_ids,
            attention_mask=tensor_attention_mask,
            max_length=max_length,
            num_beams=beam
        )
    out_text = tokenizer.batch_decode(
        out_ids,
        skip_special_tokens=True, 
        spaces_between_special_tokens=False,
        clean_up_tokenization_spaces=False
    )

    answer = out_text[0].strip()
    return answer


In [41]:
persona = [
    'i read twenty books a year.',
    "i'm a stunt double as my second job.",
    'i only eat kosher.',
    'i was raised in a single parent household.',
    'i am a publisher, here is my character trait, love of books.',
    'i am a single mother, here is what i regularly or consistently do, works two jobs.',
    'i am a postalworker, here is my character trait, hard working.',
    'i am a rancher, here is my character trait, hard working.',
    'i am a stunt man, here is what i regularly or consistently do related to other people or social groups, trusts my life to the stunt coordinator.'
]

partner = [
    'i am a homemaker. here is what i regularly or consistently do: stay at home with the kids',
    'i am a housewife. here is what i regularly or consistently do: stay at home to raise children',
    'i am a migrant worker. here is what i did in the past: worked in the fields',
    'i am a meterologist. here is my character trait: loves weather',
    'i am a director. here is what i regularly or consistently do: films movies'
]


In [26]:
num_turns = 4

history = history_full[ : 2 * num_turns - 1]
gt_response = history_full[2 * num_turns]

for l in history:
    print(l)
print()
print(gt_response)

hello what are doing today ?
i am good , i just got off work and tired , i have two jobs .
i just got done watching a horror movie
i rather read , i've read about 20 books this year .
wow ! i do love a good horror movie . loving this cooler weather
but a good movie is always good .
yes ! my son is in junior high and i just started letting him watch them too

neat ! ! i used to work in the human services field


## Cherry-picking facts

It's able to pick the facts but only if they're highly related to the history.

In [28]:
generate_response(persona, partner, history)

"that is great , i was raised by a single parent so i don't have kids ."

In [39]:
persona = []
partner = [
    "i'm a stunt double as my second job.",
]

generate_response(persona, partner, history)

'that is great , i have a son too .'

In [43]:
persona = []
partner = [
    'i am a director. here is what i regularly or consistently do: films movies',
    "i have a daughter who is in kindergarden.",
]

generate_response(persona, partner, history)

"i have a daughter , she's in kindergarten ."

## Formulating the facts

* The longer one just gets rejected, probably confused it with the two jobs thing
* The truncated one leads to a good answer

In [50]:
persona = [
    'i am a single mother, here is what i regularly or consistently do, works two jobs.',
]
partner = []

generate_response(persona, partner, history)

'that is great , i have a son too .'

In [51]:
persona = [
    'i am a single mother',
]
partner = []

generate_response(persona, partner, history)

'that is great , i am a single mom so it is hard .'

## Revisiting history

What happens if I have a fact that revisits one of the first responses?

* The stronger thing prevails - fluency with the latest query

In [46]:
persona = [
    "i'm a stunt double as my second job.",
]
partner = []

generate_response(persona, partner, history)

'that is great , i have a son too .'

In [47]:
persona = [
    "i'm a stunt double as my second job.",
    "my daughter is an actress",
    "i read books all the time"
]
partner = [
]

generate_response(persona, partner, history)

"my daughter is in junior high , she's an actress ."

# Interactive stuff

In [75]:
persona = [
    "i'm a stunt double as my second job.",
    "my daughter is an actress",
    "i read books all the time"
]
# partner = [
#     "i am a basketball player.",
#     "my steaks are the best ever."
# ]
partner = [
    "i am an astronaut"
]

history = [
    "hey how is it going?"
]

response = generate_response(persona, partner, history)
response

'it is going well . how about you ?'

In [76]:
history += [response]
history += ["also well. just got back from a basketball game."]

response = generate_response(persona, partner, history)
response

'i just got back from a basketball game'

In [77]:
history += [response]
history += ["what do you do for a living?"]

response = generate_response(persona, partner, history)
response

'i am a stunt double'

In [78]:
history += [response]
history += ["that's very cool. i am a chef"]

response = generate_response(persona, partner, history)
response

'what do you do for a living'

# Is partner persona important?

In [79]:
import os
persona_chat_dir = '/home/cutura/mnt/home/cutura/bart_peacok/data/persona_peacok'
os.listdir(persona_chat_dir)

['.DS_Store',
 'add_extend_peacok_facts.py',
 'persona_extend_full_original.json',
 'persona_extend_full_revised.json',
 'preparing_data.ipynb',
 'train_persona_original_chat_convai2.json',
 'train_persona_original_chat_ext.json',
 'train_persona_revised_chat_convai2.json',
 'train_persona_revised_chat_ext.json',
 'valid_persona_original_chat_convai2.json',
 'valid_persona_original_chat_ext.json',
 'valid_persona_revised_chat_convai2.json',
 'valid_persona_revised_chat_ext.json',
 'convai2_format',
 'valid_persona_original_chat_ext_retrieved.json',
 'train_persona_original_chat_ext_retrieved.json',
 'valid_persona_original_chat_ext_induced.json',
 'train_persona_original_chat_ext_induced.json',
 'valid_persona_original_chat_ext_random_induced.json',
 'train_persona_original_chat_ext_random_induced.json']

In [90]:
persona_chat_dir = '/home/cutura/mnt/home/cutura/bart_peacok/data/persona_peacok'
persona_chat_filenames = [
    'valid_persona_original_chat_convai2.json',
    'train_persona_original_chat_convai2.json',
]
filepath = os.path.join(persona_chat_dir, persona_chat_filenames[1])

dialogues = json.load(open(filepath))

In [93]:
def get_history(dialogue):
    history_full = []
    for turn_data in dialogue['text_plain']:
        query, response = turn_data.split('\t')
        history_full.append(query)
        history_full.append(response)
    return history_full


line_cnt, arrogant_cnt, engaging_cnt = 0, 0, 0
for dialogue_id, dialogue in dialogues.items():
    history = get_history(dialogue)
    for i, line in enumerate(history):
        if i % 2 == 0:
            # Consider only responses since they act as labels during training
            continue
        line = line.lower()
        line_cnt += 1
        if (
            'you' in line or
            '?' in line
        ):
            engaging_cnt += 1
            continue
        if (
            line.startswith('i ') or
            ' i ' in line or
            'me' in line or
            'my' in line
        ):
            arrogant_cnt += 1

print(line_cnt, arrogant_cnt, engaging_cnt)
    

131438 60854 56897
