In [1]:
from argparse import Namespace
import random
import logging
from pprint import pformat
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, GPT2Config

from utils import *
print(torch.cuda.get_device_name(torch.cuda.current_device()))

Tesla P100-PCIE-12GB


In [2]:
# inference args
args = Namespace(
    # fill in where you have stored the checkpoint information and file
    model_checkpoint_dir="/ssd003/home/shirleyw/ConversationalAI/dialomed_checkpoints/",
    model_checkpoint_file="/checkpoint/shirleyw/dialogpt_med_marco/checkpoint_epoch5_step587435.pth",
    max_history=10,
    no_sample=True,
    max_length=100,
    min_length=1,
    seed=39,
    temperature=0.7,
    top_k=100,
    top_p=0.,  # I recommend setting this to 0 so its more likely to not say "I don't know"
    device=("cuda" if torch.cuda.is_available() else "cpu"),
    force_answer=False  # discard any "I don't know"s and take the next best prediction
)

args

Namespace(device='cuda', force_answer=False, max_history=10, max_length=100, min_length=1, model_checkpoint_dir='/ssd003/home/shirleyw/ConversationalAI/dialomed_checkpoints/', model_checkpoint_file='/checkpoint/shirleyw/dialogpt_med_marco/checkpoint_epoch5_step587435.pth', no_sample=True, seed=39, temperature=0.7, top_k=100, top_p=0.0)

In [3]:
# some basics
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("TestingGPT2")
logger.info(pformat(args))

random.seed(args.seed)
torch.random.manual_seed(args.seed)
torch.cuda.manual_seed(args.seed)

INFO:TestingGPT2:Namespace(device='cuda', force_answer=False, max_history=10, max_length=100, min_length=1, model_checkpoint_dir='/ssd003/home/shirleyw/ConversationalAI/dialomed_checkpoints/', model_checkpoint_file='/checkpoint/shirleyw/dialogpt_med_marco/checkpoint_epoch5_step587435.pth', no_sample=True, seed=39, temperature=0.7, top_k=100, top_p=0.0)


In [5]:
logger.info("Get pretrained model and tokenizer")

# Initializing GPT2 Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained(args.model_checkpoint_dir)

# Initializing pretrained model
config = GPT2Config.from_json_file(args.model_checkpoint_dir + 'config.json')
state_dict = torch.load(args.model_checkpoint_file)
if 'model' in state_dict:
    state_dict = state_dict['model']
model = GPT2LMHeadModel.from_pretrained(args.model_checkpoint_file, config=config, state_dict=state_dict)
del state_dict
model.to(args.device)
model.eval()

# add our special tokens to the model
add_special_tokens(model, tokenizer)

INFO:TestingGPT2:Get pretrained model and tokenizer
Some weights of the model checkpoint at /checkpoint/shirleyw/dialogpt_med_marco/checkpoint_epoch5_step587435.pth were not used when initializing GPT2LMHeadModel: ['multiple_choice_head.summary.weight', 'multiple_choice_head.summary.bias']
- This IS expected if you are initializing GPT2LMHeadModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing GPT2LMHeadModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [6]:
import argparse
from DataLoader.dataloader import *

dataargs = {
    "output_name": "dialomed_marco_epoch5_evalresults.json",
    "log_dir": "/ssd003/home/shirleyw/ConversationalAI/dialomed_checkpoints/",
    "checkpoint": "/checkpoint/shirleyw/dialogpt_med_marco/checkpoint_epoch5_step587435.pth",
    "valid_dataset_path": "/ssd003/projects/aieng/conversational_ai/data/MSMARCO/marco_valid_tokenized.json",
    "dataset_type": "marco",
    "device": "cuda",
    "max_len": 1024,
    "valid_batch_size": 1,
    "second_loss": "mc"
}
dataargs = argparse.Namespace(**dataargs)

valid_loader = get_validation_dataloader(dataargs, tokenizer)

Building datasets
Loading Dataset
Building inputs and labels


100661it [00:43, 2330.41it/s]

Building dataloaders
Valid Dataset Length: 60100





In [23]:
inference_args = argparse.Namespace(
    no_sample=True,
    max_length=200,
    min_length=1,
    seed=39,
    temperature=0.7,
    top_k=100,
    top_p=0.,
    device=args.device,
    force_answer=False
)

In [7]:
for step, data in enumerate(valid_loader):
    break

In [20]:
gt_tokens = list(data['input_ids'][0, int(data['mc_labels'][0]), int(data['reply_start'][0]):])
input_ids = list(data['input_ids'][0, int(data['mc_labels'][0]), :int(data['reply_start'][0])])
token_type_ids = list(data['token_type_ids'][0, int(data['mc_labels'][0]), :int(data['reply_start'][0])])

In [21]:
tokenizer.decode(input_ids)

