In [1]:
import numpy as np
import random
import torch

seed = 42

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed) 
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(seed)

In [2]:
from datasets import load_dataset

# Load the e-SNLI dataset
dataset = load_dataset("esnli")

train_dataset = dataset['train']
eval_dataset = dataset['validation']
#test_dataset = dataset['test']

indices = list(range(0, len(train_dataset), 10))  # Select every 10th index
train_dataset = train_dataset.select(indices)

len(train_dataset), len(eval_dataset)#, len(test_dataset)

Reusing dataset esnli (/home/ec2-user/.cache/huggingface/datasets/esnli/plain_text/0.0.2/a160e6a02bbb8d828c738918dafec4e7d298782c334b5109af632fec6d779bbc)


  0%|          | 0/3 [00:00<?, ?it/s]

(54937, 9842)

In [9]:
# Dependency parsing

import spacy

# Load the English language model
nlp = spacy.load("en_core_web_sm")

def dependency_parse(sentence):
    doc = nlp(sentence)
    # Generate the dependency tree structure
    return " ".join([f"<{token.dep_}> {token.text}" for token in doc])

In [10]:
dependency_parse("the car is moving on the road.")

'<det> the <nsubj> car <aux> is <ROOT> moving <prep> on <det> the <pobj> road <punct> .'

In [11]:
label_dct = {0: "entailment", 1: "neutral", 2: "contradiction"}

In [12]:
from transformers import T5Tokenizer

tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-small")

# Preprocessing function
def preprocess(example):
    # Prepare input and output text
    input_text = f"Premise: example['premise'] Dependency: {dependency_parse(example['premise'])} Hypothesis: example['hypothesis'] Dependency: {dependency_parse(example['hypothesis'])} What is the relationship? Explain your answer."
    output_text = f"{label_dct[example['label']]}: {example['explanation_1']}. {example['explanation_2']}. {example['explanation_3']}."

    # Tokenize input and output
    input_encoding = tokenizer(input_text, truncation=True, padding="max_length", max_length=512, return_tensors="pt")
    output_encoding = tokenizer(output_text, truncation=True, padding="max_length", max_length=512, return_tensors="pt")

    # Create a dictionary to return
    return {
        "input_ids": input_encoding["input_ids"][0],  # Remove batch dimension
        "attention_mask": input_encoding["attention_mask"][0],  # Remove batch dimension
        "labels": output_encoding["input_ids"][0] # Remove batch dimension
    }


# Apply preprocessing
train_dataset = train_dataset.map(
    preprocess,
    remove_columns=['premise', 'hypothesis', 'label', 'explanation_1', 'explanation_2', 'explanation_3'],
)
eval_dataset = eval_dataset.map(
    preprocess,
    remove_columns=['premise', 'hypothesis', 'label', 'explanation_1', 'explanation_2', 'explanation_3'],
)

You are using the default legacy behaviour of the <class 'transformers.models.t5.tokenization_t5.T5Tokenizer'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thouroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


  0%|          | 0/54937 [00:00<?, ?ex/s]

  0%|          | 0/9842 [00:00<?, ?ex/s]

In [13]:
train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
eval_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "labels"])
print(train_dataset[0]['input_ids'].shape)  # Should show (512,)
print(train_dataset[0]['attention_mask'].shape)  # Should show (512,)
print(train_dataset[0]['labels'].shape)  # Should show (512,)

torch.Size([512])
torch.Size([512])
torch.Size([512])


In [15]:
from transformers import T5ForConditionalGeneration

model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small").cuda()

In [16]:
from transformers import (
    T5Tokenizer,
    T5ForConditionalGeneration,
    Trainer,
    TrainingArguments,
    AdamW,
    get_scheduler,
)

# Define custom optimizer
learning_rate = 0.001
optimizer = AdamW(
    model.parameters(),
    lr=learning_rate,
    betas=(0.9, 0.999),
    eps=1e-08,
)

# Training arguments
training_args = TrainingArguments(
    output_dir="./flan_t5_esnli",
    evaluation_strategy="epoch",  # Evaluate at the end of every epoch
    save_strategy="epoch",  # Save at the end of every epoch
    per_device_train_batch_size=8,  # Train batch size
    per_device_eval_batch_size=8,  # Evaluation batch size
    num_train_epochs=12,  # Number of epochs
    learning_rate=learning_rate,  # Learning rate
    lr_scheduler_type="linear",  # Linear learning rate scheduler
    warmup_ratio=0.05,  # Warmup ratio
    weight_decay=0.01,  # Weight decay
    save_total_limit=12,  # Keep only the last 2 checkpoints
    fp16=torch.cuda.is_available(),  # Use FP16 if a GPU is available
    seed=seed,
    load_best_model_at_end=True,  # Load the best model at the end of training
    metric_for_best_model="loss",  # Optimize for loss
    greater_is_better=False,
    report_to=[],
)

