In [1]:
import numpy as np
from sklearn.model_selection import StratifiedKFold
from datasets import Dataset, DatasetDict, load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Trainer
import pandas as pd
import torch
import evaluate
import nltk

BATCH_SIZE = 12
NUM_EPOCHS = 8
base_checkpoint = "t5-small"

tokenizer = AutoTokenizer.from_pretrained(base_checkpoint)

#ds = load_dataset("ColumbiaNLP/FLUTE").shuffle(seed=42)
df = pd.read_csv("complete_dataset.csv").fillna("")
ds = Dataset.from_pandas(df).shuffle(seed=42)
folds = StratifiedKFold(n_splits=10, shuffle=False)
splits = folds.split(ds, ds['label'])
indexes = [t for t in splits]

In [2]:
from flute_dream import add_combined_cols

ds = ds.map(add_combined_cols)

Map:   0%|          | 0/7534 [00:00<?, ? examples/s]

In [3]:
def preprocess_dataset_s1(examples):
    model_inputs = tokenizer(examples['premise_hypothesis'])
    labels = tokenizer(examples['label_explanation'])
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

def preprocess_dataset_s2(examples):
    model_inputs = tokenizer(examples['premise_hypothesis_system_2'])
    labels = tokenizer(examples['type_label_explanation'])
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

def preprocess_dataset_s31(examples):
    model_inputs = tokenizer(examples['premise_hypothesis_emotion'])
    labels = tokenizer(examples['label_explanation'])
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

def preprocess_dataset_s32(examples):
    model_inputs = tokenizer(examples['premise_hypothesis_motivation'])
    labels = tokenizer(examples['label_explanation'])
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

def preprocess_dataset_s33(examples):
    model_inputs = tokenizer(examples['premise_hypothesis_consequence'])
    labels = tokenizer(examples['label_explanation'])
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

def preprocess_dataset_s34(examples):
    model_inputs = tokenizer(examples['premise_hypothesis_rot'])
    labels = tokenizer(examples['label_explanation'])
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

def preprocess_dataset_s35(examples):
    model_inputs = tokenizer(examples['premise_hypothesis_all_dims'])
    labels = tokenizer(examples['label_explanation'])
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

def preprocess_dataset_s41(examples):
    model_inputs = tokenizer(examples['premise_hypothesis'])
    labels = tokenizer(examples['label'])
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

def preprocess_dataset_s42(examples):
    model_inputs = tokenizer(examples['premise_hypothesis_label'])
    labels = tokenizer(examples['explanation'])
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

In [4]:
operating_modes = [
    ("system_1", preprocess_dataset_s1),
    ("system_2", preprocess_dataset_s2),
    ("system_31", preprocess_dataset_s31),
    ("system_32", preprocess_dataset_s32),
    ("system_33", preprocess_dataset_s33),
    ("system_34", preprocess_dataset_s34),
    ("system_35", preprocess_dataset_s35),
    ]

In [5]:
from IPython.display import clear_output

modes = {name: [] for name, _ in operating_modes}

for train_idxs, val_idxs in indexes[8:]:
    fold_dataset = DatasetDict({
        "train": ds.select(train_idxs),
        "val": ds.select(val_idxs)
    })

    for name, preprocess_func in operating_modes:
        curr_ds = fold_dataset.map(preprocess_func, batched=True).remove_columns(fold_dataset['train'].column_names)

        training_args = Seq2SeqTrainingArguments(
            output_dir=f"{name}",
            learning_rate=3e-4,
            per_device_train_batch_size=BATCH_SIZE,
            per_device_eval_batch_size=2*BATCH_SIZE,
            save_total_limit=2,
            num_train_epochs=NUM_EPOCHS,
            report_to="none",
            evaluation_strategy="epoch",
            save_strategy="epoch",
            eval_accumulation_steps=1,
            logging_steps=1,
        )

        model = AutoModelForSeq2SeqLM.from_pretrained(base_checkpoint)

        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=curr_ds["train"],
            eval_dataset=curr_ds["val"].select(range(350)),
            tokenizer=tokenizer,
            data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
        )

        trainer.train()
        
        # have to do batched rouge computation otherwise not enough memory
        rouge = evaluate.load("rouge")
        metrics = {'rouge1': 0., 'rouge2': 0., 'rougeL': 0., 'rougeLsum': 0.}
        count = 0
        for i in range(0, len(curr_ds['val']), 100):
            count += 1
            (predictions, _), label_ids, _ = trainer.predict(test_dataset=curr_ds['val'].select(range(i, min(i+100, len(curr_ds['val'])))))
            predicted_token_ids = torch.argmax(torch.from_numpy(predictions), dim=-1)
            decoded_preds = tokenizer.batch_decode(predicted_token_ids, skip_special_tokens=True)
            labels = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
            decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
            new_metrics = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
            for k in new_metrics:
                metrics[k] += new_metrics[k]

        for k in metrics:
                metrics[k] /= count
        
        clear_output(wait=True)
        modes[name].append(metrics['rouge1'])
        print(modes)

