In [None]:
# pip install peft

In [None]:
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainer, Seq2SeqTrainingArguments, DataCollatorForSeq2Seq,  EarlyStoppingCallback
import torch
from torch.utils.data import Dataset
from peft import LoraConfig, get_peft_model
import pandas as pd

In [None]:
model_name = "potsawee/t5-large-generation-race-QuestionAnswer"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

# Define LoRA configuration
lora_config = LoraConfig(
    r=16,  # Rank of low-rank matrices
    lora_alpha=32,  # Scaling factor
    target_modules=["q", "v"],  # Fine-tune attention layers
    lora_dropout=0.1,
    bias="none"
)

In [None]:
# Apply LoRA to the model
model = get_peft_model(model, lora_config)

# Check trainable parameters
model.print_trainable_parameters()

In [None]:
model.config.ignore_pad_token_for_loss = True

In [None]:
# Load dataset using pandas
splits = {
    'train': 'data/train-00000-of-00001.parquet',
    'validation': 'data/validation-00000-of-00001.parquet',
    'test': 'data/test-00000-of-00001.parquet'
}
train_df = pd.read_parquet("hf://datasets/allenai/sciq/" + splits["train"])
validation_df = pd.read_parquet(
    "hf://datasets/allenai/sciq/" + splits["validation"])

In [None]:
def preprocess_function(df):
    inputs = df["support"].tolist()
    targets = [q + " <sep> " + a for q,
               a in zip(df["question"], df["correct_answer"])]
    model_inputs = tokenizer(inputs, max_length=512,
                             truncation=True, padding="max_length")
    labels = tokenizer(targets, max_length=128,
                       truncation=True, padding="max_length")
    model_inputs["labels"] = labels["input_ids"]
    return model_inputs


train_data = preprocess_function(train_df)
validation_data = preprocess_function(validation_df)

In [None]:
class SciQDataset(Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __len__(self):
        return len(self.encodings["input_ids"])

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


train_dataset = SciQDataset(train_data)
validation_dataset = SciQDataset(validation_data)

In [None]:
early_stopping_callback = EarlyStoppingCallback(early_stopping_patience=2)

training_args = Seq2SeqTrainingArguments(
    output_dir="./t5_lora_sciq",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=3e-4,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    num_train_epochs=6,
    weight_decay=0.01,
    save_total_limit=2,
    predict_with_generate=True,
    logging_dir="./logs",
    logging_steps=100,
    fp16=True,
    load_best_model_at_end=True,  # Add this line for early stopping
    metric_for_best_model="eval_loss",  # Specify which metric to use for selecting the best model
    greater_is_better=False,
    lr_scheduler_type="linear",
    warmup_steps=500,
    label_names=["labels"],
)


trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=validation_dataset,
    tokenizer=tokenizer,
    callbacks=[early_stopping_callback]  # Add early stopping callback
)

In [None]:
trainer.train()

In [None]:
model.save_pretrained("./t5_finetuned_sciq")
tokenizer.save_pretrained("./t5_finetuned_sciq")

In [None]:
!zip -r t5_finetuned_sciq.zip ./t5_finetuned_sciq
from IPython.display import FileLink
FileLink(r't5_finetuned_sciq.zip')