In [1]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq, Seq2SeqTrainingArguments, Seq2SeqTrainer, optimization
import os 
from datasets import load_dataset
import pandas as pd
import numpy as np 
from copy import deepcopy
from torch.optim import AdamW
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
checkpoint = "t5-small"

tokenizer = AutoTokenizer.from_pretrained(checkpoint)
model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

In [3]:
dream_checkpoint = "RicoBorra/DREAM-t5-small"

dream_tokenizer = AutoTokenizer.from_pretrained(dream_checkpoint)
dream_model = AutoModelForSeq2SeqLM.from_pretrained(dream_checkpoint)
dream_data_collator = DataCollatorForSeq2Seq(tokenizer=dream_tokenizer, model=dream_checkpoint)

## FLUTE Data extraction and processing
Using the few instructions from the Git Readme to have the same initial data

In [4]:
# Only the train split is available on HuggingFace
dataset = load_dataset("ColumbiaNLP/FLUTE")

In [40]:
def compute_dream_elaborations(dataset, dream_model, path_to_save) :
    if os.path.exists(path_to_save) :
        elaborations = pd.read_csv(path_to_save)
    else :
        elaborations = pd.DataFrame(columns=['premise_emotion', 'premise_motivation', 'premise_consequence', 'premise_rot',
                                            'hypothesis_emotion', 'hypothesis_motivation', 'hypothesis_consequence', 'hypothesis_rot'])
        for sentence_type in ['premise', 'hypothesis'] :
            for dream_dimension in ['emotion', 'motivation', 'consequence', 'rot'] :
                inputs = ['[SITUATION] ' + sentence + ' [QUERY] ' + dream_dimension for sentence in dataset[sentence_type]]
                tokens = torch.tensor(tokenizer(inputs, padding='longest').input_ids)
                output_tokens = dream_model.generate(tokens, max_new_tokens=100)
                elaborations[sentence_type + '_' + dream_dimension] = tokenizer.batch_decode(output_tokens, skip_special_tokens=True)
        
        # Make sure each sentence ends with a point
        elaborations = elaborations.applymap(lambda x : x + '.' if not x.endswith('.') else x)
        
        elaborations.to_csv(path_to_save)

    return elaborations

In [None]:
dreams = compute_dream_elaborations(dataset['train'], dream_model, "dream_elaborations.csv")

def add_combined_cols(entry):
    
    premise = entry["premise"].strip()
    hypothesis = entry["hypothesis"].strip()
    
    if not premise.endswith("."):
        premise += "."
    assert(premise.endswith("."))
    if not hypothesis.endswith("."):
        hypothesis += "."
    assert(hypothesis.endswith("."))
    
    # Columns for System 1
    entry["premise_hypothesis"] = 'Premise: ' + premise + ' Hypothesis: ' + hypothesis + 'Is there a contradiction or entailment between the premise and hypothesis ?'
    entry["label_explanation"] = 'Label: ' + entry["label"] + '. Explanation: ' + entry["explanation"]

    # Columns for System 2
    entry["premise_hypothesis_system_2"] = 'Premise: ' + premise + ' Hypothesis: ' + hypothesis + 'What is the type of figurative language involved? Is there a contradiction or entailment between the premise and hypothesis ?'
    entry["type_label_explanation"] = 'Type: ' + entry["type"] + '. Label: ' + entry["label"] + '. Explanation: ' + entry["explanation"]
    
    # Columns for Systems 3
    for dream_dimension in ['emotion', 'motivation', 'consequence', 'rot'] :
        entry["premise_hypothesis_" + dream_dimension] = 'Premise: ' + premise + '[' + dream_dimension.capitalize() + ']' + dreams['premise_' + dream_dimension].strip() + \
                    ' Hypothesis: ' + hypothesis + '[' + dream_dimension.capitalize() + ']' + dreams['hypothesis_' + dream_dimension] + 'Is there a contradiction or entailment between the premise and hypothesis ?'
    entry["premise_hypothesis_all_dims"] = 'Premise: ' + premise + \
                '[Emotion]' + dreams['premise_emotion'].strip() + \
                '[Motivation]' + dreams['premise_motivation'].strip() + \
                '[Consequence]' + dreams['premise_consequence'].strip() + \
                '[Rot]' + dreams['premise_rot'].strip() + \
                ' Hypothesis: ' + hypothesis + \
                '[Emotion]' + dreams['hypothesis_emotion'].strip() + \
                '[Motivation]' + dreams['hypothesis_motivation'].strip() + \
                '[Consequence]' + dreams['hypothesis_consequence'].strip() + \
                '[Rot]' + dreams['hypothesis_rot'].strip()
    return entry
