<a href="https://colab.research.google.com/github/DreRnc/ExplainingExplanations/blob/main/Explanations.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Dataset : **E-SNLI**. \
Model : **Small T5**.

In [None]:
colab = True

In [None]:
if colab:
    !git clone https://github.com/DreRnc/ExplainingExplanations.git
    %cd ExplainingExplanations
    %pip install -r requirements.txt
    !git checkout seq2seq

# 1.0 Preparation


In [None]:
size = {
    'n_train' : 10000
    'n_val' : 1000
    'n_test' : 1000
}

NUM_EPOCHS = 5

## 1.1 Loading Tokenizer

In [None]:
from transformers import T5Tokenizer

tokenizer = T5Tokenizer.from_pretrained("t5-small", truncation=True, padding=True)

## 1.2 Loading and Tokenizing Dataset

In [None]:
from datasets import load_dataset
from preprocess import prepare_dataset
from functools import partial
from utils import tokenize_function

In [None]:
dataset = load_dataset("esnli", download_mode="force_redownload")

In [None]:
tokenize_mapping= partial(tokenize_function, tokenizer=tokenizer)

In [None]:
train_tok, valid_tok, test_tol = prepare_dataset(dataset, tokenize_mapping=tokenize_mapping, sizes = sizes)

# 2.0 Tasks

In [None]:
import torch
from functools import partial
import evaluate
from src.utils import compute_metrics, eval_pred_transform_accuracy
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, T5ForConditionalGeneration, DataCollatorForSeq2Seq


In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
device

In [None]:
transform_accuracy = partial(eval_pred_transform_accuracy, tokenizer = tokenizer)
compute_accuracy = partial(compute_metrics, pred_transform=transform_accuracy, metric = evaluate.load('accuracy'))

## 2.1 Task 1: Zero-shot evaluation

In [None]:
model = T5ForConditionalGeneration.from_pretrained("t5-small")
data_collator = DataCollatorForSeq2Seq(tokenizer, model=model)

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="task1",
    predict_with_generate=True,
    per_device_eval_batch_size=16,
    generation_max_length=32,
    metric_for_best_model="accuracy",
)

In [None]:
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_tok,
    eval_dataset=valid_tok,
    compute_metrics=compute_accuracy,
    data_collator=data_collator,
    tokenizer=tokenizer,
)

In [None]:
trainer.evaluate(test_tok)

## 2.2 Task 2: Fine tuning without explanations

In [None]:
model_ft = T5ForConditionalGeneration.from_pretrained("t5-small")
data_collator_ft = DataCollatorForSeq2Seq(tokenizer, model=model_ft)

In [None]:
training_args_ft = Seq2SeqTrainingArguments(
    output_dir="task2",
    evaluation_strategy="epoch",
    num_train_epochs=NUM_EPOCHS,
    predict_with_generate=True,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    generation_max_length=32,
    metric_for_best_model="accuracy",
)

In [None]:
trainer_ft = Seq2SeqTrainer(
    model=model_ft,
    args=training_args_ft,
    train_dataset=train_tok,
    eval_dataset=valid_tok,
    compute_metrics=compute_accuracy,
    data_collator=data_collator_ft,
    tokenizer=tokenizer,
)

In [None]:
trainer_ft.train()

In [None]:
trainer_ft.evaluate(test_tok)

## 2.3 Task 3: Making the model generate explanations

We need to give as labels the label and the explanation tokenized.

### Preparing the dataset with labelled explanations

In [None]:
from utils import tokenize_function_ex

In [None]:
dataset_explanations = load_dataset("esnli", download_mode="force_redownload")

In [None]:
tokenize_mapping_ex = partial(tokenize_function_ex, tokenizer=tokenizer)

In [None]:
train_tok_ex, valid_tok_ex, test_tok_ex = prepare_dataset(dataset=dataset_explanations, tokenize_mapping=tokenize_mapping_ex, sizes=sizes)

In [None]:
train_tok_ex.features

### Fine Tuning

In [None]:
transform_accuracy_ex = partial(eval_pred_transform_accuracy, tokenizer = tokenizer, remove_explanations_from_label = True)
compute_accuracy_removing_explanations = partial(compute_metrics, pred_transform=transform_accuracy_ex, metric = evaluate.load('accuracy'))

In [None]:
model_ft_ex = T5ForConditionalGeneration.from_pretrained("t5-small")
data_collator_ft_ex = DataCollatorForSeq2Seq(tokenizer, model=model_ft_ex)

In [None]:
training_args_ft_ex = Seq2SeqTrainingArguments(
    output_dir="task3",
    evaluation_strategy="epoch",
    num_train_epochs=NUM_EPOCHS,
    predict_with_generate=True,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    generation_max_length=128,
    metric_for_best_model="accuracy",
)

In [None]:
trainer_ft_ex = Seq2SeqTrainer(
    model=model_ft_ex,
    args=training_args_ft_ex,
    train_dataset=train_tok_ex,
    eval_dataset=valid_tok_ex,
    compute_metrics=compute_accuracy_removing_explanations,
    data_collator=data_collator_ft_ex,
    tokenizer=tokenizer,
)

In [None]:
trainer_ft_ex.train()

In [None]:
trainer_ft_ex.evaluate(test_tok_ex)