In [None]:
from transformers import TrainingArguments
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer
from transformers import AutoTokenizer
import numpy as np
import os
import pandas as pd
from pathlib import Path
import gc

BASE_DIR = Path(".")

In [None]:
e_snli_train_1 = pd.read_csv(BASE_DIR / "datasets" / "esnli_train_1.csv")
e_snli_train_2 = pd.read_csv(BASE_DIR / "datasets" / "esnli_train_2.csv")
e_snli_train = pd.concat([e_snli_train_1, e_snli_train_2])

e_snli_test = pd.read_csv(BASE_DIR / "datasets" / "esnli_test.csv")

e_snli_train.head()

# Finetuning on E-SNLI

From the FLUTE paper:

> T5e-SNLI: e-SNLI (Camburu et al., 2018) dataset
> comes with supervised ground-truth labels and ra-
> tionales. We fine-tune the 3B version of T5 on
> e-SNLI for one epoch with a batch size of 1024,
> and an AdamW Optimizer with a learning rate of
> 1e − 4. We remove the Neutral examples from
> e-SNLI because our test data does not have such
> a category. We take the longest explanation per
> example in e-SNLI since our data has only one ref-
> erence explanation. In case the explanations are
> more than one sentence we join them using ‘and’
> since our data contains single-sentence explana-
> tions. This leaves us with 366,603 training and
> 6,607 validation examples.


In [None]:
e_snli_train = e_snli_train[e_snli_train["gold_label"] != "neutral"]
e_snli_test = e_snli_test[e_snli_test["gold_label"] != "neutral"]


def join_sentences(explanation):
    return str(explanation).replace(". ", " and ")


e_snli_train["Explanation_1"] = e_snli_train["Explanation_1"].apply(join_sentences)


def find_longest_explanation(row):
    explanations = [row["Explanation_1"], row["Explanation_2"], row["Explanation_3"]]
    return max(explanations, key=len)


e_snli_test["Explanation"] = e_snli_test.apply(find_longest_explanation, axis=1)
e_snli_test["Explanation"] = e_snli_test["Explanation"].apply(join_sentences)

In [None]:
from transformers import DataCollatorForSeq2Seq

In [None]:
from transformers import (
    T5ForConditionalGeneration,
    T5Tokenizer,
    Trainer,
    TrainingArguments,
)
from datasets import Dataset
import numpy as np

train_dataset = Dataset.from_pandas(e_snli_train)

model_name = "t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
model = T5ForConditionalGeneration.from_pretrained(model_name)

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model_name)


def preprocess_train(batch):
    pairIDs = batch["pairID"]
    gold_labels = batch["gold_label"]
    Sentence1s = batch["Sentence1"]
    Sentence2s = batch["Sentence2"]
    explanation_1 = batch["Explanation_1"]

    inputs = [
        f"Does the sentence '{s1}' entail or contradict the sentence '{s2}'? Please answer between 'Entails' or 'Contradicts' and explain your decision in a sentence."
        for s1, s2 in zip(Sentence1s, Sentence2s)
    ]

    targets = [
        f"{label} - {explanation}"
        for label, explanation in zip(gold_labels, explanation_1)
    ]

    inputs = tokenizer(
        inputs,
        truncation=True,
        padding="max_length",
        max_length=128,
        return_tensors="pt",
    )
    targets = tokenizer(
        targets,
        truncation=True,
        padding="max_length",
        max_length=128,
        return_tensors="pt",
    )
    return {"input_ids": inputs["input_ids"], "labels": targets["input_ids"]}


tokenized_train = train_dataset.map(preprocess_train, batched=True)
tokenized_train = tokenized_train.remove_columns(
    [
        "pairID",
        "gold_label",
        "Sentence1",
        "Sentence2",
        "Explanation_1",
        "WorkerId",
        "Sentence1_marked_1",
        "Sentence2_marked_1",
        "Sentence1_Highlighted_1",
        "Sentence2_Highlighted_1",
        "__index_level_0__",
    ]
)

In [None]:
def preprocess_test(batch):
    pairIDs = batch["pairID"]
    gold_labels = batch["gold_label"]
    Sentence1s = batch["Sentence1"]
    Sentence2s = batch["Sentence2"]
    explanation = batch["Explanation"]

    inputs = [
        f"Does the sentence '{s1}' entail or contradict the sentence '{s2}'? Please answer between 'Entails' or 'Contradicts' and explain your decision in a sentence."
        for s1, s2 in zip(Sentence1s, Sentence2s)
    ]

    targets = [
        f"{label} - {explanation}"
        for label, explanation in zip(gold_labels, explanation)
    ]

    inputs = tokenizer(
        inputs,
        truncation=True,
        padding="max_length",
        max_length=128,
        return_tensors="pt",
    )
    targets = tokenizer(
        targets,
        truncation=True,
        padding="max_length",
        max_length=128,
        return_tensors="pt",
    )
    return {"input_ids": inputs["input_ids"], "labels": targets["input_ids"]}


