In [1]:
import pandas as pd
import os
import json
import re
from datasets import Dataset
from torch.utils.data import DataLoader

from transformers import T5ForConditionalGeneration, T5Tokenizer, MT5ForConditionalGeneration, MT5Tokenizer
from transformers import Seq2SeqTrainingArguments
from transformers import DataCollatorForSeq2Seq
from transformers import Seq2SeqTrainer
os.environ['WANDB_SILENT']="true"
os.environ["WANDB_DISABLED"] = "true"



In [2]:
import pandas as pd
import os
import json
import re
from datasets import Dataset

def file_loader(json_path):
    with open(json_path) as f:
        contents = f.read()
    return json.loads(contents)

numerical_reasoning_test = "Test_Numerical_Reasoning.json"


numerical_data_test_path = os.path.join("Dataset", numerical_reasoning_test)

numerical_data_test = file_loader(numerical_data_test_path)
df_test = pd.DataFrame.from_dict(numerical_data_test)



In [3]:
df_test.to_csv('test.csv',index=False)

In [4]:
df_test['context'] = df_test['news'].apply(lambda x: re.sub(r'\([^)]*\)', '', x, 1).strip())

In [5]:
def process_input(item):

    headline = item["masked headline"]
    return f"{item['news']}\n\n Fill in the blank: {headline}"

In [6]:
df_test['t5-input'] = df_test.apply(lambda x: process_input(x),axis=1)


In [7]:
test = Dataset.from_pandas(df_test)

In [8]:
# test.save_to_disk('./Dataset/test-final')

In [9]:
def collator(batch):

    input = batch['t5-input'] #load original sentences
#     label = batch['ans'] #load noisy sentences
    inputs = tokenizer(input, return_tensors="pt", max_length = 512, padding='max_length',truncation=True) #tokenized sentences

    return inputs

In [10]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
model_id="google/flan-t5-small"
saved_model = "./Outputs/Trial-v2"
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForSeq2SeqLM.from_pretrained(saved_model)
model.cuda()

T5ForConditionalGeneration(
  (shared): Embedding(32128, 512)
  (encoder): T5Stack(
    (embed_tokens): Embedding(32128, 512)
    (block): ModuleList(
      (0): T5Block(
        (layer): ModuleList(
          (0): T5LayerSelfAttention(
            (SelfAttention): T5Attention(
              (q): Linear(in_features=512, out_features=384, bias=False)
              (k): Linear(in_features=512, out_features=384, bias=False)
              (v): Linear(in_features=512, out_features=384, bias=False)
              (o): Linear(in_features=384, out_features=512, bias=False)
              (relative_attention_bias): Embedding(32, 6)
            )
            (layer_norm): T5LayerNorm()
            (dropout): Dropout(p=0.1, inplace=False)
          )
          (1): T5LayerFF(
            (DenseReluDense): T5DenseGatedActDense(
              (wi_0): Linear(in_features=512, out_features=1024, bias=False)
              (wi_1): Linear(in_features=512, out_features=1024, bias=False)
              (wo): 

In [11]:
test_tokenized = test.map(collator, remove_columns=test.column_names, batch_size=8, num_proc=4, batched=True)

Map (num_proc=4):   0%|          | 0/4921 [00:00<?, ? examples/s]

In [12]:
from tqdm import tqdm
dataloader = DataLoader(test, batch_size=8)

#perform inference
predictions = []
for data in tqdm(dataloader):

    inputs = tokenizer(data['t5-input'], return_tensors="pt",max_length = 512, padding='max_length',truncation=True)
    output_ids = model.generate(input_ids=inputs['input_ids'].cuda(), max_length = 512)
    predictions.extend(tokenizer.batch_decode(output_ids,skip_special_tokens=True))


100%|█████████████████████████████████████████| 616/616 [01:35<00:00,  6.42it/s]


In [13]:
output = pd.DataFrame()

In [14]:
output['ID'] = test['id']

In [15]:
output['Pred'] = predictions

In [16]:
output.to_csv('Small-preds.csv',index=False)

In [20]:
df_test

Unnamed: 0,news,masked headline,id,context,t5-input
0,"(Dec 21, 2011 5:52 PM) A burglar made a huge ...",IT Guy Foils Burglar From ____ Blocks Away,0,A burglar made a huge mistake yesterday when h...,"(Dec 21, 2011 5:52 PM) A burglar made a huge ..."
1,"(Dec 22, 2013 3:00 PM) You probably felt pret...",8 Stars Who Hit ____ This Year,1,You probably felt pretty old when you learned ...,"(Dec 22, 2013 3:00 PM) You probably felt pret..."
2,"(Apr 23, 2014 4:53 AM CDT) In a case most com...",3 Text Messages to Fan Cost Buffalo Bills $____M,2,In a case most commentators are calling frivol...,"(Apr 23, 2014 4:53 AM CDT) In a case most com..."
3,"(May 10, 2016 5:57 AM CDT) Ryan Gosling and E...","Gosling, Mendes Had Secret Baby No. ____",3,Ryan Gosling and Eva Mendes have pulled off wh...,"(May 10, 2016 5:57 AM CDT) Ryan Gosling and E..."
4,"(Jan 11, 2010 7:41 AM) Gay marriage will have...",Landmark Prop ____ Trial to Begin Today,4,Gay marriage will have its day in federal cour...,"(Jan 11, 2010 7:41 AM) Gay marriage will have..."
...,...,...,...,...,...
4916,"(Sep 9, 2012 9:55 AM CDT) If you're looking f...",London's 'WeirWolf' Wins ____th Gold,4995,If you're looking for the star of London's Par...,"(Sep 9, 2012 9:55 AM CDT) If you're looking f..."
4917,"(Apr 22, 2010 1:21 PM CDT) The Coast Guard is...",La. Oil Rig Sinks; ____ Still Missing,4996,The Coast Guard is saying that an oil platform...,"(Apr 22, 2010 1:21 PM CDT) The Coast Guard is..."
4918,"(Aug 27, 2012 12:03 AM CDT) Some 300 earthqua...",____-Quake 'Storm' Rattles California,4997,Some 300 earthquakes shook up Southern Califor...,"(Aug 27, 2012 12:03 AM CDT) Some 300 earthqua..."
4919,"(Dec 8, 2008 5:40 PM) Under a wave of critici...",Merrill CEO Backs Off $____M Bonus Request,4998,"Under a wave of criticism, Merrill Lynch CEO J...","(Dec 8, 2008 5:40 PM) Under a wave of critici..."


In [22]:
df_test.iloc[2]['news']

"(Apr 23, 2014  4:53 AM CDT) In a case most commentators are calling frivolous, finicky, or just plain silly, the Buffalo Bills have coughed up $3 million to settle a lawsuit from a fan who complained that the team sent him too many text messages. The fan brought the class-action suit after he signed up for the team's news service, which promised three to five texts a week, but then received six messages one week and seven the next, reports the Buffalo News. The fan claimed that the extra messages violated the Federal Telephone Consumer Protection Act and sought damages of around $2,000 per message. Under the terms of the settlement, the tens of thousands of fans who signed up for the now-defunct service will receive a total of around $2.5 million in the form of debit cards redeemable only at the team's store, the New York Post reports. The fan's lawyers will receive more than $500,000 under the settlement and he will get $5,000."