This is the Notebook for finetuning a model on SST-2 Dataset with Prompt Tuning PEFT method

In [1]:
import torch
from accelerate import Accelerator
from datasets import load_dataset, Dataset
from torch.utils.data import DataLoader
from transformers import T5Tokenizer, T5ForConditionalGeneration
from peft import get_peft_model, PromptTuningConfig, TaskType, PromptTuningInit
from transformers import AdamW, get_linear_schedule_with_warmup, default_data_collator

In [2]:
device = "cuda"
model_name_or_path = "google/flan-t5-xl"
tokenizer_name_or_path = "google/flan-t5-xl"

batch_size = 16


In [None]:
## Define the Tokenizer and Model
tokenizer = T5Tokenizer.from_pretrained(tokenizer_name_or_path)
model = T5ForConditionalGeneration.from_pretrained(model_name_or_path)

## Define the Prompt Tuning Configuration
peft_config = PromptTuningConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    prompt_tuning_init=PromptTuningInit.TEXT,
    num_virtual_tokens=200,
    prompt_tuning_init_text=PromptTuningInit.RANDOM,
    inference_mode=False,
    tokenizer_name_or_path=tokenizer_name_or_path
)

## Get the Prompt Tuning Model
model = get_peft_model(model, peft_config)
model = model.to("cuda")

In [None]:
model.print_trainable_parameters()

In [None]:
## Load the sst-2 dataset

dataset = load_dataset("glue", "sst2")
dataset

In [None]:
text_column = "sentence"
label_column = "text_label"
max_length = 256

label_mapping = {0: "negative", 1: "positive"}

classes = dataset["train"].features["label"].names

dataset = dataset.map(
    lambda x: {"text_label": [classes[label] for label in x["label"]]},
    batched=True,
    num_proc=1
)

dataset["train"][0]
# print(dataset["train"].features["label"].names)

In [None]:
def preprocess_function(examples):
    inputs = examples[text_column]
    targets = examples[label_column]
    model_inputs = tokenizer(inputs, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt")
    labels = tokenizer(targets, max_length=2, padding="max_length", truncation=True, return_tensors="pt")
    labels = labels["input_ids"]
    labels[labels == tokenizer.pad_token_id] = -100
    model_inputs["labels"] = labels
    return model_inputs


processed_datasets = dataset.map(
    preprocess_function,
    batched=True,
    num_proc=1,
    remove_columns=dataset["train"].column_names,
    load_from_cache_file=False,
    desc="Running tokenizer on dataset",
)

train_dataset = processed_datasets["train"]
eval_dataset = processed_datasets["validation"]

train_dataloader = DataLoader(
    train_dataset, shuffle=True, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True
)
eval_dataloader = DataLoader(eval_dataset, collate_fn=default_data_collator, batch_size=batch_size, pin_memory=True)

In [8]:
# optimizer and lr scheduler
lr = 1e-4
num_epochs = 2
optimizer = torch.optim.AdamW(model.parameters(), lr=lr)
lr_scheduler = get_linear_schedule_with_warmup(
    optimizer=optimizer,
    num_warmup_steps=0,
    num_training_steps=(len(train_dataloader) * num_epochs),
)

In [None]:
from accelerate import Accelerator
accelerator = Accelerator()

model, optimizer, train_dataloader, eval_dataloader = accelerator.prepare(
    model, optimizer, train_dataloader, eval_dataloader
)

model.device

In [None]:
from rich.progress import Progress
from rich.console import Console

console = Console()

with Progress() as progress:
    task = progress.add_task("[red]Training...", total=num_epochs)

    for epoch in range(num_epochs):
        epoch_task = progress.add_task(f"Epoch {epoch}", total=len(train_dataloader))
        model.train()
        losses = []
        for step, batch in enumerate(train_dataloader):
            outputs = model(**batch)
            loss = outputs.loss
            losses.append(loss.item())
            accelerator.backward(loss)
            optimizer.step()
            lr_scheduler.step()
            optimizer.zero_grad()        
            progress.update(epoch_task, advance=1)
            # progress.print(f"loss: {loss.item()}", end="\r")

        test_task = progress.add_task(f"Epoch {epoch}", total=len(eval_dataloader))
        model.eval()
        for batch in eval_dataloader:
            outputs = model(**batch)
            loss = outputs.loss
            progress.update(test_task, advance=1)
            # progress.print(f"eval_loss: {loss.item()}", end="\r")
            

        progress.update(task, advance=1)
        progress.print(f"epoch: {epoch} loss: {sum(losses) / len(losses)}")
        model.save_pretrained(f"tf-xl-prompt-tuning-sst2")

In [None]:
# ## Evaluate the model

# model.eval()
# eval_loss = 0
# predictions = []
# true = []

# for step, batch in enumerate(eval_dataloader):
#     input_ids, attention_mask, labels = batch["input_ids"], batch["attention_mask"], batch["labels"]
#     with torch.no_grad():
#         outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
#         loss = outputs.loss
#         eval_loss += loss.item()
#         logits = outputs.logits
#         predictions.extend(logits.argmax(-1).tolist())
#         true.extend(labels.tolist())

# eval_loss = eval_loss / len(eval_dataloader)
# print(f"Evaluation Loss: {eval_loss}")

# ## Calculate the Accuracy
# correct = 0
# total = 0

# for p, t in zip(predictions, true):
#     if p == t:
#         correct += 1
#     total += 1

# accuracy = correct / total
# print(f"Accuracy: {accuracy}")

# ## Confusion Matrix
# from sklearn.metrics import confusion_matrix
# import seaborn as sns
# import matplotlib.pyplot as plt

# label_mappping = {0: "Negative", 1: "Positive"}
# predictions = [label_mappping[p] for p in predictions]
# true = [label_mappping[t] for t in true]

# cm = confusion_matrix(true, predictions)
# sns.heatmap(cm, annot=True, xticklabels=label_mappping.values(), yticklabels=label_mappping.values())

In [None]:
for step, batch in enumerate(eval_dataloader):
    input_ids, attention_mask, labels = batch["input_ids"], batch["attention_mask"], batch["labels"]
    with torch.no_grad():
        outputs = model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        eval_loss += loss.item()
        logits = outputs.logits
        print(outputs.logits.argmax(-1))
        print(tokenizer.decode(outputs.logits.argmax(-1)[0]))
    break