# T5 Fine-tuning

### Import Libraries

In [1]:
from transformers import Seq2SeqTrainingArguments, Seq2SeqTrainer
import pandas as pd
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, DataCollatorForSeq2Seq
from datasets import load_dataset
import datasets



### Load Dataset function

In [2]:
def load_data (name, inpExpOutFunc): 
    dataset = load_dataset(name)
    train_data = dataset['train']
    test_data = dataset['validation']

    def preprocess_function(dataset):
        dataset["input_ids"] = []
        dataset["attention_mask"] = []
        dataset["labels"] = []
        dataset = dict(dataset)
        for index in range(len(dataset[list(dataset)[0]])): # 1000 is the default batch size
            inp, exp_out = inpExpOutFunc(dataset, index)
            model_inputs = tokenizer(inp, max_length=1024, truncation=True)
            labels = tokenizer(exp_out, max_length=1024, truncation=True)

            dataset["input_ids"].append(model_inputs["input_ids"])
            dataset["attention_mask"].append(model_inputs["attention_mask"])
            dataset["labels"].append(labels["input_ids"])
        return dataset
    
    train_data = train_data.map(preprocess_function, batched=True)
    test_data = test_data.map(preprocess_function, batched=True)

    return train_data, test_data

### Load Dataset

In [3]:
use_pytorch_training = False

# load tokenizer
tokenizer = AutoTokenizer.from_pretrained('google/t5-v1_1-small')
tokenizer.pad_token = tokenizer.eos_token

# load datasets needed
def sciq_extract (dataset, index): 
    return dataset['support'][index], dataset['question'][index]
train_data_sciq, test_data_sciq = load_data("sciq", sciq_extract) # scientific questions and answers

def squad_extract (dataset, index):
    return dataset['question'][index], dataset['answers'][index]["text"][0]
train_data_squad, test_data_squad = load_data('squad', squad_extract) # wikipedia questions and answers

def piqa_extract (dataset, index): 
    if dataset["label"][index] == 0:
        return dataset["goal"][index], dataset["sol1"][index]
    else:
        return dataset["goal"][index], dataset["sol2"][index]
train_data_piqa, test_data_piqa = load_data("piqa", piqa_extract) # piqa questions and answers, although used for common sense, used questions and answers

test_data = datasets.concatenate_datasets((test_data_sciq, test_data_squad, test_data_piqa))
train_data = datasets.concatenate_datasets((train_data_sciq, train_data_squad, train_data_piqa))

# keep only input_ids, attention_mask, and labels
def clean_dataset (dataset): 
    columns_remove = dataset.column_names
    columns_remove.remove("input_ids")
    columns_remove.remove("attention_mask")
    columns_remove.remove("labels")
    return dataset.remove_columns(columns_remove)

test_data = clean_dataset(test_data)
train_data = clean_dataset(train_data)

# load model
model = AutoModelForSeq2SeqLM.from_pretrained('google/t5-v1_1-small')

# visualize dataset
print(type(train_data)) # <class 'datasets.arrow_dataset.Dataset'>

Using custom data configuration default
Reusing dataset sciq (/home/eshaanb/.cache/huggingface/datasets/sciq/default/0.1.0/50e5c6e3795b55463819d399ec417bfd4c3c621105e00295ddb5f3633d708493)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/eshaanb/.cache/huggingface/datasets/sciq/default/0.1.0/50e5c6e3795b55463819d399ec417bfd4c3c621105e00295ddb5f3633d708493/cache-7149f66aad271c8b.arrow


  0%|          | 0/1 [00:00<?, ?ba/s]

Reusing dataset squad (/home/eshaanb/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453)


  0%|          | 0/2 [00:00<?, ?it/s]