test_dataset = Dataset.from_pandas(e_snli_test)
tokenized_test = test_dataset.map(preprocess_test, batched=True)

In [None]:
import numpy as np
import evaluate

metric = evaluate.load("rouge")

# TODO: Is this correct? Test.


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    print("Before decoding")
    decoded_preds = tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    print(decoded_preds)
    print(decoded_labels)
    result = metric.compute(
        predictions=decoded_preds, references=decoded_labels, use_stemmer=True
    )
    print("result", result)
    return result["rouge1"]

In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="models",
    evaluation_strategy="epoch",
    per_device_eval_batch_size=32,  # smaller batch size to run on less VRAM
    per_device_train_batch_size=32,  # smaller batch size to run on less VRAM
    num_train_epochs=1,
    learning_rate=1e-4,
    save_total_limit=1,
    save_steps=10,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    # compute_metrics=compute_metrics,
    data_collator=data_collator,
)

In [None]:
trainer.train()

del tokenized_train
del tokenized_test
del model
gc.collect()

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained("models/checkpoint-11450")

P = "I bet I am blue."
H = "I bet I am like a cherry."
prompt = f"Does the sentence '{P}' entail or contradict the sentence '{H}'? Please answer between 'Entails' or 'Contradicts' and explain your decision in a sentence."
tokens = tokenizer(prompt, return_tensors="pt").input_ids
tokens = tokens.to(model.device)
output = model.generate(tokens, max_new_tokens=100)
output = tokenizer.decode(output[0])
print(output)

## Finetuning on FLUTE

From the FLUTE paper:

> We fine-tune the 3B version of T5
> model for 10 epochs with a batch size of 1024, and
> an AdamW Optimizer with a learning rate of 1e−4
> in a multitask fashion with data from all the four
> types of figurative languages combined. Our train-
> ing data consists of 7,035 samples which is 50X
> smaller than e-SNLI. For validation we use 500 ex-
> amples which is used for selecting best checkpoint
> based on loss.


In [None]:
flute_dataset = pd.read_json(BASE_DIR / "datasets" / "train.jsonl", lines=True)

print(len(flute_dataset))
flute_dataset.head()

In [None]:
from transformers import (
    T5ForConditionalGeneration,
    T5Tokenizer,
    Trainer,
    TrainingArguments,
)
from datasets import Dataset
import numpy as np

flute_dataset = Dataset.from_pandas(flute_dataset)

model_name = "t5-small"
tokenizer = T5Tokenizer.from_pretrained(model_name)
flute_model = T5ForConditionalGeneration.from_pretrained(model_name)

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model_name)


def preprocess_train(batch):
    premises = batch["premise"]
    hypothesis = batch["hypothesis"]
    labels = batch["label"]
    explanations = batch["explanation"]

    inputs = [
        f"Does the sentence '{s1}' entail or contradict the sentence '{s2}'? Please answer between 'Entails' or 'Contradicts' and explain your decision in a sentence."
        for s1, s2 in zip(hypothesis, premises)
    ]

    targets = [
        f"{label} - {explanation}" for label, explanation in zip(labels, explanations)
    ]

    inputs = tokenizer(
        inputs,
        truncation=True,
        padding="max_length",
        max_length=128,
        return_tensors="pt",
    )
    targets = tokenizer(
        targets,
        truncation=True,
        padding="max_length",
        max_length=128,
        return_tensors="pt",
    )
    return {"input_ids": inputs["input_ids"], "labels": targets["input_ids"]}


tokenized_dataset = flute_dataset.map(preprocess_train, batched=True)

tokenized_dataset = tokenized_dataset.remove_columns(
    ["id", "premise", "hypothesis", "label", "explanation", "split", "type"]
)

train_dataset = tokenized_dataset.shuffle(seed=42).select(range(7035))
test_dataset = tokenized_dataset.select(range(7035, len(tokenized_dataset)))

print(len(train_dataset))
print(len(test_dataset))

In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(
    output_dir="models",
    evaluation_strategy="epoch",
    per_device_eval_batch_size=32,
    per_device_train_batch_size=32,
    num_train_epochs=10,
    learning_rate=1e-4,
    save_total_limit=1,
    save_steps=10,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    # compute_metrics=compute_metrics,
    data_collator=data_collator,
)

In [None]:
trainer.train()

del train_dataset
del test_dataset
del model
gc.collect()

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained("models/checkpoint-2200")

P = "I bet I am blue."
H = "I bet I am like a cherry."
prompt = f"Does the sentence '{P}' entail or contradict the sentence '{H}'? Please answer between 'Entails' or 'Contradicts' and explain your decision in a sentence."
tokens = tokenizer(prompt, return_tensors="pt").input_ids
tokens = tokens.to(model.device)
output = model.generate(tokens, max_new_tokens=100)
output = tokenizer.decode(output[0])
print(output)