In [None]:
!pip install torch
!pip install safetensors



In [None]:
from lora_transformer import LoraConfig, LoraModel
from typing import Literal
import torch
from transformers import AutoTokenizer,AutoModelForSequenceClassification, get_linear_schedule_with_warmup, DataCollatorWithPadding
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm
from sklearn.metrics import accuracy_score
from torch.cuda.amp import autocast, GradScaler
import os

In [None]:
device = torch.device("cuda")

In [None]:
torch.set_float32_matmul_precision('high')

In [None]:
print(device)

cuda


In [None]:
MODEL_NAME = "roberta-base"
TASK = "sst2"
BATCH_SIZE = 256
LR = 5e-4
EPOCHS = 60
MAX_LEN = 512
RANK = 8
ALPHA = 16
BIAS = "none"
DROPOUT = 0.0
TARGET_MODULES = ["query", "value"]
EXCLUDE_MODULES =  ["classifier"]

In [None]:
dataset = load_dataset("glue", TASK)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [None]:
def preprocess_function(examples):
    return tokenizer(examples["sentence"], truncation=True, max_length=MAX_LEN)

In [None]:
tokenized_datasets = dataset.map(preprocess_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["sentence", "idx"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")

Map:   0%|          | 0/872 [00:00<?, ? examples/s]

In [None]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
NUM_WORKERS = os.cpu_count()

train_dataloader = DataLoader(tokenized_datasets["train"],
    shuffle=True,
    batch_size=BATCH_SIZE,
    collate_fn=data_collator,
    num_workers = NUM_WORKERS)
eval_dataloader = DataLoader(tokenized_datasets["validation"], batch_size=BATCH_SIZE, collate_fn=data_collator)

In [None]:
pretrained_model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [None]:
lora_config = LoraConfig(rank = RANK, bias = BIAS, alpha = ALPHA,
                        dropout = DROPOUT,
                        target_modules = TARGET_MODULES , exclude_modules = EXCLUDE_MODULES)


In [None]:
model = LoraModel(pretrained_model, lora_config)
model.to(device)

LoraModel(
  (model): RobertaForSequenceClassification(
    (roberta): RobertaModel(
      (embeddings): RobertaEmbeddings(
        (word_embeddings): Embedding(50265, 768, padding_idx=1)
        (position_embeddings): Embedding(514, 768, padding_idx=1)
        (token_type_embeddings): Embedding(1, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): RobertaEncoder(
        (layer): ModuleList(
          (0-11): 12 x RobertaLayer(
            (attention): RobertaAttention(
              (self): RobertaSdpaSelfAttention(
                (query): LoRALinearLayer(
                  in_features=768, out_features=768, bias=True
                  (dropout): Identity()
                )
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): LoRALinearLayer(
                  in_features=768, out_features=768, bias=True
                  (dropout): 

In [None]:
def count_parameters(model):
    trainable = model.get_n_trainable()
    # pretrained model total
    total = sum(p.numel() for p in model.model.parameters())
    print(f"Trainable Params: {trainable:,} || Total Params: {total:,} || %: {100 * trainable / total:.2f}%")

In [None]:
count_parameters(model)

Trainable Params: 887,042 || Total Params: 124,942,082 || %: 0.71%


In [None]:
try:
    model = torch.compile(model)
except Exception as e:
    print("error")

In [None]:
optimizer = AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=LR,
    fused=True
)
num_training_steps = EPOCHS * len(train_dataloader)
scheduler = get_linear_schedule_with_warmup(optimizer,
    num_warmup_steps=int(0.06 * num_training_steps),
    num_training_steps=num_training_steps)

In [None]:
scaler = GradScaler()

  scaler = GradScaler()


In [None]:
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{EPOCHS}")

    for batch in progress_bar:
        batch = {k: v.to(device) for k, v in batch.items()}

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast('cuda'):
            outputs = model(**batch)
            loss = outputs.loss

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        total_loss += loss.item()
        progress_bar.set_postfix({"loss": loss.item()})

    avg_train_loss = total_loss / len(train_dataloader)

    model.eval()
    preds = []
    labels = []

    for batch in eval_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)

        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)

        preds.extend(predictions.cpu().numpy())
        labels.extend(batch["labels"].cpu().numpy())

    acc = accuracy_score(labels, preds)
    print(f"\nEpoch {epoch + 1}. Train Loss: {avg_train_loss:.4f} | Val Accuracy: {acc:.4f}\n")

model.save_model("roberta_lora_sst2.safetensors", merge_weights=False)
print("Model Saved.")

Epoch 1/60:   0%|          | 0/132 [00:00<?, ?it/s]W1204 20:30:50.673000 896 torch/fx/experimental/symbolic_shapes.py:6833] [0/5] _maybe_guard_rel() was called on non-relation expression Eq(s52, s92) | Eq(s92, 1)
W1204 20:30:50.679000 896 torch/fx/experimental/symbolic_shapes.py:6833] [0/5] _maybe_guard_rel() was called on non-relation expression Eq(s16, 1) | Eq(s27, s16)
Epoch 1/60: 100%|██████████| 132/132 [00:20<00:00,  6.53it/s, loss=0.237]



Epoch 1. Train Loss: 0.5538 | Val Accuracy: 0.9106



Epoch 2/60: 100%|██████████| 132/132 [00:12<00:00, 10.43it/s, loss=0.213]



Epoch 2. Train Loss: 0.2510 | Val Accuracy: 0.9266



Epoch 3/60: 100%|██████████| 132/132 [00:12<00:00, 10.44it/s, loss=0.185]



Epoch 3. Train Loss: 0.2218 | Val Accuracy: 0.9369



Epoch 4/60: 100%|██████████| 132/132 [00:12<00:00, 10.45it/s, loss=0.227]



Epoch 4. Train Loss: 0.2017 | Val Accuracy: 0.9323



Epoch 5/60: 100%|██████████| 132/132 [00:12<00:00, 10.47it/s, loss=0.154]



Epoch 5. Train Loss: 0.1850 | Val Accuracy: 0.9335



Epoch 6/60: 100%|██████████| 132/132 [00:12<00:00, 10.41it/s, loss=0.191]



Epoch 6. Train Loss: 0.1741 | Val Accuracy: 0.9369



Epoch 7/60: 100%|██████████| 132/132 [00:12<00:00, 10.52it/s, loss=0.119]



Epoch 7. Train Loss: 0.1610 | Val Accuracy: 0.9381



Epoch 8/60: 100%|██████████| 132/132 [00:12<00:00, 10.37it/s, loss=0.137]



Epoch 8. Train Loss: 0.1520 | Val Accuracy: 0.9335



Epoch 9/60: 100%|██████████| 132/132 [00:12<00:00, 10.50it/s, loss=0.153]



Epoch 9. Train Loss: 0.1430 | Val Accuracy: 0.9369



Epoch 10/60: 100%|██████████| 132/132 [00:12<00:00, 10.51it/s, loss=0.133]



Epoch 10. Train Loss: 0.1365 | Val Accuracy: 0.9323



Epoch 11/60: 100%|██████████| 132/132 [00:12<00:00, 10.47it/s, loss=0.161]



Epoch 11. Train Loss: 0.1290 | Val Accuracy: 0.9392



Epoch 12/60: 100%|██████████| 132/132 [00:12<00:00, 10.51it/s, loss=0.112]



Epoch 12. Train Loss: 0.1245 | Val Accuracy: 0.9404



Epoch 13/60: 100%|██████████| 132/132 [00:12<00:00, 10.46it/s, loss=0.0906]



Epoch 13. Train Loss: 0.1191 | Val Accuracy: 0.9404



Epoch 14/60: 100%|██████████| 132/132 [00:12<00:00, 10.46it/s, loss=0.12]



Epoch 14. Train Loss: 0.1138 | Val Accuracy: 0.9404



Epoch 15/60: 100%|██████████| 132/132 [00:12<00:00, 10.48it/s, loss=0.13]



Epoch 15. Train Loss: 0.1092 | Val Accuracy: 0.9381



Epoch 16/60: 100%|██████████| 132/132 [00:12<00:00, 10.50it/s, loss=0.123]



Epoch 16. Train Loss: 0.1065 | Val Accuracy: 0.9323



Epoch 17/60: 100%|██████████| 132/132 [00:12<00:00, 10.43it/s, loss=0.107]



Epoch 17. Train Loss: 0.1012 | Val Accuracy: 0.9369



Epoch 18/60: 100%|██████████| 132/132 [00:12<00:00, 10.49it/s, loss=0.113]



Epoch 18. Train Loss: 0.0996 | Val Accuracy: 0.9404



Epoch 19/60: 100%|██████████| 132/132 [00:12<00:00, 10.46it/s, loss=0.0544]



Epoch 19. Train Loss: 0.0965 | Val Accuracy: 0.9427



Epoch 20/60: 100%|██████████| 132/132 [00:12<00:00, 10.42it/s, loss=0.0901]



Epoch 20. Train Loss: 0.0970 | Val Accuracy: 0.9438



Epoch 21/60: 100%|██████████| 132/132 [00:12<00:00, 10.46it/s, loss=0.0978]



Epoch 21. Train Loss: 0.0920 | Val Accuracy: 0.9415



Epoch 22/60: 100%|██████████| 132/132 [00:12<00:00, 10.48it/s, loss=0.0878]



Epoch 22. Train Loss: 0.0885 | Val Accuracy: 0.9369



Epoch 23/60: 100%|██████████| 132/132 [00:13<00:00,  9.59it/s, loss=0.0711]



Epoch 23. Train Loss: 0.0852 | Val Accuracy: 0.9381



Epoch 24/60: 100%|██████████| 132/132 [00:14<00:00,  8.93it/s, loss=0.0807]



Epoch 24. Train Loss: 0.0824 | Val Accuracy: 0.9415



Epoch 25/60: 100%|██████████| 132/132 [00:12<00:00, 10.41it/s, loss=0.0534]



Epoch 25. Train Loss: 0.0813 | Val Accuracy: 0.9358



Epoch 26/60: 100%|██████████| 132/132 [00:12<00:00, 10.37it/s, loss=0.0685]



Epoch 26. Train Loss: 0.0777 | Val Accuracy: 0.9404



Epoch 27/60: 100%|██████████| 132/132 [00:12<00:00, 10.39it/s, loss=0.122]



Epoch 27. Train Loss: 0.0763 | Val Accuracy: 0.9392



Epoch 28/60: 100%|██████████| 132/132 [00:12<00:00, 10.37it/s, loss=0.0924]



Epoch 28. Train Loss: 0.0746 | Val Accuracy: 0.9427



Epoch 29/60: 100%|██████████| 132/132 [00:12<00:00, 10.37it/s, loss=0.0365]



Epoch 29. Train Loss: 0.0732 | Val Accuracy: 0.9415



Epoch 30/60: 100%|██████████| 132/132 [00:12<00:00, 10.39it/s, loss=0.084]



Epoch 30. Train Loss: 0.0713 | Val Accuracy: 0.9438



Epoch 31/60: 100%|██████████| 132/132 [00:12<00:00, 10.42it/s, loss=0.0791]



Epoch 31. Train Loss: 0.0700 | Val Accuracy: 0.9438



Epoch 32/60: 100%|██████████| 132/132 [00:12<00:00, 10.49it/s, loss=0.068]



Epoch 32. Train Loss: 0.0677 | Val Accuracy: 0.9381



Epoch 33/60: 100%|██████████| 132/132 [00:12<00:00, 10.40it/s, loss=0.11]



Epoch 33. Train Loss: 0.0659 | Val Accuracy: 0.9404



Epoch 34/60: 100%|██████████| 132/132 [00:12<00:00, 10.37it/s, loss=0.076]



Epoch 34. Train Loss: 0.0658 | Val Accuracy: 0.9381



Epoch 35/60: 100%|██████████| 132/132 [00:12<00:00, 10.41it/s, loss=0.0646]



Epoch 35. Train Loss: 0.0628 | Val Accuracy: 0.9427



Epoch 36/60: 100%|██████████| 132/132 [00:12<00:00, 10.40it/s, loss=0.0934]



Epoch 36. Train Loss: 0.0634 | Val Accuracy: 0.9392



Epoch 37/60: 100%|██████████| 132/132 [00:12<00:00, 10.31it/s, loss=0.0482]



Epoch 37. Train Loss: 0.0601 | Val Accuracy: 0.9415



Epoch 38/60: 100%|██████████| 132/132 [00:12<00:00, 10.42it/s, loss=0.0581]



Epoch 38. Train Loss: 0.0607 | Val Accuracy: 0.9404



Epoch 39/60: 100%|██████████| 132/132 [00:12<00:00, 10.38it/s, loss=0.0714]



Epoch 39. Train Loss: 0.0578 | Val Accuracy: 0.9415



Epoch 40/60: 100%|██████████| 132/132 [00:12<00:00, 10.39it/s, loss=0.0645]



Epoch 40. Train Loss: 0.0579 | Val Accuracy: 0.9392



Epoch 41/60: 100%|██████████| 132/132 [00:12<00:00, 10.42it/s, loss=0.0599]



Epoch 41. Train Loss: 0.0567 | Val Accuracy: 0.9369



Epoch 42/60: 100%|██████████| 132/132 [00:12<00:00, 10.45it/s, loss=0.0313]



Epoch 42. Train Loss: 0.0565 | Val Accuracy: 0.9415



Epoch 43/60: 100%|██████████| 132/132 [00:12<00:00, 10.46it/s, loss=0.0399]



Epoch 43. Train Loss: 0.0547 | Val Accuracy: 0.9404



Epoch 44/60: 100%|██████████| 132/132 [00:12<00:00, 10.41it/s, loss=0.0519]



Epoch 44. Train Loss: 0.0531 | Val Accuracy: 0.9427



Epoch 45/60: 100%|██████████| 132/132 [00:12<00:00, 10.39it/s, loss=0.037]



Epoch 45. Train Loss: 0.0531 | Val Accuracy: 0.9404



Epoch 46/60: 100%|██████████| 132/132 [00:12<00:00, 10.44it/s, loss=0.058]



Epoch 46. Train Loss: 0.0521 | Val Accuracy: 0.9381



Epoch 47/60: 100%|██████████| 132/132 [00:12<00:00, 10.30it/s, loss=0.0387]



Epoch 47. Train Loss: 0.0511 | Val Accuracy: 0.9392



Epoch 48/60: 100%|██████████| 132/132 [00:12<00:00, 10.33it/s, loss=0.0754]



Epoch 48. Train Loss: 0.0513 | Val Accuracy: 0.9392



Epoch 49/60: 100%|██████████| 132/132 [00:12<00:00, 10.34it/s, loss=0.0773]



Epoch 49. Train Loss: 0.0505 | Val Accuracy: 0.9404



Epoch 50/60: 100%|██████████| 132/132 [00:12<00:00, 10.36it/s, loss=0.048]



Epoch 50. Train Loss: 0.0503 | Val Accuracy: 0.9369



Epoch 51/60: 100%|██████████| 132/132 [00:12<00:00, 10.38it/s, loss=0.0713]



Epoch 51. Train Loss: 0.0499 | Val Accuracy: 0.9358



Epoch 52/60: 100%|██████████| 132/132 [00:12<00:00, 10.43it/s, loss=0.0361]



Epoch 52. Train Loss: 0.0479 | Val Accuracy: 0.9415



Epoch 53/60: 100%|██████████| 132/132 [00:12<00:00, 10.39it/s, loss=0.0427]



Epoch 53. Train Loss: 0.0455 | Val Accuracy: 0.9392



Epoch 54/60: 100%|██████████| 132/132 [00:12<00:00, 10.40it/s, loss=0.0666]



Epoch 54. Train Loss: 0.0469 | Val Accuracy: 0.9381



Epoch 55/60: 100%|██████████| 132/132 [00:12<00:00, 10.39it/s, loss=0.0827]



Epoch 55. Train Loss: 0.0461 | Val Accuracy: 0.9404



Epoch 56/60: 100%|██████████| 132/132 [00:12<00:00, 10.30it/s, loss=0.0504]



Epoch 56. Train Loss: 0.0453 | Val Accuracy: 0.9392



Epoch 57/60: 100%|██████████| 132/132 [00:12<00:00, 10.31it/s, loss=0.0457]



Epoch 57. Train Loss: 0.0464 | Val Accuracy: 0.9415



Epoch 58/60: 100%|██████████| 132/132 [00:12<00:00, 10.37it/s, loss=0.00915]



Epoch 58. Train Loss: 0.0433 | Val Accuracy: 0.9415



Epoch 59/60: 100%|██████████| 132/132 [00:12<00:00, 10.43it/s, loss=0.0203]



Epoch 59. Train Loss: 0.0438 | Val Accuracy: 0.9381



Epoch 60/60: 100%|██████████| 132/132 [00:12<00:00, 10.42it/s, loss=0.0392]



Epoch 60. Train Loss: 0.0445 | Val Accuracy: 0.9381

Model Saved.