Loading cached processed dataset at /home/eshaanb/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453/cache-91a658dd920251b1.arrow
Loading cached processed dataset at /home/eshaanb/.cache/huggingface/datasets/squad/plain_text/1.0.0/d6ec3ceb99ca480ce37cdd35555d6cb2511d223b9150cce08a837ef62ffea453/cache-01f138db708e2dc1.arrow
Reusing dataset piqa (/home/eshaanb/.cache/huggingface/datasets/piqa/plain_text/1.1.0/6c611c1a9bf220943c4174e117d3b660859665baf1d43156230116185312d011)


  0%|          | 0/3 [00:00<?, ?it/s]

Loading cached processed dataset at /home/eshaanb/.cache/huggingface/datasets/piqa/plain_text/1.1.0/6c611c1a9bf220943c4174e117d3b660859665baf1d43156230116185312d011/cache-a9766a0b7805db6e.arrow
Loading cached processed dataset at /home/eshaanb/.cache/huggingface/datasets/piqa/plain_text/1.1.0/6c611c1a9bf220943c4174e117d3b660859665baf1d43156230116185312d011/cache-e7eb7e926a85652b.arrow


<class 'datasets.arrow_dataset.Dataset'>


### Visualize Dataset

In [4]:
print(train_data)

Dataset({
    features: ['input_ids', 'attention_mask', 'labels'],
    num_rows: 115391
})


### Transformers' Seq2Seq Training

In [5]:

# see https://huggingface.co/docs/transformers/tasks/summarization

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=model)

training_args = Seq2SeqTrainingArguments(
    output_dir="./results",
    learning_rate=5e-4,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=10,
    eval_accumulation_steps=10,
    weight_decay=0.01,
    num_train_epochs=2,
    fp16=True,
    save_total_limit=1,
    logging_steps=50,
)
trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_data,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

trainer.train()

Using amp half precision backend
***** Running training *****
  Num examples = 115391
  Num Epochs = 2
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 10
  Gradient Accumulation steps = 10
  Total optimization steps = 23078
  nn.utils.clip_grad_norm_(


Step,Training Loss


KeyboardInterrupt: 

## Test

In [None]:
from random import shuffle, randint
# load model from checkpoint 
prompt = "Natural Language Processing is a field of artificial intelligence in which computers analyze, understand, and derive meaning from human language in a smart and useful way. By utilizing NLP, developers can organize and structure knowledge to perform tasks such as automatic summarization, translation, named entity recognition, relationship extraction, sentiment analysis, speech recognition, and topic segmentation. 'Apart from common word processor operations that treat text like a mere sequence of symbols, NLP considers the hierarchical structure of language: several words make a phrase, several phrases make a sentence and, ultimately, sentences convey ideas,' John Rehling. “By analyzing language for its meaning, NLP systems have long filled useful roles, such as correcting grammar, converting speech to text and automatically translating between languages.” NLP is used to analyze text, allowing machines to understand how humans speak. This human-computer interaction enables real-world applications like automatic text summarization, sentimental analysis, topic extraction, named entity recognition, parts-of-speech tagging, relationship extraction, stemming, and more. NLP is characterized as a difficult problem in computer science. Human language is rarely precise, or plainly spoken. To understand human language is to understand not only the words, but the concepts and how they’re linked together to create meaning. Despite language being one of the easiest things for the human mind to learn, the ambiguity of language is what makes natural language processing a difficult problem for computers to master."
PATH = "./results/checkpoint-17500"
tokenizer = AutoTokenizer.from_pretrained(PATH, local_files_only=True)
model = AutoModelForSeq2SeqLM.from_pretrained(PATH, local_files_only=True)
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
outputs = model.generate(input_ids)
print(f"{tokenizer.decode(outputs[0])}")

## Visualize Dataset

In [None]:
# read csv 
df = pd.read_csv("loss.csv", delimiter="\t")
print(df)
# matplotlib the loss
import matplotlib.pyplot as plt
plt.plot(df["Training Loss"])
plt.xlabel("Training Steps")
plt.ylabel("Loss")
plt.show()