In [1]:
import os
import random
import time
import pickle
import math
from argparse import ArgumentParser
from collections import namedtuple

from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizerFast, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification

from data import Dataset
from model import Model
from util import save_checkpoint, ProgressMeter, AverageMeter, num_params
from constants import *
from predict_factuality import predict_factuality

In [2]:
# both models are bert-base-uncased and share the same tokenizer
model_string = 'patrickvonplaten/bert2bert_cnn_daily_mail'
attribute_model_string = 'textattack/bert-base-uncased-MNLI'
device = 'cuda'
verbose = True

In [3]:
tokenizer = BertTokenizerFast.from_pretrained(model_string)
print(f"Loading pre-trained model: {model_string}...")
model = AutoModelForSeq2SeqLM.from_pretrained(model_string, return_dict=True).to(device)
model.eval()

print(f"Loading pre-trained conditioning model: {attribute_model_string}...")
conditioning_model = AutoModelForSequenceClassification.from_pretrained(attribute_model_string).to(device)
conditioning_model.eval()
if verbose:
    #checkpoint = torch.load(args.ckpt, map_location=args.device)
    #print(f"=> loaded checkpoint '{args.ckpt}' (epoch {checkpoint['epoch']})")
    print(f"model num params {num_params(model)}")
    print(f"conditioning_model num params {num_params(conditioning_model)}")

Loading pre-trained model: patrickvonplaten/bert2bert_cnn_daily_mail...
Loading pre-trained conditioning model: textattack/bert-base-uncased-MNLI...
model num params 247363386
conditioning_model num params 109484547


In [4]:
inputs = []
with open('factuality_data/dummy_input.txt', 'r', encoding='utf-8') as rf:
    for line in rf:
        inputs.append(line.strip())

In [None]:
for inp in tqdm(inputs, total=len(inputs)):
    results = predict_factuality(model,
                    tokenizer, 
                    conditioning_model, 
                    [inp],
                    precondition_topk=200,
                    do_sample=False,
                    min_length=30,
                    max_length=90,
                    condition_lambda=1.0,
                    device=device)
    print(results[0])

  0%|          | 0/1 [00:00<?, ?it/s]

