In [None]:
import numpy as np
from sklearn.model_selection import StratifiedKFold
from datasets import Dataset, DatasetDict, load_dataset
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Trainer
import pandas as pd
import torch
import evaluate
import nltk

BATCH_SIZE = 10
NUM_EPOCHS = 8
base_checkpoint = "t5-small"

tokenizer = AutoTokenizer.from_pretrained(base_checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(base_checkpoint)

def add_cols(entry):

    premise = entry["premise"].strip()
    hypothesis = entry["hypothesis"].strip()

    if not premise.endswith("."):
        premise += "."
    assert(premise.endswith("."))
    if not hypothesis.endswith("."):
        hypothesis += "."
    assert(hypothesis.endswith("."))

    # Columns for System 1
    entry["premise_hypothesis"] = 'Premise: ' + premise + ' Hypothesis: ' + hypothesis + ' Is there a contradiction or entailment between the premise and hypothesis ?'
    #entry["label_explanation"] = 'Explanation: ' + entry["explanation"] + '. Label: ' + entry["label"]
    entry["label_explanation"] = 'Label: ' + entry["label"] + '. Explanation: ' + entry["explanation"]
    return entry

df = pd.read_csv("complete_dataset.csv").fillna("")
df_syn = pd.read_csv("synthetic_data_merge.tsv", sep="\t").fillna("")
ds = Dataset.from_pandas(df).shuffle(seed=42)
ds_syn = Dataset.from_pandas(df_syn).shuffle(seed=42)

ds = ds.map(add_cols)
ds_syn = ds_syn.map(add_cols)

def preprocess_dataset_s1(examples):
    model_inputs = tokenizer(examples['premise_hypothesis'], truncation=True, max_length=512)
    labels = tokenizer(examples['label_explanation'], truncation=True, max_length=512)
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

In [None]:
sentence = "translate English to French: Hello big guy, I'm a very strange man"
tokenizer.decode(model.generate(tokenizer(sentence, return_tensors="pt")['input_ids'])[0], skip_special_tokens=True)

In [None]:
for now_ds in (ds_syn.train_test_split(test_size=0.2), ds.train_test_split(test_size=0.2)):

    curr_ds = now_ds.map(preprocess_dataset_s1, batched=True).remove_columns(now_ds['train'].column_names)

    training_args = Seq2SeqTrainingArguments(
        output_dir=f"T5-small-synthetic-FLUTE",
        learning_rate=3e-4,
        per_device_train_batch_size=8,
        per_device_eval_batch_size=2*8,
        save_total_limit=2,
        num_train_epochs=NUM_EPOCHS,
        report_to="none",
        evaluation_strategy="epoch",
        save_strategy="epoch",
        eval_accumulation_steps=1,
        logging_steps=1,
        lr_scheduler_type="constant"
    )

    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=curr_ds["train"],
        eval_dataset=curr_ds["test"].select(range(350)),
        tokenizer=tokenizer,
        data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
    )

    trainer.train()

    # have to do batched rouge computation otherwise not enough memory
    rouge = evaluate.load("rouge")
    metrics = {'rouge1': 0., 'rouge2': 0., 'rougeL': 0., 'rougeLsum': 0.}
    count = 0
    for i in range(0, len(curr_ds['test']), 80):
        count += 1
        (predictions, _), label_ids, _ = trainer.predict(test_dataset=curr_ds['test'].select(range(i, min(i+80, len(curr_ds['test'])))))
        # delete stuff after EOS token
        predicted_token_ids = torch.argmax(torch.from_numpy(predictions), dim=-1)
        for i in range(predicted_token_ids.shape[0]):
            ind = (predicted_token_ids[i] == 1).nonzero(as_tuple=True)[0]
            if ind.numel() != 0:
                predicted_token_ids[i, ind[0]:] = 1

        decoded_preds = tokenizer.batch_decode(predicted_token_ids, skip_special_tokens=True)
        labels = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
        decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
        new_metrics = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
        for k in new_metrics:
            metrics[k] += new_metrics[k]

    for k in metrics:
            metrics[k] /= count

    print(metrics['rouge1'])
    break

In [None]:
from huggingface_hub import login
login()
trainer.push_to_hub()

In [None]:
#curr_ds = now_ds.map(preprocess_dataset_s1, batched=True).remove_columns(now_ds['train'].column_names)
model = AutoModelForSeq2SeqLM.from_pretrained(base_checkpoint)

training_args = Seq2SeqTrainingArguments(
    output_dir=f"synthetics",
    learning_rate=3e-4,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=2*BATCH_SIZE,
    save_total_limit=2,
    num_train_epochs=8,
    report_to="none",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    eval_accumulation_steps=1,
    logging_steps=1,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=curr_ds["train"],
    eval_dataset=curr_ds["test"].select(range(350)),
    tokenizer=tokenizer,
    data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer),
)

trainer.train()

# have to do batched rouge computation otherwise not enough memory
rouge = evaluate.load("rouge")
metrics = {'rouge1': 0., 'rouge2': 0., 'rougeL': 0., 'rougeLsum': 0.}
count = 0
for i in range(0, len(curr_ds['test']), 100):
    count += 1
    (predictions, _), label_ids, _ = trainer.predict(test_dataset=curr_ds['test'].select(range(i, min(i+100, len(curr_ds['test'])))))
    predicted_token_ids = torch.argmax(torch.from_numpy(predictions), dim=-1)
    # delete stuff after EOS token
    for i in range(predicted_token_ids.shape[0]):
        ind = (predicted_token_ids[i] == 1).nonzero(as_tuple=True)[0]
        if ind.numel() != 0:
            predicted_token_ids[i, ind[0]:] = 1

    decoded_preds = tokenizer.batch_decode(predicted_token_ids, skip_special_tokens=True)
    labels = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    new_metrics = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)
    for k in new_metrics:
        metrics[k] += new_metrics[k]

for k in metrics:
        metrics[k] /= count

print(metrics['rouge1'])