# Define Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    optimizers=(optimizer, None),
)



In [17]:
# Start training
trainer.train()

# Save the final model
model.save_pretrained("./final_flan_t5_esnli")
tokenizer.save_pretrained("./final_flan_t5_esnli")

# Evaluate on the test set
test_results = trainer.evaluate(eval_dataset=eval_dataset)
print("Test results:", test_results)



Epoch,Training Loss,Validation Loss
1,0.0536,0.861022
2,0.0481,0.258028
3,0.0435,0.296071
4,0.0405,0.246219
5,0.0376,0.264902
6,0.0356,0.393109
7,0.0335,0.336889
8,0.0314,0.336917
9,0.0296,0.322672
10,0.028,0.334201


Checkpoint destination directory ./flan_t5_esnli/checkpoint-1717 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./flan_t5_esnli/checkpoint-3434 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./flan_t5_esnli/checkpoint-5151 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./flan_t5_esnli/checkpoint-6868 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./flan_t5_esnli/checkpoint-8585 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./flan_t5_esnli/checkpoint-10302 already exists and is non-empty.Saving will proceed but saved results may be invalid.
Checkpoint destination directory ./flan_t5_esnli/checkpoint-12019 already exists and is non-e

Test results: {'eval_loss': 0.2462185174226761, 'eval_runtime': 207.0557, 'eval_samples_per_second': 47.533, 'eval_steps_per_second': 1.488, 'epoch': 12.0}


In [25]:
false = "f"
true = "t"
yoyo = {
    "FraudulentPaystubs::Templates::One": false,
    "FraudulentPaystubs::Templates::Two": false,
    "FraudulentPaystubs::Templates::Three": false,
    "FraudulentPaystubs::Templates::Four": true,
    "FraudulentPaystubs::Templates::Five": false,
    "FraudulentPaystubs::Templates::Six": false,
    "FraudulentPaystubs::Templates::Seven": false,
    "FraudulentPaystubs::Templates::ThePaystubsAdpAndRectangle": false,
    "FraudulentPaystubs::Templates::ThePaystubsSquare": false,
    "FraudulentPaystubs::Templates::OnlinePaystubAdvanced": false,
    "FraudulentPaystubs::Templates::PaystubOnlineNeat": false,
    "FraudulentPaystubs::Templates::RealCheckstubsAdvThree": false,
    "FraudulentPaystubs::Templates::PaystubDirect": false,
    "FraudulentPaystubs::VisionDatas::SequenceNumbers": false,
    "FraudulentPaystubs::Templates::Fifteen": false,
    "FraudulentPaystubs::Templates::Sixteen": false,
    "FraudulentPaystubs::Templates::Seventeen": false,
    "FraudulentPaystubs::Templates::Eighteen": false,
    "FraudulentPaystubs::Templates::Nineteen": false,
    "FraudulentPaystubs::Templates::Twenty": false,
    "FraudulentPaystubs::Templates::TwentyOne": false,
    "FraudulentPaystubs::Templates::TwentyTwo": false,
    "FraudulentPaystubs::Templates::TwentyThree": false,
    "FraudulentPaystubs::Templates::TwentyFour": false,
    "FraudulentPaystubs::Templates::TwentyFive": false,
    "FraudulentPaystubs::Templates::TwentySix": false,
    "FraudulentPaystubs::Templates::TwentySeven": false,
    "FraudulentPaystubs::Templates::Thirty": false,
    "FraudulentPaystubs::Templates::ThirtyOne": false,
    "FraudulentPaystubs::Templates::ThirtyTwo": false,
    "FraudulentPaystubs::Templates::ThirtyThree": false,
    "FraudulentPaystubs::Templates::ThirtyFour": false,
    "FraudulentPaystubs::Templates::ThirtyFive": false,
    "FraudulentPaystubs::Templates::ThirtySix": false,
    "FraudulentPaystubs::Templates::ThirtySeven": false,
    "FraudulentPaystubs::Templates::ThirtyEight": false,
    "FraudulentPaystubs::Templates::ThirtyNine": false,
    "FraudulentPaystubs::Templates::Forty": false,
    "FraudulentPaystubs::Templates::FortyOne": false,
    "FraudulentPaystubs::Templates::FiftyOne": false,
    "FraudulentPaystubs::Templates::FiftyFour": false,
    "FraudulentPaystubs::Templates::FiftyFive": false,
    "FraudulentPaystubs::Templates::FiftySeven": false,
    "FraudulentPaystubs::Templates::FiftyNine": false,
    "FraudulentPaystubs::Templates::SixtyOne": false,
    "FraudulentPaystubs::Templates::SixtySeven": false,
    "FraudulentPaystubs::Templates::SixtyEight": false,
    "FraudulentPaystubs::Templates::SixtyNine": false
  }