"<bos> 1: a government-owned corporation (as a utility or railroad) engaged in a profit-making enterprise that may require the exercise of powers unique to government (as eminent domain) — called also government corporation, publicly held corporationExamples of corporation in a Sentence. 1  He works as a consultant for several large corporations. 2  a substantial corporation that showed that he was a sucker for all-you-can-eat buffets.McDonald's Corporation is one of the most recognizable corporations in the world. A corporation is a company or group of people authorized to act as a single entity (legally a person) and recognized as such in law. Early incorporated entities were established by charter (i.e. by an ad hoc act granted by a monarch or passed by a parliament or legislature).Corporation definition, an association of individuals, created by law or under authority of law, having a continuous existence independent of the existences of its members, and powers and liabilities dist

In [22]:
tokenizer.decode(gt_tokens)

'<speaker2> A corporation is a company or group of people authorized to act as a single entity and recognized as such in law. <eos>'

In [24]:
with torch.no_grad():
    output = sample_sequence_tokens(input_ids, token_type_ids, tokenizer, model, inference_args)

In [26]:
tokenizer.decode(output)

'<speaker2> A corporation is an association of individuals, created by law or under authority of law.'

In [7]:
context = ["""If your card is damaged, you will be required to pay the $20 replacement fee (debit or credit only). Damage can include use of stickers, hole punching, significant wear on the card within a short period of time, damage to the magnetic stripe or image. A new photo will be taken. Meal plan, library, building access, and athletic services will be active on your new card within 24 hours, or the Monday following a card replacement on Friday.
"""]
           
print(context)
#My TCard is damaged, do I have to pay for a replacement?

['If your card is damaged, you will be required to pay the $20 replacement fee (debit or credit only). Damage can include use of stickers, hole punching, significant wear on the card within a short period of time, damage to the magnetic stripe or image. A new photo will be taken. Meal plan, library, building access, and athletic services will be active on your new card within 24 hours, or the Monday following a card replacement on Friday.\n']


In [None]:
# how many times do you want to ask questions
num_times = 10

for i in range(num_times):
    raw_text = input(">>> ")
    while not raw_text:
        print('Prompt should not be empty!')
        raw_text = input(">>> ")
    history = [tokenizer.encode('<speaker1>' + raw_text)]
    with torch.no_grad():
        out_ids = sample_sequence(context, history, tokenizer, model, args)
    out_text = tokenizer.decode(out_ids, skip_special_tokens=True)

    print("Out text:", out_text)

>>> My TCard is damaged, do I have to pay for a replacement?
Out text: If your card is damaged, you will be required to pay the $20 replacement fee. If it is defective, you will be required to pay the $20 replacement fee within a short period of time. A new photo will be taken. Meal plan, library, building access, and athletic services will be active on your new card within 24 hours, or the Monday following a card replacement on Friday. A few tips on taking care of your TCard are as follows: Your TCard should be carried


In [6]:
# Get F1 Score of prediction

context = ["""If your card is damaged, you will be required to pay the $20 replacement fee (debit or credit only). Damage can include use of stickers, hole punching, significant wear on the card within a short period of time, damage to the magnetic stripe or image. A new photo will be taken. Meal plan, library, building access, and athletic services will be active on your new card within 24 hours, or the Monday following a card replacement on Friday.
"""]
context = [tokenizer.encode(x) for x in context]

history = "When will I get building access on my new card?"
history = [tokenizer.encode('<speaker1>' + history)]

answer = "You will get building access on your new card within 24 hours, or the Monday following a card replacement on Friday."
answer = tokenizer.encode('<speaker2>' + answer)

data = build_input_from_segments(context, history, answer, tokenizer, True, True)
tokenizer.decode(data['input_ids'])

In [8]:
# pass through model
input_ids = torch.Tensor(data['input_ids']).type(torch.LongTensor).to(args.device)
lm_labels = torch.Tensor(data['lm_labels']).type(torch.LongTensor).to(args.device)
token_type_ids = torch.Tensor(data['token_type_ids']).type(torch.LongTensor).to(args.device)
start_index = torch.tensor(data['start_idx'])
end_index = torch.tensor(data['mc_token_ids'])

output = model(
    input_ids, token_type_ids=token_type_ids
)

lm_logits = output.logits

In [23]:
print("Actual Answer:", tokenizer.decode(input_ids[start_index+1:end_index+1]))
print("Rough Prediction:", tokenizer.decode(torch.argmax(lm_logits[start_index:end_index], dim=1)))

get_f1_score(input_ids, lm_logits, torch.tensor(data['start_idx']), torch.tensor(data['mc_token_ids']), tokenizer)

Actual Answer: You will get building access on your new card within 24 hours, or the Monday following a card replacement on Friday. <eos>
Prediction: You will get building access on your new card within 24 hours. or the Monday following a card replacement on Friday. <eos>


0.85

Note that F1 Score is prone to spitting out a much higher number than what should be in our scenario. For the calculation of F1 Score, I'm just taking the argmax for each position in the logits, where as an actual language generation would generate the next token one at a time. Because of that, during response generation time, it's very common to predict an "eos" following a period, but that's not visible here when just taking the argmax, and what would have been predicted provided no eos is visible. So there's a higher chance of accurate tokens appearing here in the F1 Score than won't get predicted during actual response generation time.