# PEFT Tutorial
*(A bulk of the material of this tutorial is taken from Sebastian Raschka's [Code Lora from Scratch](https://lightning.ai/lightning-ai/studios/code-lora-from-scratch).)*

In [None]:
import os
import time
from functools import partial

import lightning as L
import torch
import torch.nn.functional as F
from custom_lightning_module import CustomLightningModule
from datasets import load_dataset
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.loggers import CSVLogger
from peft import LoraConfig, TaskType, get_peft_model
from torch import nn
from torch.utils.data import DataLoader, Dataset
from transformers import AutoModelForSequenceClassification, AutoTokenizer

## Configuring Paths


In [None]:
DATASET_DIR = "../../data/imdb/"
SAVED_MODEL_DIR = "/projects/fta_bootcamp/trained_models/peft_demo/"
OUTPUT_DIR = "../../scratch/peft/" # main directory of the the demo output
CHECKPOINT_DIR = f"{OUTPUT_DIR}checkpoints" # where to save checkpoints
MODEL_NAME = "distilbert-base-uncased"

## Our Custom LoRA Layer <a id="LoRA_Anchor"></a>

In [None]:
torch.set_float32_matmul_precision("medium")

class LoRALayer(torch.nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        std_dev = 1 / torch.sqrt(torch.tensor(rank).float())
        self.W_a = torch.nn.Parameter(torch.randn(in_dim, rank) * std_dev)
        self.W_b = torch.nn.Parameter(torch.zeros(rank, out_dim))
        self.alpha = alpha

    def forward(self, x):
        pass ### TODO: TODO: implement the forward pass of lora ###


class LinearWithLoRA(torch.nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha,
        )

    def forward(self, x):
        pass ### TODO: TODO: implement the forward pass of lora layer ###

In [None]:
torch.manual_seed(123)

# a simple linear layer with 10 inputs and 1 output
# requires_grad=False makes it non-trainable
with torch.no_grad():
    linear_layer = torch.nn.Linear(10, 1)

# a simple example input
x = torch.rand((1, 10))

linear_layer(x)

In [None]:
lora_layer = LinearWithLoRA(linear=linear_layer, rank=8, alpha=1)
lora_layer(x)

In [None]:
lora_layer.lora.W_b = torch.nn.Parameter(lora_layer.lora.W_b + 0.01 * x[0])
lora_layer(x)

## Loading the Dataset into DataFrames

In [None]:
imdb_dataset = load_dataset(
    "csv",
    data_files={
        "train": os.path.join(DATASET_DIR, "train.csv"),
        "validation": os.path.join(DATASET_DIR, "val.csv"),
        "test": os.path.join(DATASET_DIR, "test.csv"),
    },
)

print(imdb_dataset)

## Loading Tokenizer

In [None]:
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
print("Tokenizer input max length:", tokenizer.model_max_length)
print("Tokenizer vocabulary size:", tokenizer.vocab_size)

## Tokenizing Data

In [None]:
def tokenize_text(batch):
    return tokenizer(batch["text"], truncation=True, padding=True)

imdb_tokenized = imdb_dataset.map(tokenize_text, batched=True, batch_size=None)
del imdb_dataset
imdb_tokenized.set_format("torch", columns=["input_ids", "attention_mask", "label"])
os.environ["TOKENIZERS_PARALLELISM"] = "false"

## Setting Up DataLoaders

In [None]:
class IMDBDataset(Dataset):
    def __init__(self, dataset_dict, partition_key="train"):
        self.partition = dataset_dict[partition_key]

    def __getitem__(self, index):
        return self.partition[index]

    def __len__(self):
        return self.partition.num_rows

In [None]:
train_dataset = IMDBDataset(imdb_tokenized, partition_key="train")
val_dataset = IMDBDataset(imdb_tokenized, partition_key="validation")
test_dataset = IMDBDataset(imdb_tokenized, partition_key="test")

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=12,
    shuffle=True,
    num_workers=4,
)

val_loader = DataLoader(
    dataset=val_dataset,
    batch_size=12,
    num_workers=4,
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=12,
    num_workers=4,
)

## Counting Number of Trainable Parameters Function

In [None]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

## Finetunning Last Two Layers

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME, num_labels=2)

print(f"Total number of trainable parameters for the base model: {count_parameters(model):,}" )

Freeze all the layers:

In [None]:
for param in model.parameters():
    param.requires_grad = False

Unfreeze the last two layers:

In [None]:
for param in model.pre_classifier.parameters():
    param.requires_grad = True