{'system_1': [0.498351050356349, 0.5090844074007058], 'system_2': [0.5289706096196247, 0.535423121872322], 'system_31': [0.5004418699772601, 0.5081924349139884], 'system_32': [0.5000023839289528, 0.5069172466751917], 'system_33': [0.49609349538092806, 0.5060478473679386], 'system_34': [0.49683869684920356, 0.5072891134960673], 'system_35': [0.49873229852851614, 0.5063988604083957]}


In [6]:
"""
{'system_1': [0.5056115530372459, 0.5085812472965738], 'system_2': [0.5323239453833644, 0.5374525863206762], 'system_31': [0.5084813364883226, 0.5078069259932709], 'system_32': [0.5055255195638098, 0.5074808302929493], 'system_33': [0.5032615710245999, 0.508591734512484], 'system_34': [0.5067396302748032, 0.5087587260751042], 'system_35': [0.5036829298423671, 0.5073656197676405]}
{'system_1': [0.47329270980754923], 'system_2': [0.5039333868999851], 'system_31': [0.47420432731446144], 'system_32': [0.473849787491779], 'system_33': [0.47287909150934293], 'system_34': [0.47171914165316353], 'system_35': [0.47436950304876985]}
{'system_1': [0.5221458462682274, 0.4906728974226775], 'system_2': [0.5453307495230577, 0.5189864534869755], 'system_31': [0.524029905213074, 0.49432447432649923], 'system_32': [0.5236538569098267, 0.4949640794043393], 'system_33': [0.522696825432352, 0.49418902866839187], 'system_34': [0.5231002650281118, 0.49245360907178504], 'system_35': [0.5238553839735753, 0.4967562420314123]}
{'system_1': [0.5201383622010409, 0.5019083214066056], 'system_2': [0.5451744732719431, 0.5273190341222926], 'system_31': [0.5192974336052382, 0.5002995114825706], 'system_32': [0.5188841013023715, 0.5007719063422055], 'system_33': [0.5203219622545853, 0.49992736417041994], 'system_34': [0.5180675680247785, 0.5008015168679895], 'system_35': [0.5224425242712794, 0.5010754435268072]}
{'system_1': [0.49590791112913984], 'system_2': [0.5147096144113823], 'system_31': [0.4927430972029428], 'system_32': [0.49258799343269694], 'system_33': [0.49049952341283654], 'system_34': [0.49317378483047924], 'system_35': [0.4919880906784929]}
{'system_1': [0.498351050356349, 0.5090844074007058], 'system_2': [0.5289706096196247, 0.535423121872322], 'system_31': [0.5004418699772601, 0.5081924349139884], 'system_32': [0.5000023839289528, 0.5069172466751917], 'system_33': [0.49609349538092806, 0.5060478473679386], 'system_34': [0.49683869684920356, 0.5072891134960673], 'system_35': [0.49873229852851614, 0.5063988604083957]}
"""

"\n{'system_1': [0.5056115530372459, 0.5085812472965738], 'system_2': [0.5323239453833644, 0.5374525863206762], 'system_31': [0.5084813364883226, 0.5078069259932709], 'system_32': [0.5055255195638098, 0.5074808302929493], 'system_33': [0.5032615710245999, 0.508591734512484], 'system_34': [0.5067396302748032, 0.5087587260751042], 'system_35': [0.5036829298423671, 0.5073656197676405]}\n{'system_1': [0.47329270980754923], 'system_2': [0.5039333868999851], 'system_31': [0.47420432731446144], 'system_32': [0.473849787491779], 'system_33': [0.47287909150934293], 'system_34': [0.47171914165316353], 'system_35': [0.47436950304876985]}\n{'system_1': [0.5221458462682274, 0.4906728974226775], 'system_2': [0.5453307495230577, 0.5189864534869755], 'system_31': [0.524029905213074, 0.49432447432649923], 'system_32': [0.5236538569098267, 0.4949640794043393], 'system_33': [0.522696825432352, 0.49418902866839187], 'system_34': [0.5231002650281118, 0.49245360907178504], 'system_35': [0.5238553839735753, 

In [7]:
modes

{'system_1': [0.498351050356349, 0.5090844074007058],
 'system_2': [0.5289706096196247, 0.535423121872322],
 'system_31': [0.5004418699772601, 0.5081924349139884],
 'system_32': [0.5000023839289528, 0.5069172466751917],
 'system_33': [0.49609349538092806, 0.5060478473679386],
 'system_34': [0.49683869684920356, 0.5072891134960673],
 'system_35': [0.49873229852851614, 0.5063988604083957]}