<a href="https://colab.research.google.com/github/DreRnc/ExplainingExplanations/blob/ModData/Base.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Dataset : **E-SNLI**. \
Model : **Base T5**.

In [1]:
%load_ext autoreload
%autoreload 2
colab = False

In [2]:
if colab:
    !git clone https://github.com/DreRnc/ExplainingExplanations.git
    %cd ExplainingExplanations
    !git checkout seq2seq
    %pip install -r requirements_colab.txt
    

# 1.0 Preparation


Set parameters for the experiments.

In [3]:
MODEL = 't5-small'
    
sizes = {
    'n_train' : 1000,
    'n_val' : 1000,
    'n_test' : 1000
}

# Whether to use the mnli prompt on which the model is pretrained or not
USE_MNLI_PROMPT = False
EXPLANATION_FIRST = False

## 1.1 Loading Tokenizer

In [4]:
from transformers import T5Tokenizer

tokenizer = T5Tokenizer.from_pretrained(MODEL, truncation=True, padding=True)

  from .autonotebook import tqdm as notebook_tqdm


You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


## 1.2 Loading and Tokenizing Dataset

In [5]:
from datasets import load_dataset
from src.preprocess import prepare_dataset
from functools import partial
from src.utils import tokenize_function

In [6]:
dataset = load_dataset("esnli", download_mode="force_redownload")

Downloading data: 100%|██████████| 39.3M/39.3M [00:01<00:00, 31.0MB/s]
Downloading data: 100%|██████████| 1.62M/1.62M [00:00<00:00, 5.53MB/s]
Downloading data: 100%|██████████| 1.61M/1.61M [00:00<00:00, 9.94MB/s]
Generating train split: 100%|██████████| 549367/549367 [00:00<00:00, 2511131.99 examples/s]
Generating validation split: 100%|██████████| 9842/9842 [00:00<00:00, 2919401.70 examples/s]
Generating test split: 100%|██████████| 9824/9824 [00:00<00:00, 3005897.47 examples/s]


In [7]:
tokenize_mapping = partial(tokenize_function, tokenizer=tokenizer, use_mnli_format = USE_MNLI_PROMPT)

In [8]:
train_tok, valid_tok, test_tok = prepare_dataset(dataset, tokenize_mapping=tokenize_mapping, sizes = sizes)

Map: 100%|██████████| 1000/1000 [00:00<00:00, 15883.30 examples/s]
Map: 100%|██████████| 1000/1000 [00:00<00:00, 16506.18 examples/s]
Map: 100%|██████████| 1000/1000 [00:00<00:00, 16489.96 examples/s]


In [11]:
import torch
from functools import partial
import evaluate
from src.utils import compute_metrics, eval_pred_transform_accuracy
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer, T5ForConditionalGeneration, DataCollatorForSeq2Seq


In [12]:
if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
    
device

device(type='cuda')

# Model fine tuned with correct explanations

In [15]:
from src.utils import remove_explanation

In [16]:
model_ft_ex = T5ForConditionalGeneration.from_pretrained('task3_' + MODEL + '/best_model')
data_collator_ft = DataCollatorForSeq2Seq(tokenizer, model=model_ft_ex)

In [17]:
n_mistakes = 50

found_mistakes = []
i = 0
while len(found_mistakes) <  n_mistakes:
    input_ids = test_tok[i]['input_ids'].unsqueeze(0)
    output_ids = model_ft_ex.generate(input_ids, max_new_tokens=100)[0]
    pred = tokenizer.decode(output_ids, skip_special_tokens=True, max_length = 100)
    
    true = tokenizer.decode(test_tok[i]['labels'], skip_special_tokens=True, max_length = 100)

    if remove_explanation(pred) != true:
        mistake = test_tok[i].copy()

        del_keys = ['input_ids','explanation_1','explanation_2','explanation_3', 'labels', 'attention_mask']

        for key in del_keys:
            del mistake[key]
        
        mistake['pred'] = pred
        mistake['label'] = true

        found_mistakes.append(mistake)
         
    i +=1

for mistake in found_mistakes:
    print(mistake)