for param in model.classifier.parameters():
    param.requires_grad = True

print(f"Total number of trainable parameters: {count_parameters(model):,}" )

In [None]:
lightning_model = CustomLightningModule(model)
callbacks = [
    ModelCheckpoint(
        dirpath=CHECKPOINT_DIR,
        filename="last_two",
        save_top_k=1, # save top 1 model
        mode="max",
        monitor="val_acc",
    ),
]

logger = CSVLogger(save_dir="logs/", name="my-model")

trainer = L.Trainer(
    max_epochs=3,
    callbacks=callbacks,
    accelerator="gpu",
    precision="16-mixed",
    devices=1,
    logger=logger,
    log_every_n_steps=10,
)

In [None]:
# Comment cell below if you don't want to go through the training process. You can just load a trained model in the next cell.

start = time.time()
trainer.fit(model=lightning_model,
            train_dataloaders=train_loader,
            val_dataloaders=val_loader)

end = time.time()
elapsed = end - start
print(f"Time elapsed {elapsed/60:.2f} min")

In [None]:
# Load from a saved model
lightning_model = CustomLightningModule.load_from_checkpoint(checkpoint_path="/projects/fta_bootcamp/trained_models/peft_demo//last_two.ckpt", model=model)

# train_acc = trainer.validate(lightning_model, dataloaders=train_loader, verbose=False)
# val_acc = trainer.validate(lightning_model, dataloaders=val_loader, verbose=False)
test_acc = trainer.test(lightning_model, dataloaders=test_loader, verbose=False)

# print(f"Train acc: {train_acc[0]['val_acc']*100:2.2f}%")
# print(f"Val acc:   {val_acc[0]['val_acc']*100:2.2f}%")
print(f"Test acc:  {test_acc[0]['accuracy']*100:2.2f}%")

## Enter LoRA!

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME, num_labels=2)

for param in model.parameters():
    param.requires_grad = False

Let's use our [LoRA layer](#LoRA_Anchor) implementation from before. Here's our current model *before* adding LoRA layers:

In [None]:
model

Now let's wrap the query and value layers of transformer blocks with LoRA.

In [None]:
def apply_adaptation_layer(model, adaptation_layer, lora_r, lora_alpha, config):
    assign_lora = partial(adaptation_layer, rank=lora_r, alpha=lora_alpha)

    for layer in model.distilbert.transformer.layer:
        if config.get("lora_query"):
            pass ### TODO: TODO: look at the model architecture and and use assign_lora function ###
        if config.get("lora_key"):
            pass ### TODO: TODO: look at the model architecture and and use assign_lora function ###
        if config.get("lora_value"):
            pass ### TODO: TODO: look at the model architecture and and use assign_lora function ###
        if config.get("lora_projection"):
            pass ### TODO: TODO: look at the model architecture and and use assign_lora function ###
        if config.get("lora_mlp"):
            ### look at the model architecture and and use assign_lora function (make sure you apply to both linear layers in fnn) ###
            pass
    if config.get("lora_head"):
        ### look at the model architecture and and use assign_lora function. Apply to both pre_classifier and classifier layers) ###
        pass

In [None]:
config = {
    "lora_query": True,
    "lora_key": False,
    "lora_value": True,
    "lora_projection": False,
    "lora_mlp": False,
    "lora_head": False,
}
apply_adaptation_layer(model, adaptation_layer=LinearWithLoRA, lora_r=8, lora_alpha=16, config=config)

Let's look at the model after the LoRA layers are added:

In [None]:
model

In [None]:
# Check if linear layers are frozen
for name, param in model.named_parameters():
    print(f"{name}: {param.requires_grad}")

In [None]:
print(f"Total number of trainable parameters: {count_parameters(model):,}" )

## Fine-Tune with LoRA

In [None]:
lightning_model = CustomLightningModule(model)
callbacks = [
    ModelCheckpoint(
        dirpath=CHECKPOINT_DIR,
        filename="lora",
        save_top_k=1, # save top 1 model
        mode="max",
        monitor="val_acc",
    ),
]
logger = CSVLogger(save_dir="logs/", name="my-model")

trainer = L.Trainer(
    max_epochs=3,
    callbacks=callbacks,
    accelerator="gpu",
    precision="16-mixed",
    devices=1,
    logger=logger,
    log_every_n_steps=10,
)

In [None]:
# Comment cell below if you don't want to go through the training process. You can just load a trained model in the next cell.

start = time.time()
trainer.fit(model=lightning_model,
            train_dataloaders=train_loader,
            val_dataloaders=val_loader)