# combine columns
combined_cols_dataset = dataset['train'].map(add_combined_cols)

# create train test split because given data has only train data
# splits are shuffled by default
dataset_train_test = combined_cols_dataset.train_test_split(test_size=0.2, seed=42)

In [8]:
import evaluate
rouge = evaluate.load("rouge")

import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predicted_token_ids = torch.argmax(torch.from_numpy(predictions[0]), dim=-1)
    decoded_preds = tokenizer.batch_decode(predicted_token_ids, skip_special_tokens=True)
    labels = np.where(labels != -100, labels, tokenizer.pad_token_id)
    decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
    result = rouge.compute(predictions=decoded_preds, references=decoded_labels, use_stemmer=True)

    prediction_lens = [np.count_nonzero(pred != tokenizer.pad_token_id) for pred in predictions]
    result["gen_len"] = np.mean(prediction_lens)

    return {k: round(v, 4) for k, v in result.items()}

## System 1 : Normal classifier

In [9]:
model_s1 = deepcopy(model)

In [6]:
def preprocess_dataset(examples):
    model_inputs = tokenizer(examples['premise_hypothesis'])
    labels = tokenizer(examples['label_explanation'])
    model_inputs['labels'] = labels['input_ids']
    return model_inputs

In [7]:
tokenized_ds = dataset_train_test.map(preprocess_dataset, batched=True)
tokenized_ds = tokenized_ds.remove_columns(dataset_train_test['train'].column_names)

In [10]:
'''The following parameters were taken from the DREAM-FLUTE paper (only the number of epochs has been increased because the model is smaller)'''
training_args = Seq2SeqTrainingArguments(
    output_dir="D:\Documents\PoliTo\Deep NLP\Project\S1Model_more_accurate",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    seed=42,
    #weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=8,
    load_best_model_at_end=True,
    #eval_accumulation_steps=8,
    #fp16=True,
    #push_to_hub=True,
    adam_beta1=0.9,
    adam_beta2=0.999,
    adam_epsilon=1e-08,
    lr_scheduler_type='linear'
)

