In [1]:
!pip install -qqq transformers datasets sentencepiece rouge_score

[K     |████████████████████████████████| 5.5 MB 5.4 MB/s 
[K     |████████████████████████████████| 451 kB 62.4 MB/s 
[K     |████████████████████████████████| 1.3 MB 48.7 MB/s 
[K     |████████████████████████████████| 7.6 MB 35.8 MB/s 
[K     |████████████████████████████████| 182 kB 56.9 MB/s 
[K     |████████████████████████████████| 212 kB 46.0 MB/s 
[K     |████████████████████████████████| 115 kB 47.4 MB/s 
[K     |████████████████████████████████| 127 kB 34.8 MB/s 
[?25h  Building wheel for rouge-score (setup.py) ... [?25l[?25hdone


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [3]:
!mkdir -p CausalQA/input
!cp -r /content/drive/MyDrive/CausalQA/input/* ./CausalQA/input/

In [4]:
!unzip ./CausalQA/input/original-splits.zip
!unzip ./CausalQA/input/random-splits.zip

Archive:  ./CausalQA/input/original-splits.zip
   creating: Webis-CausalQA-22-v-1.0/input/original-splits/
  inflating: Webis-CausalQA-22-v-1.0/input/original-splits/eli5_train_original_split.csv  
  inflating: Webis-CausalQA-22-v-1.0/input/original-splits/msmarco_valid_original_split.csv  
  inflating: Webis-CausalQA-22-v-1.0/input/original-splits/searchqa_train_original_split.csv  
  inflating: Webis-CausalQA-22-v-1.0/input/original-splits/newsqa_train_original_split.csv  
  inflating: Webis-CausalQA-22-v-1.0/input/original-splits/naturalquestions_valid_original_split.csv  
  inflating: Webis-CausalQA-22-v-1.0/input/original-splits/hotpotqa_valid_original_split.csv  
  inflating: Webis-CausalQA-22-v-1.0/input/original-splits/searchqa_valid_original_split.csv  
  inflating: Webis-CausalQA-22-v-1.0/input/original-splits/triviaqa_valid_original_split.csv  
  inflating: Webis-CausalQA-22-v-1.0/input/original-splits/naturalquestions_train_original_split.csv  
  inflating: Webis-CausalQA-2

In [5]:
import re
import string
from argparse import Namespace
from rouge_score import rouge_scorer, scoring
from collections import Counter
from datasets import load_dataset
import torch
from torch.utils.data import DataLoader
from transformers import (
    T5Tokenizer,
    T5ForConditionalGeneration,
    DataCollatorForSeq2Seq,
    set_seed,
    Seq2SeqTrainer,
    Seq2SeqTrainingArguments,
)
from itertools import chain
from tqdm import tqdm

In [6]:
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [7]:
DEVICE

device(type='cpu')

In [8]:
def preprocess(str_: str) -> str:
    str_ = str_.lower()
    str_ = str_.translate(str.maketrans("", "", string.punctuation))
    str_ = re.sub(r"\b(a|an|the)\b", " ", str_)
    str_ = " ".join(str_.split())
    return str_

In [9]:
def _f1(pred: str, ground_truth: str) -> float:
    prediction_tokens = pred.split()
    ground_truth_tokens = ground_truth.split()
    common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
    num_same = sum(common.values())

    if num_same == 0:
        return 0
    
    precision = 1.0 * num_same / len(prediction_tokens)
    recall = 1.0 * num_same / len(ground_truth_tokens)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


In [10]:
def _em(pred: str, ground_truth: str) -> int:
    return int(pred == ground_truth)

In [11]:
def _calculate_measures(measure, predictions, ground_truths):
    result = 0
    samples = []
    for pred, ground_truth in zip(predictions, ground_truths):
        value = max(measure(pred, answer) for answer in ground_truth)
        result += value
        samples.append(value)
    return result / len(predictions), samples

In [12]:
def _rouge_l(predictions, ground_truths):
    scorer = rouge_scorer.RougeScorer(["rougeL"])
    aggregator = scoring.BootstrapAggregator()
    samples_precision = []
    samples_recall = []
    samples_f1 = []
    for pred, gts in zip(predictions, ground_truths):
        score = scorer.score_multi(gts, pred)
        aggregator.add_scores(score)
        samples_precision.append(score["rougeL"].precision)
        samples_recall.append(score["rougeL"].recall)
        samples_f1.append(score["rougeL"].fmeasure)
    results = aggregator.aggregate()
    return (
        {
            "rougeL_precision": results["rougeL"].mid.precision,
            "rougeL_recall": results["rougeL"].mid.recall,
            "rougeL_f1": results["rougeL"].mid.fmeasure,
        },
        samples_precision,
        samples_recall,
        samples_f1,
    )

In [13]:
def all_metrics(predictions, ground_truths) :
    predictions = [preprocess(pred) for pred in predictions]
    ground_truths = [[preprocess(gt) for gt in gts] for gts in ground_truths]

    print("predictions", predictions)
    print("ground_truths", ground_truths)

    rougel, sample_rougel_precision, sample_rougel_recall, sample_rougel_f1 = _rouge_l(predictions, ground_truths)
    f1, sample_f1 = _calculate_measures(_f1, predictions, ground_truths)
    em, sample_em = _calculate_measures(_em, predictions, ground_truths)

    f1_em_ = {"f1": f1, "em": em}
    samples = {
        "samples_f1": sample_f1,
        "samples_exact_match": sample_em,
        "samples_rougeL_precision": sample_rougel_precision,
        "samples_rougeL_recall": sample_rougel_recall,
        "samples_rougeL_f1": sample_rougel_f1,
    }

    return dict(chain.from_iterable(d.items() for d in (f1_em_, rougel, samples)))

In [14]:
def generate_tokenizer_model(args):
  print("Load tokenizer and model...")
  tokenizer = T5Tokenizer.from_pretrained(args.model)
  model = T5ForConditionalGeneration.from_pretrained(args.model)

  return tokenizer, model

In [15]:
args = Namespace(
    checkpoint="allenai/unifiedqa-v2-t5-base-1363200",
    train_file="Webis-CausalQA-22-v-1.0/input/original-splits/squad2_train_original_split.csv",
    valid_file="Webis-CausalQA-22-v-1.0/input/original-splits/squad2_valid_original_split.csv",
    model="andreaschandra/unifiedqa-v2-t5-base-1363200-finetuned-causalqa-squad",
    epochs=5,
    # source_length=2048, #original
    source_length=1024,
    target_length=100,
    batch_size=2,
    seed=42,
    num_procs=8,
    output_directory="Webis-CausalQA-22-v-1.0/models/original-splits/"
)

In [16]:
tokenizer, model = generate_tokenizer_model(args)

Load tokenizer and model...


Downloading:   0%|          | 0.00/792k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.20k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.43k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/1.54k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/892M [00:00<?, ?B/s]

In [17]:
model = model.to(DEVICE)

In [18]:
def build_input(batch):
    input_ = [(question + ' \\n ' + context if context is not None else question)
              for question, context in zip(batch['question_processed'], batch['context_processed'])]
    batch['input'] = input_
    return batch

In [19]:
set_seed(args.seed)
data = load_dataset('csv', data_files=args.valid_file)['train']
data = data.map(build_input, batched=True, load_from_cache_file=False, num_proc=args.num_procs)
data = data.remove_columns(['context', 'context_processed'])



Downloading and preparing dataset csv/default to /root/.cache/huggingface/datasets/csv/default-bd5f6a419939b4bd/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317...


Downloading data files:   0%|          | 0/1 [00:00<?, ?it/s]

Extracting data files:   0%|          | 0/1 [00:00<?, ?it/s]

Generating train split: 0 examples [00:00, ? examples/s]

Dataset csv downloaded and prepared to /root/.cache/huggingface/datasets/csv/default-bd5f6a419939b4bd/0.0.0/6b34fb8fcf56f7c8ba51dc895bfa2bfbe43546f190a60fcf74bb5e8afdcc2317. Subsequent calls will reuse this data.


  0%|          | 0/1 [00:00<?, ?it/s]

          

#0:   0%|          | 0/1 [00:00<?, ?ba/s]

   

#1:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#3:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#2:   0%|          | 0/1 [00:00<?, ?ba/s]

#4:   0%|          | 0/1 [00:00<?, ?ba/s]

#5:   0%|          | 0/1 [00:00<?, ?ba/s]

 

#6:   0%|          | 0/1 [00:00<?, ?ba/s]

#7:   0%|          | 0/1 [00:00<?, ?ba/s]

In [20]:
def run_model(batch, model, tokenizer, args):
    if 'naturalquestions' in args.valid_file:
        encoded_inputs = tokenizer(batch, max_length=10000, padding='max_length',
                                   truncation=True, return_tensors="pt").to(DEVICE)
    else:
        encoded_inputs = tokenizer(batch, padding='longest', return_tensors="pt").to(DEVICE)
    res = model.generate(**encoded_inputs, max_length=500)
    return tokenizer.batch_decode(res, skip_special_tokens=True)

In [21]:
loader = DataLoader(data, shuffle=False, num_workers=0, batch_size=args.batch_size)
predictions = []
for batch in tqdm(loader):
    batch_predictions = run_model(batch['input'], model, tokenizer, args)
    predictions.extend(batch_predictions)


  0%|          | 0/126 [00:00<?, ?it/s][A
  1%|          | 1/126 [00:01<03:59,  1.92s/it][A
  2%|▏         | 2/126 [00:04<04:09,  2.01s/it][A
  2%|▏         | 3/126 [00:06<04:43,  2.30s/it][A
  3%|▎         | 4/126 [00:09<05:02,  2.48s/it][A
  4%|▍         | 5/126 [00:12<05:38,  2.80s/it][A
  5%|▍         | 6/126 [00:15<05:38,  2.82s/it][A
  6%|▌         | 7/126 [00:18<05:36,  2.83s/it][A
  6%|▋         | 8/126 [00:20<05:12,  2.65s/it][A
  7%|▋         | 9/126 [00:23<05:09,  2.65s/it][A
  8%|▊         | 10/126 [00:26<05:10,  2.68s/it][A
  9%|▊         | 11/126 [00:29<05:34,  2.91s/it][A
 10%|▉         | 12/126 [00:31<04:55,  2.59s/it][A
 10%|█         | 13/126 [00:33<04:39,  2.47s/it][A
 11%|█         | 14/126 [00:36<04:38,  2.48s/it][A
 12%|█▏        | 15/126 [00:37<04:13,  2.29s/it][A
 13%|█▎        | 16/126 [00:40<04:22,  2.38s/it][A
 13%|█▎        | 17/126 [00:43<04:42,  2.59s/it][A
 14%|█▍        | 18/126 [00:47<05:07,  2.84s/it][A
 15%|█▌        | 19/126 [00:4

In [22]:
answers = data['answer']
answers = [answer.split('\t') for answer in answers]

In [23]:
len(predictions), len(answers)

(252, 252)

In [24]:
predictions[0]

'time or space'

In [25]:
answers[0]

['time or space', 'time or space', 'time or space']

In [26]:
result = {}
result['checkpoint'] = args.checkpoint
result['metrics'] = all_metrics(predictions, answers)
result['predictions'] = predictions

predictions ['time or space', 'definitions', 'more efficient solutions', 'property damage', 'heavily impacted', 'sky digital', 'their own militia', 'huguenot rebellions', 'required education of children as catholics and prohibited emigration', 'protest against occupation of prussia by napoleon', 'wars of religion', 'acted increasingly aggressively', 'adapted quickly and often married outside their immediate french communities', 'lead melts and steam escapes', 'increase in land available for cultivation', 'spiritus nitroaereus', 'he published his findings first', 'more active and lived longer', 'source of most of chemical energy released', 'because of its unpaired electrons', 'very reactive allotrope of oxygen', 'unpaired electrons', 'magnetic', 'higher oxygen content', 'mild euphoric', 'performance boost', 'anaerobic bacteria', 'electronegativity', 'oxides', '160 kpa', 'low total pressures used', 'to avoid being targeted by boycott', 'multilateral negotiations', 'currency values would 

In [27]:
result['metrics'].keys()

dict_keys(['f1', 'em', 'rougeL_precision', 'rougeL_recall', 'rougeL_f1', 'samples_f1', 'samples_exact_match', 'samples_rougeL_precision', 'samples_rougeL_recall', 'samples_rougeL_f1'])

In [28]:
from pprint import pprint

In [29]:
pprint(result['metrics'])

{'em': 0.6468253968253969,
 'f1': 0.8263820367779812,
 'rougeL_f1': 0.8259425081752165,
 'rougeL_precision': 0.8430672745081389,
 'rougeL_recall': 0.8474311513001989,
 'samples_exact_match': [1,
                         1,
                         1,
                         1,
                         1,
                         1,
                         1,
                         1,
                         0,
                         0,
                         1,
                         0,
                         0,
                         0,
                         1,
                         1,
                         1,
                         1,
                         0,
                         1,
                         0,
                         1,
                         1,
                         1,
                         1,
                         0,
                         1,
                         1,
                         0,
                     