{'premise': 'This church choir sings to the masses as they sing joyous songs from the book at a church.', 'hypothesis': 'The church has cracks in the ceiling.', 'pred': 'label: contradiction explanation: The church choir cannot sing joyous songs from the book at a church if it has cracks in the ceiling.', 'label': 'neutral'}
{'premise': 'A woman with a green headscarf, blue shirt and a very big grin.', 'hypothesis': 'The woman has been shot.', 'pred': 'label: neutral explanation: A woman with a green headscarf, blue shirt and a very big grin does not indicate that she has been shot.', 'label': 'contradiction'}
{'premise': 'A statue at a museum that no seems to be looking at.', 'hypothesis': 'Tons of people are gathered around the statue.', 'pred': 'label: entailment explanation: A statue at a museum is a statue.', 'label': 'contradiction'}
{'premise': 'A land rover is being driven across a river.', 'hypothesis': 'A Land Rover is splashing water as it crosses a river.', 'pred': 'label: 

In [18]:
n_correct = 50

found_correct = []
i = 0
while len(found_correct) <  n_correct:
    input_ids = test_tok[i]['input_ids'].unsqueeze(0)
    output_ids = model_ft_ex.generate(input_ids, max_new_tokens=100)[0]
    pred = tokenizer.decode(output_ids, skip_special_tokens=True, max_length = 100)
    
    true = tokenizer.decode(test_tok[i]['labels'], skip_special_tokens=True, max_length = 100)

    if remove_explanation(pred) == true:
        correct = test_tok[i].copy()

        del_keys = ['input_ids','explanation_1','explanation_2','explanation_3', 'labels', 'attention_mask']

        for key in del_keys:
            del correct[key]
        
        correct['pred'] = pred
        correct['label'] = true

        found_correct.append(correct)
         
    i +=1

for correct in found_correct:
    print(correct)

{'premise': 'This church choir sings to the masses as they sing joyous songs from the book at a church.', 'hypothesis': 'The church is filled with song.', 'pred': 'label: entailment explanation: A church choir sings to the masses as they sing joyous songs from the book at a church is a rephrasing of the church is filled with song.', 'label': 'entailment'}
{'premise': 'This church choir sings to the masses as they sing joyous songs from the book at a church.', 'hypothesis': 'A choir singing at a baseball game.', 'pred': 'label: contradiction explanation: A church choir is not a baseball game.', 'label': 'contradiction'}
{'premise': 'A woman with a green headscarf, blue shirt and a very big grin.', 'hypothesis': 'The woman is young.', 'pred': 'label: neutral explanation: Not all women are young.', 'label': 'neutral'}
{'premise': 'A woman with a green headscarf, blue shirt and a very big grin.', 'hypothesis': 'The woman is very happy.', 'pred': 'label: entailment explanation: A big grin i

# Model fine tuned with shuffled explanations

In [19]:
from src.utils import remove_explanation

In [20]:
model_ft_shex = T5ForConditionalGeneration.from_pretrained('task4_' + MODEL + '/best_model')
data_collator_ft = DataCollatorForSeq2Seq(tokenizer, model=model_ft_shex)

In [21]:
n_mistakes = 10

found_mistakes = []
i = 0
while len(found_mistakes) <  n_mistakes:
    input_ids = test_tok[i]['input_ids'].unsqueeze(0)
    output_ids = model_ft_shex.generate(input_ids, max_new_tokens=100)[0]
    pred = tokenizer.decode(output_ids, skip_special_tokens=True, max_length = 100)
    
    true = tokenizer.decode(test_tok[i]['labels'], skip_special_tokens=True, max_length = 100)

    if remove_explanation(pred) != true:
        mistake = test_tok[i].copy()

        del_keys = ['input_ids','explanation_1','explanation_2','explanation_3', 'labels', 'attention_mask']

        for key in del_keys:
            del mistake[key]
        
        mistake['pred'] = pred
        mistake['label'] = true

        found_mistakes.append(mistake)
         
    i +=1

for mistake in found_mistakes:
    print(mistake)

{'premise': 'This church choir sings to the masses as they sing joyous songs from the book at a church.', 'hypothesis': 'The church has cracks in the ceiling.', 'pred': 'label: contradiction explanation: A man and woman stand next to a table covered in beer glasses and pitchers cannot be at a preschool.', 'label': 'neutral'}
{'premise': 'A woman with a green headscarf, blue shirt and a very big grin.', 'hypothesis': 'The woman is very happy.', 'pred': 'label: neutral explanation: A man and woman wearing t-shirt are standing next to a hole.', 'label': 'entailment'}
{'premise': 'One tan girl with a wool hat is running and leaning over an object, while another person in a wool hat is sitting on the ground.', 'hypothesis': 'A man watches his daughter leap', 'pred': 'label: contradiction explanation: A man and woman wearing t-shirt are standing next to a hole.', 'label': 'neutral'}
{'premise': 'Three firefighter come out of subway station.', 'hypothesis': 'Three firefighters putting out a f

In [22]:
n_correct = 10

found_correct = []
i = 0
while len(found_correct) <  n_correct:
    input_ids = test_tok[i]['input_ids'].unsqueeze(0)
    output_ids = model_ft_shex.generate(input_ids, max_new_tokens=100)[0]
    pred = tokenizer.decode(output_ids, skip_special_tokens=True, max_length = 100)
    
    true = tokenizer.decode(test_tok[i]['labels'], skip_special_tokens=True, max_length = 100)

    if remove_explanation(pred) == true:
        correct = test_tok[i].copy()

        del_keys = ['input_ids','explanation_1','explanation_2','explanation_3', 'labels', 'attention_mask']

        for key in del_keys:
            del correct[key]
        
        correct['pred'] = pred
        correct['label'] = true

        found_correct.append(correct)
         
    i +=1

for correct in found_correct:
    print(correct)

{'premise': 'This church choir sings to the masses as they sing joyous songs from the book at a church.', 'hypothesis': 'The church is filled with song.', 'pred': 'label: entailment explanation: A man and woman stand next to a table covered in beer glasses and pitchers cannot be at a preschool.', 'label': 'entailment'}
{'premise': 'This church choir sings to the masses as they sing joyous songs from the book at a church.', 'hypothesis': 'A choir singing at a baseball game.', 'pred': 'label: contradiction explanation: A man and woman wearing t-shirt are standing next to a hole.', 'label': 'contradiction'}
{'premise': 'A woman with a green headscarf, blue shirt and a very big grin.', 'hypothesis': 'The woman is young.', 'pred': 'label: neutral explanation: A man and woman wearing t-shirt are standing next to a hole.', 'label': 'neutral'}
{'premise': 'A woman with a green headscarf, blue shirt and a very big grin.', 'hypothesis': 'The woman has been shot.', 'pred': 'label: contradiction e

# Model fine tuned with compressed explanations

In [23]:
from src.utils import remove_explanation

In [24]:
model_ft_5 = T5ForConditionalGeneration.from_pretrained('task5_' + MODEL + '/best_model')
data_collator_ft = DataCollatorForSeq2Seq(tokenizer, model=model_ft_5)

In [25]:
n_mistakes = 10

found_mistakes = []
i = 0
while len(found_mistakes) <  n_mistakes:
    input_ids = test_tok[i]['input_ids'].unsqueeze(0)
    output_ids = model_ft_5.generate(input_ids, max_new_tokens=100)[0]
    pred = tokenizer.decode(output_ids, skip_special_tokens=True, max_length = 100)
    
    true = tokenizer.decode(test_tok[i]['labels'], skip_special_tokens=True, max_length = 100)

    if remove_explanation(pred) != true:
        mistake = test_tok[i].copy()

        del_keys = ['input_ids','explanation_1','explanation_2','explanation_3', 'labels', 'attention_mask']

        for key in del_keys:
            del mistake[key]
        
        mistake['pred'] = pred
        mistake['label'] = true

        found_mistakes.append(mistake)
         
    i +=1

for mistake in found_mistakes:
    print(mistake)

{'premise': 'A woman with a green headscarf, blue shirt and a very big grin.', 'hypothesis': 'The woman is very happy.', 'pred': 'label: neutral explanation: man woman genders leading Clydesdale drinking tea', 'label': 'entailment'}
{'premise': 'A land rover is being driven across a river.', 'hypothesis': 'A Land Rover is splashing water as it crosses a river.', 'pred': 'label: neutral explanation: man woman genders person hitching directs horse drawn cart has horse attached attach horse cart', 'label': 'entailment'}
{'premise': 'One tan girl with a wool hat is running and leaning over an object, while another person in a wool hat is sitting on the ground.', 'hypothesis': 'A man watches his daughter leap', 'pred': 'label: contradiction explanation: man woman genders leading Clydesdale drinking tea', 'label': 'neutral'}
{'premise': 'Male in a blue jacket decides to lay in the grass.', 'hypothesis': 'The guy wearing a blue jacket is laying on the green grass', 'pred': 'label: neutral exp

In [26]:
n_correct = 10

found_correct = []
i = 0
while len(found_correct) <  n_correct:
    input_ids = test_tok[i]['input_ids'].unsqueeze(0)
    output_ids = model_ft_5.generate(input_ids, max_new_tokens=100)[0]
    pred = tokenizer.decode(output_ids, skip_special_tokens=True, max_length = 100)
    
    true = tokenizer.decode(test_tok[i]['labels'], skip_special_tokens=True, max_length = 100)

    if remove_explanation(pred) == true:
        correct = test_tok[i].copy()

        del_keys = ['input_ids','explanation_1','explanation_2','explanation_3', 'labels', 'attention_mask']

        for key in del_keys:
            del correct[key]
        
        correct['pred'] = pred
        correct['label'] = true

        found_correct.append(correct)
         
    i +=1

for correct in found_correct:
    print(correct)

{'premise': 'This church choir sings to the masses as they sing joyous songs from the book at a church.', 'hypothesis': 'The church has cracks in the ceiling.', 'pred': 'label: neutral explanation: man woman genders leading Clydesdale drinking tea', 'label': 'neutral'}
{'premise': 'This church choir sings to the masses as they sing joyous songs from the book at a church.', 'hypothesis': 'The church is filled with song.', 'pred': 'label: entailment explanation: man woman genders', 'label': 'entailment'}
{'premise': 'This church choir sings to the masses as they sing joyous songs from the book at a church.', 'hypothesis': 'A choir singing at a baseball game.', 'pred': 'label: contradiction explanation: man woman genders leading Clydesdale drinking tea', 'label': 'contradiction'}
{'premise': 'A woman with a green headscarf, blue shirt and a very big grin.', 'hypothesis': 'The woman is young.', 'pred': 'label: neutral explanation: man woman genders leading Clydesdale drinking tea', 'label'