In [4]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from transformers import AutoTokenizer, AutoConfig, AutoModelForCausalLM, T5EncoderModel
from sklearn.metrics import f1_score, accuracy_score
from model_classes import *
import numpy as np
from tqdm import tqdm
import re

device = "mps"

model = torch.load('best_classifier.pth').to(device)
tokenizer = AutoTokenizer.from_pretrained('google/t5-xl-lm-adapt')

In [28]:
import json

with open('speakers_facts.json', 'r') as f:
    annotated_data = json.loads(f.read())
    
with open('sess_ann_inputs.json', 'r') as f:
    ann_inputs = json.loads(f.read())

In [29]:
remove_blank_facts = lambda x: [fact for fact in x if len(fact) >= 7 or not fact.endswith(':')]
preprocess_fact = lambda x: re.sub(r'^[\*\d]+\.?\s+', '', x)
speaker_mapper = {'facts_s1': 's1', 'facts_s2': 's2'}

annotated_data = [{
    'dialog_id': dialog['data']['dialog_id'],
    'session': dialog['data']['session'],
    'dialog': dialog['data']['dialogue'], 
    'correct_facts': {speaker_mapper[facts['from_name']]: remove_blank_facts([preprocess_fact(fact) for fact in facts['value']['choices']])
                      for facts in dialog['annotations'][0]['result'] if facts['from_name'] in speaker_mapper.keys()},
    'facts': {'s1': remove_blank_facts([preprocess_fact(fact['value']) for fact in dialog['data']['s1']]), 
              's2': remove_blank_facts([preprocess_fact(fact['value']) for fact in dialog['data']['s2']])},
    'summary': dialog['data']['summary']
        }
    for dialog in annotated_data]

In [30]:
# remove_blank_facts = lambda x: [fact for fact in x if fact]
# speaker_mapper = {'facts_s1': 's1', 'facts_s2': 's2'}

# annotated_data = [{
#     'dialog_id': dialog['data']['dialog_id'],
#     'dialogue': dialog['data']['dialogue'], 
#     'correct_facts': {speaker_mapper[facts['from_name']]: facts['value']['choices'] 
#                       for facts in dialog['annotations'][0]['result']},
#     'facts': {'s1': remove_blank_facts([fact['value'] for fact in dialog['data']['s1']]), 
#               's2': remove_blank_facts([fact['value'] for fact in dialog['data']['s2']])}
#         }
#     for dialog in annotated_data]

In [31]:
def convert_dialog(dialog):
    return '\n'.join([f"{phrase['author']}: {phrase['text']}" for phrase in dialog])

In [32]:
def fact_dialog_concat(speaker, fact, dialog):
    return '\n\n'.join(['\n'.join([f'Fact about {speaker}:', fact]), dialog,])

In [33]:
def predict_actual_facts(facts, model):
    inputs = tokenizer(
            facts, 
            return_tensors="pt",
            truncation=True,
            padding=True,
            max_length=2048,
        ).to(model.device)
    
    with torch.no_grad():
        outputs = F.softmax(model(inputs, inputs.attention_mask).detach().cpu(), dim=1)[:, 1] > .65
        
    return outputs.tolist()

In [34]:
def choice_actual_facts(speaker, facts, dialog, model):
    facts = np.array(facts.copy())
    input_data = [fact_dialog_concat(speaker, fact, convert_dialog(dialog)) for fact in facts]
    actual_facts = predict_actual_facts(input_data, model)
    return facts[actual_facts].tolist()

In [35]:
annotated_data = {(data['dialog_id'], data['session']): data for data in annotated_data}

In [36]:
speakers = {'s1': 'bot_0', 's2': 'bot_1'}

extract_facts = lambda facts: [fact['value'] for fact in facts]
extract_speaker_phrases = lambda speaker, dialog: [{'author': phrase['author'], 'text': phrase['text']} 
                                                   for phrase in dialog if phrase['author'] == speaker]

ann_inputs = [
    {
        **ann, 
        **{
            speaker: [{'value': fact} for fact in remove_blank_facts([preprocess_fact(fact['value']) for fact in ann[speaker]])] 
            for speaker in speakers.keys()
        }
    } 
    for ann in ann_inputs
]

In [14]:
predictions = [
    {'data': data,
     'predictions': [
         {'result': [result for result in [
         {
             'from_name': 'facts_' + f_speaker, 
             'to_name': 'chat',
             'value': {'choices': choice_actual_facts(b_speaker, extract_facts(data[f_speaker]), 
                                                      extract_speaker_phrases(b_speaker, data['dialogue']), 
                                                      model)},
             'type': 'choices',
         }
         for f_speaker, b_speaker in speakers.items()] if result['value']['choices']]}
     ]
    } 
    for data in tqdm(ann_inputs)
]

100%|██████████| 174/174 [05:36<00:00,  1.93s/it]


In [38]:
for prediction in predictions:
    prediction['data']['is_annotated'] = (prediction['data']['dialog_id'], prediction['data']['session']) in annotated_data
    
    if (prediction['data']['dialog_id'], prediction['data']['session']) in annotated_data:
        prediction['predictions'] = [
            {'result': [
                    {
                        'from_name': 'facts_' + speaker, 'to_name': 'chat', 'value': {'choices': facts}, 'type': 'choices'
                    }
                for speaker, facts in annotated_data[(prediction['data']['dialog_id'], prediction['data']['session'])]['correct_facts'].items()
            ] 
            }
        ]

In [116]:
# predictions = [
#     {'data': {'dialog': data['dialog'], **{speaker: [{'value': fact} 
#                                                      for fact in facts] 
#                                            for speaker, facts in data['facts'].items()}},
#      'predictions': [
#          {'result': [result for result in [
#          {
#              'from_name': 'facts_' + speaker, 
#              'to_name': 'chat',
#              'value': {'choices': choice_actual_facts(speaker, facts, data['dialog'], model)},
#              'type': 'choices',
#          }
#          for speaker, facts in data['facts'].items()] if result['value']['choices']]}
#      ]
#     } 
#     for data in sample_dialogs
# ]

In [40]:
with open('predictions_vicuna.json', 'w') as f:
    json.dump(predictions, f)