In [1]:
import os
import random
import time
import pickle
import math
from argparse import ArgumentParser
from collections import namedtuple

from tqdm import tqdm
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import BertTokenizerFast, AutoModelForSeq2SeqLM, AutoModelForSequenceClassification

from data import Dataset
from model import Model
from util import save_checkpoint, ProgressMeter, AverageMeter, num_params
from constants import *
from predict_factuality import predict_factuality

In [2]:
# both models are bert-base-uncased and share the same tokenizer
model_string = 'patrickvonplaten/bert2bert_cnn_daily_mail'
attribute_model_string = 'textattack/bert-base-uncased-MNLI'
device = 'cuda'
verbose = True

In [3]:
tokenizer = BertTokenizerFast.from_pretrained(model_string)
print(f"Loading pre-trained model: {model_string}...")
model = AutoModelForSeq2SeqLM.from_pretrained(model_string, return_dict=True).to(device)
model.eval()

print(f"Loading pre-trained conditioning model: {attribute_model_string}...")
conditioning_model = AutoModelForSequenceClassification.from_pretrained(attribute_model_string).to(device)
conditioning_model.eval()
if verbose:
    #checkpoint = torch.load(args.ckpt, map_location=args.device)
    #print(f"=> loaded checkpoint '{args.ckpt}' (epoch {checkpoint['epoch']})")
    print(f"model num params {num_params(model)}")
    print(f"conditioning_model num params {num_params(conditioning_model)}")

Loading pre-trained model: patrickvonplaten/bert2bert_cnn_daily_mail...
Loading pre-trained conditioning model: textattack/bert-base-uncased-MNLI...
model num params 247363386
conditioning_model num params 109484547


In [9]:
inputs = []
with open('factuality_data/dummy_input.txt', 'r', encoding='utf-8') as rf:
    for line in rf:
        inputs.append(line.strip())

In [5]:
for inp in tqdm(inputs, total=len(inputs)):
    results = predict_factuality(model,
                    tokenizer, 
                    conditioning_model, 
                    [inp],
                    precondition_topk=200,
                    do_sample=False,
                    min_length=30,
                    max_length=90,
                    condition_lambda=0.5,
                    device=device)
    print(results[0])

100%|██████████| 1/1 [00:59<00:00, 59.62s/it]

tallest structure in paris is the eiffel tower, which was built by the french tallest building. it is now taller than the chrysler building and has been used since 1957.





## Try the real CNN-DM dataset

In [4]:
import datasets
val_data = datasets.load_dataset("cnn_dailymail", "3.0.0", split="validation[:1%]")

Reusing dataset cnn_dailymail (/root/.cache/huggingface/datasets/cnn_dailymail/3.0.0/3.0.0/0128610a44e10f25b4af6689441c72af86205282d26399642f7db38fa7535602)


In [5]:
example = val_data[3]
inp = str(example['article'])
example

{'article': '(CNN)Ferguson is crumbling. The cowardly and reprehensible shooting Wednesday night of two police officers came in a tumultuous seven days for the Missouri town, which had already seen Ferguson Police Chief Thomas Jackson announce his resignation after a damning Justice Department report on its police department. The report, which was ordered in the wake of the killing of Michael Brown last year, highlighted a predatory policing problem and a department that was biased, prejudiced and that has regularly targeted, arrested and fined African-Americans. Residents understandably want justice. But what\'s worse in all this is that Ferguson is illustrative of a broader problem across the country as increasingly militarized majority-white police departments demonstrate consistent racial bias toward majority-black communities. It\'s a combustible mix. In three-quarters of all U.S. cities with populations 50,000 or more, the police presence is "disproportionately white relative to 

In [None]:
results = predict_factuality(model,
                    tokenizer, 
                    conditioning_model, 
                    [inp],
                    precondition_topk=200,
                    do_sample=False,
                    min_length=30,
                    max_length=90,
                    condition_lambda=0.0,
                    device=device)
results

In [None]:
%%time
results = predict_factuality(model,
                    tokenizer, 
                    conditioning_model, 
                    [inp],
                    precondition_topk=200,
                    do_sample=False,
                    min_length=30,
                    max_length=90,
                    condition_lambda=1.0,
                    device=device)
results

### Batch prediction and evaluate with ROUGE

In [17]:
from tqdm.notebook import tnrange
ground_truths, baseline_preds = list(), list()
for i in tnrange(len(val_data)):
    example = val_data[i]
    ground_truths.append(example['highlights'])
    inp = str(example['article'])
    pred = predict_factuality(model,
                    tokenizer, 
                    conditioning_model, 
                    [inp],
                    precondition_topk=200,
                    do_sample=False,
                    min_length=30,
                    max_length=90,
                    condition_lambda=0.0,
                    device=device)[0]
    baseline_preds.append(pred)

HBox(children=(FloatProgress(value=0.0, max=134.0), HTML(value='')))




In [18]:
from datasets import load_metric
metric = load_metric('rouge')

In [20]:
metric.compute(predictions=baseline_preds, references=ground_truths, rouge_types=['rouge1', 'rouge2', 'rougeL'])

{'rouge1': AggregateScore(low=Score(precision=0.2973886056202959, recall=0.29614087733837935, fmeasure=0.2884923559900645), mid=Score(precision=0.32652823393269714, recall=0.32132599812241225, fmeasure=0.31466521619878635), high=Score(precision=0.3562696664088422, recall=0.3496707382447767, fmeasure=0.34034349993756213)),
 'rouge2': AggregateScore(low=Score(precision=0.10140779724122802, recall=0.10000367038351103, fmeasure=0.09795803188875224), mid=Score(precision=0.12738807145162467, recall=0.12617802620784818, fmeasure=0.1239448374378738), high=Score(precision=0.15785803937875417, recall=0.1566275179534468, fmeasure=0.15346734024983144)),
 'rougeL': AggregateScore(low=Score(precision=0.21395823140953238, recall=0.21507604393472093, fmeasure=0.20929645243874456), mid=Score(precision=0.23793875573848006, recall=0.23952351909790798, fmeasure=0.23212024089348565), high=Score(precision=0.26698111309424244, recall=0.2678733453458605, fmeasure=0.2600093795293905))}