> [0;32m/notebooks/fudge/predict_factuality.py[0m(129)[0;36m_generate_no_beam_search[0;34m()[0m
[0;32m    127 [0;31m            [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m    128 [0;31m[0;34m[0m[0m
[0m[0;32m--> 129 [0;31m            expanded_lengths = torch.LongTensor(
[0m[0;32m    130 [0;31m                [[cur_len for _ in range(precondition_topk)] for _ in range(batch_size)]).to(scores.device)
[0m[0;32m    131 [0;31m[0;34m[0m[0m
[0m


ipdb>  tplus1_candidates.shape


torch.Size([1, 200, 169])


ipdb>  encoder_input_ids.shape


torch.Size([1, 167])


ipdb>  tplus1_candidate[0][0].shape


*** NameError: name 'tplus1_candidate' is not defined


ipdb>  tplus1_candidates[0,0,:].shape


torch.Size([169])


ipdb>  print(tplus1_candidates[0,0,:])


tensor([  101,  1996,  3578,  2003, 27234,  3620,  1006,  1015,  1010,  5757,
         2509,  3027,  1007,  4206,  1010,  2055,  1996,  2168,  4578,  2004,
         2019,  6282,  1011, 11676,  2311,  1010,  1998,  1996, 13747,  3252,
         1999,  3000,  1012,  2049,  2918,  2003,  2675,  1010,  9854,  8732,
         3620,  1006, 19151,  3027,  1007,  2006,  2169,  2217,  1012,  2076,
         2049,  2810,  1010,  1996,  1041, 13355,  2884,  3578, 15602,  1996,
         2899,  6104,  2000,  2468,  1996, 13747,  2158,  1011,  2081,  3252,
         1999,  1996,  2088,  1010,  1037,  2516,  2009,  2218,  2005,  4601,
         2086,  2127,  1996, 17714,  2311,  1999,  2047,  2259,  2103,  2001,
         2736,  1999,  4479,  1012,  2009,  2001,  1996,  2034,  3252,  2000,
         3362,  1037,  4578,  1997,  3998,  3620,  1012,  2349,  2000,  1996,
         2804,  1997,  1037,  5062,  9682,  2012,  1996,  2327,  1997,  1996,
         3578,  1999,  3890,  1010,  2009,  2003,  2085, 12283, 

ipdb>  print(tplus1_candidates[0,1,:])


tensor([  101,  1996,  3578,  2003, 27234,  3620,  1006,  1015,  1010,  5757,
         2509,  3027,  1007,  4206,  1010,  2055,  1996,  2168,  4578,  2004,
         2019,  6282,  1011, 11676,  2311,  1010,  1998,  1996, 13747,  3252,
         1999,  3000,  1012,  2049,  2918,  2003,  2675,  1010,  9854,  8732,
         3620,  1006, 19151,  3027,  1007,  2006,  2169,  2217,  1012,  2076,
         2049,  2810,  1010,  1996,  1041, 13355,  2884,  3578, 15602,  1996,
         2899,  6104,  2000,  2468,  1996, 13747,  2158,  1011,  2081,  3252,
         1999,  1996,  2088,  1010,  1037,  2516,  2009,  2218,  2005,  4601,
         2086,  2127,  1996, 17714,  2311,  1999,  2047,  2259,  2103,  2001,
         2736,  1999,  4479,  1012,  2009,  2001,  1996,  2034,  3252,  2000,
         3362,  1037,  4578,  1997,  3998,  3620,  1012,  2349,  2000,  1996,
         2804,  1997,  1037,  5062,  9682,  2012,  1996,  2327,  1997,  1996,
         3578,  1999,  3890,  1010,  2009,  2003,  2085, 12283, 

In [None]:
sequence_0 = "The company HuggingFace is based in New York City. Google is located in Mount Hill."
sequence_1 = "Apples are especially bad for your health"
sequence_2 = "HuggingFace's headquarters are situated in New York"
tokenized_data =tokenizer([sequence_0, sequence_2], [sequence_2, sequence_0], padding=True, truncation=True, max_length=512, return_tensors="pt")

In [None]:
tokenizer.decode([  101,  1996,  3578,  2003, 27234,  3620,  1006,  1015,  1010,  5757,
         2509,  3027,  1007,  4206,  1010,  2055,  1996,  2168,  4578,  2004,
         2019,  6282,  1011, 11676,  2311,  1010,  1998,  1996, 13747,  3252,
         1999,  3000,  1012,  2049,  2918,  2003,  2675,  1010,  9854,  8732,
         3620,  1006, 19151,  3027,  1007,  2006,  2169,  2217,  1012,  2076,
         2049,  2810,  1010,  1996,  1041, 13355,  2884,  3578, 15602,  1996,
         2899,  6104,  2000,  2468,  1996, 13747,  2158,  1011,  2081,  3252,
         1999,  1996,  2088,  1010,  1037,  2516,  2009,  2218,  2005,  4601,
         2086,  2127,  1996, 17714,  2311,  1999,  2047,  2259,  2103,  2001,
         2736,  1999,  4479,  1012,  2009,  2001,  1996,  2034,  3252,  2000,
         3362,  1037,  4578,  1997,  3998,  3620,  1012,  2349,  2000,  1996,
         2804,  1997,  1037,  5062,  9682,  2012,  1996,  2327,  1997,  1996,
         3578,  1999,  3890,  1010,  2009,  2003,  2085, 12283,  2084,  1996,
        17714,  2311,  2011,  1019,  1012,  1016,  3620,  1006,  2459,  3027,
         1007,  1012, 13343, 26288,  1010,  1996,  1041, 13355,  2884,  3578,
         2003,  1996,  2117, 13747,  2489,  1011,  3061,  3252,  1999,  2605,
         2044,  1996,  4971,  4887, 20596,  1012,   102,  1996,   102])