end = time.time()
elapsed = end - start
print(f"Time elapsed {elapsed/60:.2f} min")

In [None]:
# Load from a saved model
lightning_model = CustomLightningModule.load_from_checkpoint(checkpoint_path="/projects/fta_bootcamp/trained_models/peft_demo/lora.ckpt", model=model)

# train_acc = trainer.validate(lightning_model, dataloaders=train_loader, verbose=False)
# val_acc = trainer.validate(lightning_model, dataloaders=val_loader, verbose=False)
test_acc = trainer.test(lightning_model, dataloaders=test_loader, verbose=False)

# print(f"Train acc: {train_acc[0]['val_acc']*100:2.2f}%")
# print(f"Val acc:   {val_acc[0]['val_acc']*100:2.2f}%")
print(f"Test acc:  {test_acc[0]['accuracy']*100:2.2f}%")

## Using HF's LoRA

We can replace our custom LoRA implementation with an implementation from the [peft library](https://github.com/huggingface/peft). Peft is an open-source, one-stop-shop library from HuggingFace for *parameter efficient fine-tuning* (PEFT) and is integrated with the their [transformers library](https://github.com/huggingface/transformers) for easy model training and inference. 

Here's a sample snippet for how to prepare a model for PEFT training with LoRA. We can easily fine-tune the DistillBert model we had before with this implementation of the LoRA layer, instead of our custom layer.

In [None]:
peft_config = LoraConfig(
    task_type=TaskType.SEQ_2_SEQ_LM,
    inference_mode=False,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_lin", "v_lin"],
)

model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

## What's this [DoRA](https://arxiv.org/pdf/2402.09353) thing I keep hearing about?

In [None]:
model = AutoModelForSequenceClassification.from_pretrained(
    MODEL_NAME, num_labels=2)

for param in model.parameters():
    param.requires_grad = False

In [None]:
# Code inspired by https://github.com/catid/dora/blob/main/dora.py
class LinearWithDoRA(torch.nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha,
        )

        self.m = nn.Parameter(
            self.linear.weight.norm(p=2, dim=0, keepdim=True))

    def forward(self, x):
        lora = self.lora.W_a @ self.lora.W_b
        combined_weight = self.linear.weight + self.lora.alpha*lora.T
        column_norm = combined_weight.norm(p=2, dim=0, keepdim=True)
        V = combined_weight / column_norm
        new_weight = self.m * V
        return F.linear(x, new_weight, self.linear.bias)

In [None]:
config = {
    "lora_query": True,
    "lora_key": False,
    "lora_value": True,
    "lora_projection": False,
    "lora_mlp": False,
    "lora_head": False,
}
apply_adaptation_layer(model, adaptation_layer=LinearWithDoRA, lora_r=8, lora_alpha=16, config=config)

In [None]:
model

In [None]:
print(f"Total number of trainable parameters: {count_parameters(model):,}" )

## Finetune with DoRA

In [None]:
lightning_model = CustomLightningModule(model)

callbacks = [
    ModelCheckpoint(
        dirpath="",
        filename="dora",
        save_top_k=1, # save top 1 model
        mode="max",
        monitor="val_acc",
    ),
]

logger = CSVLogger(save_dir="logs/", name="my-model")

trainer = L.Trainer(
    max_epochs=3,
    callbacks=callbacks,
    accelerator="gpu",
    precision="16-mixed",
    devices=1,
    logger=logger,
    log_every_n_steps=10,
)

In [None]:
# Comment cell below if you don't want to go through the training process. You can just load a trained model in the next cell.

start = time.time()
trainer.fit(model=lightning_model,
            train_dataloaders=train_loader,
            val_dataloaders=val_loader)

end = time.time()
elapsed = end - start
print(f"Time elapsed {elapsed/60:.2f} min")

In [None]:
# Load from a saved model
lightning_model = CustomLightningModule.load_from_checkpoint(checkpoint_path="/projects/fta_bootcamp/trained_models/peft_demo/dora.ckpt", model=model)

# train_acc = trainer.validate(lightning_model, dataloaders=train_loader, verbose=False)
# val_acc = trainer.validate(lightning_model, dataloaders=val_loader, verbose=False)
test_acc = trainer.test(lightning_model, dataloaders=test_loader, verbose=False)

# print(f"Train acc: {train_acc[0]['val_acc']*100:2.2f}%")
# print(f"Val acc:   {val_acc[0]['val_acc']*100:2.2f}%")
print(f"Test acc:  {test_acc[0]['accuracy']*100:2.2f}%")