### Install transformers lib if not done yet

In [None]:
!pip install transformers

### Import all necessary python workpackages

In [None]:
import torch
from torch import nn
from torch.utils.data import DataLoader, random_split
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
from functools import partial
import torch.optim as optim
import time

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Help functions

In [None]:
#Show number of trainable parameters and model architecture
def showModel(model):

    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    total_params = sum(p.numel() for p in model.parameters())
    param_size = sum(p.numel() * p.element_size() for p in model.parameters())
    buffer_size = sum(b.numel() * b.element_size() for b in model.buffers())
    total_size = (param_size + buffer_size) / 1024**2  # Convert to MB
    print(f"Trainable parameters: {trainable_params}")
    print(f"Total parameters: {total_params}")
    print(f"Fraction trainable: {trainable_params/total_params:.4f}")
    print(model)

#LoRALayer class allow to create the low-rank neural networks to be trained
class LoRALayer(torch.nn.Module):
    def __init__(self, in_dim, out_dim, rank, alpha):
        super().__init__()
        self.A = nn.Linear(in_dim, rank, bias=False)
        self.B = nn.Linear(rank, out_dim, bias=False)
        nn.init.normal_(self.A.weight, 0,0.01)
        nn.init.zeros_(self.B.weight)
        self.alpha = alpha/rank

    def forward(self, x):
        #x = self.alpha * (x @ self.A @ self.B)
        x = self.B(self.A(x)) * self.alpha
        return x

#Class to build the LoRA layer along with the linear layer it "replaces"
class LinearWithLoRA(torch.nn.Module):
    def __init__(self, linear, rank, alpha):
        super().__init__()
        self.linear = linear
        self.lora = LoRALayer(
            linear.in_features, linear.out_features, rank, alpha
        )

    def forward(self, x):
        return self.linear(x) + self.lora(x)

#Evaluate a model on a given data loader, calculating the accuracy as the ratio of good predictions
def evaluate(model, loader):
    model.eval()
    correct, total = 0, 0
    with torch.no_grad():
        for batch in loader:
            batch = {k: v.to(device) for k, v in batch.items()}
            outputs = model(**batch)
            preds = torch.argmax(outputs.logits, dim=-1)
            correct += (preds == batch["labels"]).sum().item()
            total += batch["labels"].size(0)
    return 100 * correct / total

### Setup
1. Load a tiny bert model (pre-trained)
2. Freeze original weights
3. Load the sst2 dataset (we only use 500 entries split between train and test but you can change these parameters

In [None]:
#Very very small models for test, and the output will be according to two label, negative or positive sentiment
#We load the pre-trained model including classification heads
model_name = "arnabdhar/tinybert-imdb"   # small, already fine-tuned
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name)
      
showModel(model)

#Original weights are frozen in the model
for param in model.parameters():
    param.requires_grad = False


#Load sst2 dataset and select 5000 first samples for fast testing (you can adjust this paramater) 
#The dataset is formatted for the optimization 
dataset = load_dataset("glue", "sst2", split="train[:500]")  # small subset for demo
dataset = dataset.map(lambda batch: tokenizer(batch["sentence"], truncation=True, padding="max_length", max_length=32), batched=True)
dataset = dataset.rename_column("label", "labels")
dataset.set_format("torch", columns=["input_ids", "attention_mask", "labels"])

# Split into train/test
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16)

### LoRA PETF
1. set LoRA parameters
2. replace linear layers by LoRA blocks
3. set the trainable weights/parameters
4. train the model
5. evaluate
6. compare it with the base model

In [None]:
# LoRA parameters, alpha used in LoRA layer is actually alpha/rank as the effective scaling factor is usuallt dependant of the rank
#!!!TODO!! play with these parameters to see how good / long is the fine-tuning
lora_rank = 4
lora_alpha = 8

#Replace the linear layers of the original model
#!!! TODO !!! Check the original model to replace other layers (you can also check how much this impacts on perfromance as nothing forces you to apply LoRA for all)
for layer in model.bert.encoder.layer:
    layer.attention.....

#We the last layer is trainable
for name, param in model.named_parameters():
    if 'classifier' in name:
        param.requires_grad = True

showModel(model)

# !!TODO ensure that LoRA layer parameters are trainable and the classifier head (see above for inspiration)





#Configure the optimizer and ensures all trainable parameters will be trained
# !!! you can change the number of epochs for testing
criterion = nn.CrossEntropyLoss() #usual metric for classification
optimizer = optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=1e-4) 
num_epochs = 20

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

start_time = time.time()  # record overall start time
model.train()
for epoch in range(num_epochs):
    running_loss = 0.0
    for batch in train_loader:
        batch = {k: v.to(device) for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        running_loss += loss.item()
    print(f"Epoch {epoch+1}/{num_epochs} - Avg Loss: {running_loss/len(train_loader):.4f}")
total_time = time.time() - start_time
print(f"Total training time: {total_time/60:.2f} minutes ({total_time:.1f} seconds)")


# !! TODO evaluate accuracy of fine-tuned model

# !! TODO evaluate accuracy of based model
base_model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)
base_model.to(device)
for param in base_model.parameters():
    param.requires_grad = False