In [26]:
yoyo = list(yoyo.keys())
yoyo = {x: str(i+1) for i,x in enumerate(yoyo)}

In [27]:
import json
json.dumps(yoyo)

'{"FraudulentPaystubs::Templates::One": "1", "FraudulentPaystubs::Templates::Two": "2", "FraudulentPaystubs::Templates::Three": "3", "FraudulentPaystubs::Templates::Four": "4", "FraudulentPaystubs::Templates::Five": "5", "FraudulentPaystubs::Templates::Six": "6", "FraudulentPaystubs::Templates::Seven": "7", "FraudulentPaystubs::Templates::ThePaystubsAdpAndRectangle": "8", "FraudulentPaystubs::Templates::ThePaystubsSquare": "9", "FraudulentPaystubs::Templates::OnlinePaystubAdvanced": "10", "FraudulentPaystubs::Templates::PaystubOnlineNeat": "11", "FraudulentPaystubs::Templates::RealCheckstubsAdvThree": "12", "FraudulentPaystubs::Templates::PaystubDirect": "13", "FraudulentPaystubs::VisionDatas::SequenceNumbers": "14", "FraudulentPaystubs::Templates::Fifteen": "15", "FraudulentPaystubs::Templates::Sixteen": "16", "FraudulentPaystubs::Templates::Seventeen": "17", "FraudulentPaystubs::Templates::Eighteen": "18", "FraudulentPaystubs::Templates::Nineteen": "19", "FraudulentPaystubs::Template

In [None]:
{"FraudulentPaystubs::Templates::One": "1", "FraudulentPaystubs::Templates::Two": "2", "FraudulentPaystubs::Templates::Three": "3", "FraudulentPaystubs::Templates::Four": "4", "FraudulentPaystubs::Templates::Five": "5", "FraudulentPaystubs::Templates::Six": "6", "FraudulentPaystubs::Templates::Seven": "7", "FraudulentPaystubs::Templates::ThePaystubsAdpAndRectangle": "8", "FraudulentPaystubs::Templates::ThePaystubsSquare": "9", "FraudulentPaystubs::Templates::OnlinePaystubAdvanced": "10", "FraudulentPaystubs::Templates::PaystubOnlineNeat": "11", "FraudulentPaystubs::Templates::RealCheckstubsAdvThree": "12", "FraudulentPaystubs::Templates::PaystubDirect": "13", "FraudulentPaystubs::VisionDatas::SequenceNumbers": "14", "FraudulentPaystubs::Templates::Fifteen": "15", "FraudulentPaystubs::Templates::Sixteen": "16", "FraudulentPaystubs::Templates::Seventeen": "17", "FraudulentPaystubs::Templates::Eighteen": "18", "FraudulentPaystubs::Templates::Nineteen": "19", "FraudulentPaystubs::Templates::Twenty": "20", "FraudulentPaystubs::Templates::TwentyOne": "21", "FraudulentPaystubs::Templates::TwentyTwo": "22", "FraudulentPaystubs::Templates::TwentyThree": "23", "FraudulentPaystubs::Templates::TwentyFour": "24", "FraudulentPaystubs::Templates::TwentyFive": "25", "FraudulentPaystubs::Templates::TwentySix": "26", "FraudulentPaystubs::Templates::TwentySeven": "27", "FraudulentPaystubs::Templates::Thirty": "28", "FraudulentPaystubs::Templates::ThirtyOne": "29", "FraudulentPaystubs::Templates::ThirtyTwo": "30", "FraudulentPaystubs::Templates::ThirtyThree": "31", "FraudulentPaystubs::Templates::ThirtyFour": "32", "FraudulentPaystubs::Templates::ThirtyFive": "33", "FraudulentPaystubs::Templates::ThirtySix": "34", "FraudulentPaystubs::Templates::ThirtySeven": "35", "FraudulentPaystubs::Templates::ThirtyEight": "36", "FraudulentPaystubs::Templates::ThirtyNine": "37", "FraudulentPaystubs::Templates::Forty": "38", "FraudulentPaystubs::Templates::FortyOne": "39", "FraudulentPaystubs::Templates::FiftyOne": "40", "FraudulentPaystubs::Templates::FiftyFour": "41", "FraudulentPaystubs::Templates::FiftyFive": "42", "FraudulentPaystubs::Templates::FiftySeven": "43", "FraudulentPaystubs::Templates::FiftyNine": "44", "FraudulentPaystubs::Templates::SixtyOne": "45", "FraudulentPaystubs::Templates::SixtySeven": "46", "FraudulentPaystubs::Templates::SixtyEight": "47", "FraudulentPaystubs::Templates::SixtyNine": "48"}