<a href="https://colab.research.google.com/github/DreRnc/ExplainingExplanations/blob/torch/Explanations.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Dataset : **E-SNLI**. \
Model : **Small T5**.

In [None]:
colab = False

In [1]:
if colab:
    !git clone https://github.com/DreRnc/ExplainingExplanations.git
    %cd ExplainingExplanations
    !pip install -r requirements.txt

SyntaxError: invalid syntax. Maybe you meant '==' or ':=' instead of '='? (1711734592.py, line 1)

# 1.0 Preparation


## 1.1 Loading Dataset

In [None]:
from datasets import load_dataset

dataset = load_dataset("esnli")

In [None]:
training_set = dataset["train"]
validation_set = dataset["validation"]
test_set = dataset["test"]

print("Shape of training_set: ", training_set.shape)
print("Shae of validation_set: ", validation_set.shape)
print("Shape of test_set: ", test_set.shape)

In [None]:
training_set[0]

In [None]:
n_train = n_valid = n_test = 5000

train_small = training_set.select(range(n_train))
valid_small = validation_set.select(range(n_valid))
test_small = test_set.select(range(n_test))

print("Shape of train_small: ", train_small.shape)
print("Shape of valid_small: ", valid_small.shape)
print("Shape of test_small: ", test_small.shape)

## 1.2 Loading T5 Model

In [None]:
from transformers import T5Tokenizer, T5ForConditionalGeneration

tokenizer = T5Tokenizer.from_pretrained("t5-small")
model = T5ForConditionalGeneration.from_pretrained("t5-small")

Test **zero-shot** on a random task.

In [None]:
input_ids = tokenizer(
    "translate English to French: Hello Dre, I think the English version is ok for us.",
    return_tensors="pt",
).input_ids
outputs = model.generate(input_ids, max_new_tokens=100)

print(tokenizer.decode(outputs[0], skip_special_tokens=True, max_length=100))

## 1.3 Zero-shot example to Verify Everything is Working

In [None]:
from src.utils import generate_prompt_mnli

In [None]:
example = training_set[0]
example

Generating the prompt:

<b><u> mnli hypothesis: </b></u> The St. Louis Cardinals have always won. <b><u> premise: </b></u> yeah well losing is i mean i’m i’m originally from Saint Louis and Saint Louis Cardinals when they were there were uh a mostly a losing team but

Output:
* 0: Entailment
* 1: Neutral
* 2: Contradiction

In [None]:
prompt = generate_prompt_mnli(example)
prompt

In [None]:
input_ids = tokenizer(prompt, return_tensors="pt").input_ids

outputs = model.generate(input_ids)

print(tokenizer.decode(outputs[0], skip_special_tokens=True))

## 1.4  Tokenize the dataset

In [None]:
train_small.info

In [None]:
train_small[0]["prompt"] = "hello"

train_small[0]

In [None]:
train_small.features

In [None]:
from functools import partial
from src.utils import tokenize_function
import time

In [None]:
tokenize_mapping = partial(tokenize_function, tokenizer=tokenizer)

In [None]:
time_init = time.time()
train_small_tokenized = train_small.map(tokenize_mapping, batched=True).with_format("torch")
valid_small_tokenized = valid_small.map(tokenize_mapping, batched=True).with_format("torch")
test_small_tokenized = test_small.map(tokenize_mapping, batched=True).with_format("torch")
time_end = time.time()

print("Time taken to tokenize: ", time_end - time_init)
print("Shape of train_small_tokenized: ", train_small_tokenized.shape)
print("Shape of valid_small_tokenized: ", valid_small_tokenized.shape)
print("Shape of test_small_tokenized: ", test_small_tokenized.shape)

In [None]:
train_small_tokenized = train_small_tokenized.rename_column("label", "labels")
train_small_tokenized = train_small_tokenized.remove_columns(["premise", "hypothesis", "explanation_1", "explanation_2", "explanation_3"])

valid_small_tokenized = valid_small_tokenized.rename_column("label", "labels")
valid_small_tokenized = valid_small_tokenized.remove_columns(["premise", "hypothesis", "explanation_1", "explanation_2", "explanation_3"])

test_small_tokenized = test_small_tokenized.rename_column("label", "labels")
test_small_tokenized = test_small_tokenized.remove_columns(["premise", "hypothesis", "explanation_1", "explanation_2", "explanation_3"])

In [None]:
train_small_tokenized.features

In [None]:
from torch.utils.data import DataLoader

In [None]:
train_dataloader = DataLoader(train_small_tokenized, shuffle=True, batch_size=8)
eval_dataloader = DataLoader(valid_small_tokenized, batch_size=8)
test_dataloader = DataLoader(test_small_tokenized, batch_size=8)

# 2.0 Task 1: Zero-shot evaluation

In [None]:
from src.utils import convert_labels
from tqdm import tqdm
import torch
import evaluate

In [None]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

model.to(device)
device

In [None]:
metric = evaluate.load("accuracy")
model.eval()

progress_bar = tqdm(range(len(test_dataloader)))

for batch in test_dataloader:
    batch = {k: v.to(device) for k, v in batch.items()}
    with torch.no_grad():
        outputs = model.generate(**batch)
        output = tokenizer.batch_decode(outputs, skip_special_tokens=True, max_length=100)
        preds = convert_labels(output)

    metric.add_batch(
        predictions=preds,
        references=batch["labels"],
    )

    progress_bar.update(1)

metric.compute()

# 3.0 Task 2: Fine tuning without explanations

In [None]:
from torch.optim import AdamW

optimizer = AdamW(model.parameters(), lr=5e-5)

In [None]:
import torch.nn as nn
loss = nn.CrossEntropyLoss()

In [None]:
from transformers import get_scheduler

num_epochs = 3
num_training_steps = num_epochs * len(train_dataloader)
lr_scheduler = get_scheduler(
    name="linear", optimizer=optimizer, num_warmup_steps=0, num_training_steps=num_training_steps
)

In [None]:
progress_bar = tqdm(range(num_training_steps))

model.train()
for epoch in range(num_epochs):
    for batch in train_dataloader:

        optimizer.zero_grad()

        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model.generate(**batch)
        output = tokenizer.batch_decode(outputs, skip_special_tokens=True, max_length=100)
        preds = convert_labels(output)

        loss_value = loss(preds, batch["labels"])
        loss_value.backward()

        optimizer.step()
        lr_scheduler.step()

        progress_bar.update(1)

# 4.0 Task 4: Making the model generate explanations