In [1]:
import os
import random
import time
import pickle
import math
from argparse import ArgumentParser
from collections import namedtuple

from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizerFast, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification

from data import Dataset
from model import Model
from util import save_checkpoint, ProgressMeter, AverageMeter, num_params
from constants import *
from predict_factuality import predict_factuality

In [2]:
model_string = 'patrickvonplaten/bert2bert_cnn_daily_mail'
attribute_model_string = 'textattack/bert-base-uncased-MNLI'
device = 'cuda'
verbose = True

In [3]:
tokenizer = BertTokenizerFast.from_pretrained(model_string)
print(f"Loading pre-trained model: {model_string}...")
model = AutoModelForSeq2SeqLM.from_pretrained(model_string, return_dict=True).to(device)
model.eval()

print(f"Loading pre-trained conditioning model: {attribute_model_string}...")
conditioning_model = AutoModelForSequenceClassification.from_pretrained(attribute_model_string).to(device)
conditioning_model.eval()
if verbose:
    #checkpoint = torch.load(args.ckpt, map_location=args.device)
    #print(f"=> loaded checkpoint '{args.ckpt}' (epoch {checkpoint['epoch']})")
    print(f"model num params {num_params(model)}")
    print(f"conditioning_model num params {num_params(conditioning_model)}")

Loading pre-trained model: patrickvonplaten/bert2bert_cnn_daily_mail...
Loading pre-trained conditioning model: textattack/bert-base-uncased-MNLI...
model num params 247363386
conditioning_model num params 109484547


In [4]:
inputs = []
with open('factuality_data/dummy_input.txt', 'r', encoding='utf-8') as rf:
    for line in rf:
        inputs.append(line.strip())

In [5]:
for inp in tqdm(inputs, total=len(inputs)):
    results = predict_factuality(model,
                    tokenizer, 
                    conditioning_model, 
                    [inp],
                    precondition_topk=200,
                    do_sample=False,
                    min_length=30,
                    max_length=90,
                    condition_lambda=0.0,
                    device=device)
    print(results[0])

100%|██████████| 1/1 [00:00<00:00,  1.14it/s]

the eiffel tower is the tallest structure in paris. it is the second tallest structure to reach a height of 300 metres. the tower is now taller than the chrysler building. [SEP]