trainer = Seq2SeqTrainer(
    model=model_s1,
    args=training_args,
    train_dataset=tokenized_ds["train"],
    eval_dataset=tokenized_ds["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    #compute_metrics=compute_metrics
)

trainer.train()

  0%|          | 0/24112 [00:00<?, ?it/s]You're using a T5TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
  2%|▏         | 501/24112 [01:00<44:48,  8.78it/s]

{'loss': 2.3937, 'learning_rate': 4.896317186463172e-05, 'epoch': 0.17}


  4%|▍         | 1001/24112 [02:02<47:53,  8.04it/s] 

{'loss': 2.0177, 'learning_rate': 4.792634372926344e-05, 'epoch': 0.33}


  6%|▌         | 1501/24112 [03:07<41:30,  9.08it/s]  

{'loss': 1.9162, 'learning_rate': 4.688951559389516e-05, 'epoch': 0.5}


  8%|▊         | 2001/24112 [04:14<45:05,  8.17it/s]  

{'loss': 1.8441, 'learning_rate': 4.585268745852688e-05, 'epoch': 0.66}


 10%|█         | 2501/24112 [05:17<47:17,  7.62it/s]  

{'loss': 1.8173, 'learning_rate': 4.4815859323158594e-05, 'epoch': 0.83}


 12%|█▏        | 3001/24112 [06:22<40:52,  8.61it/s]  

{'loss': 1.8065, 'learning_rate': 4.3779031187790315e-05, 'epoch': 1.0}


                                                    
 12%|█▎        | 3014/24112 [06:48<41:56,  8.38it/s]

{'eval_loss': 1.5989887714385986, 'eval_runtime': 24.1193, 'eval_samples_per_second': 62.481, 'eval_steps_per_second': 31.261, 'epoch': 1.0}


 15%|█▍        | 3501/24112 [07:47<43:26,  7.91it/s]   

{'loss': 1.7179, 'learning_rate': 4.274220305242203e-05, 'epoch': 1.16}


 17%|█▋        | 4001/24112 [08:56<44:10,  7.59it/s]  

{'loss': 1.7005, 'learning_rate': 4.170537491705375e-05, 'epoch': 1.33}


 19%|█▊        | 4501/24112 [09:53<36:40,  8.91it/s]

{'loss': 1.7, 'learning_rate': 4.066854678168547e-05, 'epoch': 1.49}


 21%|██        | 5001/24112 [10:55<36:56,  8.62it/s]  

{'loss': 1.667, 'learning_rate': 3.963171864631719e-05, 'epoch': 1.66}


 23%|██▎       | 5501/24112 [12:02<39:06,  7.93it/s]  

{'loss': 1.6764, 'learning_rate': 3.8594890510948907e-05, 'epoch': 1.82}


 25%|██▍       | 6001/24112 [13:09<35:09,  8.59it/s]

{'loss': 1.6123, 'learning_rate': 3.755806237558063e-05, 'epoch': 1.99}


                                                    
 25%|██▌       | 6028/24112 [13:46<37:19,  8.08it/s]

{'eval_loss': 1.5124155282974243, 'eval_runtime': 33.65, 'eval_samples_per_second': 44.785, 'eval_steps_per_second': 22.407, 'epoch': 2.0}


 27%|██▋       | 6501/24112 [14:58<43:05,  6.81it/s]   

{'loss': 1.626, 'learning_rate': 3.652123424021234e-05, 'epoch': 2.16}


 29%|██▉       | 7001/24112 [16:04<36:23,  7.84it/s]  

{'loss': 1.567, 'learning_rate': 3.548440610484406e-05, 'epoch': 2.32}


 31%|███       | 7501/24112 [17:13<38:22,  7.22it/s]  

{'loss': 1.5735, 'learning_rate': 3.4447577969475784e-05, 'epoch': 2.49}


 33%|███▎      | 8001/24112 [18:18<34:50,  7.71it/s]

{'loss': 1.5688, 'learning_rate': 3.3410749834107505e-05, 'epoch': 2.65}


 35%|███▌      | 8501/24112 [19:25<34:31,  7.54it/s]

{'loss': 1.5799, 'learning_rate': 3.237392169873922e-05, 'epoch': 2.82}


 37%|███▋      | 9001/24112 [20:40<41:49,  6.02it/s]

{'loss': 1.5713, 'learning_rate': 3.133709356337094e-05, 'epoch': 2.99}


                                                    
 38%|███▊      | 9042/24112 [21:23<40:45,  6.16it/s]

{'eval_loss': 1.4648537635803223, 'eval_runtime': 36.7605, 'eval_samples_per_second': 40.995, 'eval_steps_per_second': 20.511, 'epoch': 3.0}


 39%|███▉      | 9501/24112 [22:39<43:07,  5.65it/s]   

{'loss': 1.5161, 'learning_rate': 3.0300265428002654e-05, 'epoch': 3.15}


 41%|████▏     | 10001/24112 [23:57<28:06,  8.36it/s] 

{'loss': 1.5249, 'learning_rate': 2.9263437292634375e-05, 'epoch': 3.32}


 44%|████▎     | 10501/24112 [24:58<31:17,  7.25it/s]

{'loss': 1.4821, 'learning_rate': 2.8226609157266093e-05, 'epoch': 3.48}


 46%|████▌     | 11001/24112 [26:04<26:11,  8.34it/s]  

{'loss': 1.4917, 'learning_rate': 2.7189781021897807e-05, 'epoch': 3.65}


 48%|████▊     | 11501/24112 [27:15<23:06,  9.09it/s]

{'loss': 1.5155, 'learning_rate': 2.615295288652953e-05, 'epoch': 3.82}


 50%|████▉     | 12002/24112 [28:15<22:04,  9.14it/s]

{'loss': 1.5253, 'learning_rate': 2.5116124751161246e-05, 'epoch': 3.98}


                                                     
 50%|█████     | 12056/24112 [28:45<30:56,  6.49it/s]

{'eval_loss': 1.434415578842163, 'eval_runtime': 24.5448, 'eval_samples_per_second': 61.398, 'eval_steps_per_second': 30.719, 'epoch': 4.0}


 52%|█████▏    | 12501/24112 [29:37<20:18,  9.53it/s]   

{'loss': 1.4731, 'learning_rate': 2.4079296615792967e-05, 'epoch': 4.15}


 54%|█████▍    | 13001/24112 [30:40<20:35,  8.99it/s]

{'loss': 1.4942, 'learning_rate': 2.3042468480424688e-05, 'epoch': 4.31}


 56%|█████▌    | 13501/24112 [31:45<19:47,  8.94it/s]

{'loss': 1.4639, 'learning_rate': 2.2005640345056406e-05, 'epoch': 4.48}


 58%|█████▊    | 14001/24112 [32:53<19:56,  8.45it/s]

{'loss': 1.4378, 'learning_rate': 2.0968812209688123e-05, 'epoch': 4.64}


 60%|██████    | 14501/24112 [33:55<17:37,  9.09it/s]

{'loss': 1.4826, 'learning_rate': 1.9931984074319844e-05, 'epoch': 4.81}


 62%|██████▏   | 15001/24112 [34:56<16:53,  8.99it/s]

{'loss': 1.4389, 'learning_rate': 1.8895155938951562e-05, 'epoch': 4.98}


                                                     
 62%|██████▎   | 15070/24112 [35:28<20:49,  7.23it/s]

{'eval_loss': 1.4145468473434448, 'eval_runtime': 24.2577, 'eval_samples_per_second': 62.125, 'eval_steps_per_second': 31.083, 'epoch': 5.0}


 64%|██████▍   | 15501/24112 [36:20<17:59,  7.98it/s]   

{'loss': 1.4424, 'learning_rate': 1.785832780358328e-05, 'epoch': 5.14}


 66%|██████▋   | 16001/24112 [37:24<15:24,  8.77it/s]

{'loss': 1.3914, 'learning_rate': 1.6821499668214997e-05, 'epoch': 5.31}


 68%|██████▊   | 16501/24112 [38:26<18:07,  7.00it/s]

{'loss': 1.4264, 'learning_rate': 1.5784671532846715e-05, 'epoch': 5.47}


 71%|███████   | 17001/24112 [39:30<13:05,  9.05it/s]

{'loss': 1.448, 'learning_rate': 1.4747843397478434e-05, 'epoch': 5.64}


 73%|███████▎  | 17502/24112 [40:27<12:13,  9.01it/s]

{'loss': 1.4369, 'learning_rate': 1.3711015262110152e-05, 'epoch': 5.81}


 75%|███████▍  | 18001/24112 [41:29<12:08,  8.39it/s]

{'loss': 1.4389, 'learning_rate': 1.2674187126741871e-05, 'epoch': 5.97}


                                                     
 75%|███████▌  | 18084/24112 [42:05<11:21,  8.85it/s]

{'eval_loss': 1.4025543928146362, 'eval_runtime': 25.9616, 'eval_samples_per_second': 58.047, 'eval_steps_per_second': 29.043, 'epoch': 6.0}


 77%|███████▋  | 18501/24112 [42:56<10:46,  8.68it/s]   

{'loss': 1.3917, 'learning_rate': 1.163735899137359e-05, 'epoch': 6.14}


 79%|███████▉  | 19001/24112 [43:56<10:03,  8.47it/s]

{'loss': 1.413, 'learning_rate': 1.0600530856005308e-05, 'epoch': 6.3}


 81%|████████  | 19501/24112 [45:03<09:24,  8.17it/s]

{'loss': 1.398, 'learning_rate': 9.563702720637027e-06, 'epoch': 6.47}


 83%|████████▎ | 20001/24112 [46:08<08:11,  8.36it/s]

{'loss': 1.3996, 'learning_rate': 8.526874585268747e-06, 'epoch': 6.64}


 85%|████████▌ | 20501/24112 [47:13<06:57,  8.66it/s]

{'loss': 1.4054, 'learning_rate': 7.490046449900464e-06, 'epoch': 6.8}


 87%|████████▋ | 21001/24112 [48:17<05:45,  9.01it/s]

{'loss': 1.4127, 'learning_rate': 6.453218314532184e-06, 'epoch': 6.97}


                                                     
 88%|████████▊ | 21098/24112 [49:00<05:53,  8.52it/s]

{'eval_loss': 1.3936963081359863, 'eval_runtime': 30.2886, 'eval_samples_per_second': 49.755, 'eval_steps_per_second': 24.894, 'epoch': 7.0}


 89%|████████▉ | 21501/24112 [50:00<06:25,  6.77it/s]  

{'loss': 1.4169, 'learning_rate': 5.416390179163902e-06, 'epoch': 7.13}


 91%|█████████ | 22001/24112 [51:14<04:59,  7.05it/s]

{'loss': 1.3759, 'learning_rate': 4.379562043795621e-06, 'epoch': 7.3}


 93%|█████████▎| 22501/24112 [52:28<04:37,  5.80it/s]

{'loss': 1.3879, 'learning_rate': 3.3427339084273395e-06, 'epoch': 7.47}


 95%|█████████▌| 23001/24112 [53:41<03:20,  5.54it/s]

{'loss': 1.3855, 'learning_rate': 2.305905773059058e-06, 'epoch': 7.63}


 97%|█████████▋| 23501/24112 [54:52<01:32,  6.58it/s]

{'loss': 1.3851, 'learning_rate': 1.2690776376907765e-06, 'epoch': 7.8}


100%|█████████▉| 24001/24112 [55:56<00:16,  6.67it/s]

{'loss': 1.3895, 'learning_rate': 2.3224950232249503e-07, 'epoch': 7.96}


                                                     
100%|██████████| 24112/24112 [56:41<00:00,  7.79it/s]

{'eval_loss': 1.3940967321395874, 'eval_runtime': 30.05, 'eval_samples_per_second': 50.15, 'eval_steps_per_second': 25.092, 'epoch': 8.0}


100%|██████████| 24112/24112 [56:43<00:00,  7.08it/s]

{'train_runtime': 3403.4563, 'train_samples_per_second': 14.167, 'train_steps_per_second': 7.085, 'train_loss': 1.5525400064287393, 'epoch': 8.0}





TrainOutput(global_step=24112, training_loss=1.5525400064287393, metrics={'train_runtime': 3403.4563, 'train_samples_per_second': 14.167, 'train_steps_per_second': 7.085, 'train_loss': 1.5525400064287393, 'epoch': 8.0})

In [16]:
model_s1 = AutoModelForSeq2SeqLM.from_pretrained("D:\Documents\PoliTo\Deep NLP\Project\S1Model_more_accurate\checkpoint-24112")
i = "Premise: Today I crashed my car. Hypothesis: I felt like a champion when I crashed my car."
t = tokenizer(i, return_tensors='pt').input_ids
t = t.to(model_s1.device)
o = model_s1.generate(t, max_new_tokens = 100)
d = tokenizer.decode(o[0])
d 

'<pad> Label: Contradiction. Explanation: Accidentally driving a car is not a good thing and so someone feeling like a champion when they crash it is not a good thing.</s>'

## System 2 : Predict type of figurative language