In [None]:
!pip install torch
!pip install safetensors



In [2]:
from lora_transformer import LoraConfig, LoraModel
from typing import Literal
import torch
from transformers import AutoTokenizer,AutoModelForSequenceClassification, get_linear_schedule_with_warmup, DataCollatorWithPadding
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm
from sklearn.metrics import matthews_corrcoef
from torch.cuda.amp import autocast, GradScaler
import os

In [3]:
device = torch.device("cuda")

In [4]:
torch.set_float32_matmul_precision('high')

  _C._set_float32_matmul_precision(precision)


In [5]:
print(device)

cuda


In [20]:
MODEL_NAME = "roberta-base"
TASK = "cola"
BATCH_SIZE = 32
LR = 4e-4
EPOCHS = 80
MAX_LEN = 512
RANK = 8
ALPHA = 16
BIAS = "none"
DROPOUT = 0.0
TARGET_MODULES = []
EXCLUDE_MODULES =  ["classifier"]

In [21]:
dataset = load_dataset("glue", TASK)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

In [22]:
def preprocess_function(examples):
    return tokenizer(examples["sentence"], truncation=True, max_length=MAX_LEN)

In [23]:
tokenized_datasets = dataset.map(preprocess_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["sentence", "idx"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")

Map:   0%|          | 0/1043 [00:00<?, ? examples/s]

In [24]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
NUM_WORKERS = os.cpu_count()

train_dataloader = DataLoader(tokenized_datasets["train"],
    shuffle=True,
    batch_size=BATCH_SIZE,
    collate_fn=data_collator,
    num_workers = NUM_WORKERS)
eval_dataloader = DataLoader(tokenized_datasets["validation"], batch_size=BATCH_SIZE, collate_fn=data_collator)

In [25]:
pretrained_model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [26]:
lora_config = LoraConfig(rank = RANK, bias = BIAS, alpha = ALPHA,
                        dropout = DROPOUT,
                        target_modules = TARGET_MODULES , exclude_modules = EXCLUDE_MODULES)


In [27]:
model = LoraModel(pretrained_model, lora_config)
model.to(device)

LoraModel(
  (model): RobertaForSequenceClassification(
    (roberta): RobertaModel(
      (embeddings): RobertaEmbeddings(
        (word_embeddings): Embedding(50265, 768, padding_idx=1)
        (position_embeddings): Embedding(514, 768, padding_idx=1)
        (token_type_embeddings): Embedding(1, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): RobertaEncoder(
        (layer): ModuleList(
          (0-11): 12 x RobertaLayer(
            (attention): RobertaAttention(
              (self): RobertaSdpaSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): RobertaSelfOutput(
                (dense): 

In [28]:
def count_parameters(model):
    trainable = model.get_n_trainable()
    # pretrained model total
    total = sum(p.numel() for p in model.model.parameters())
    print(f"Trainable Params: {trainable:,} || Total Params: {total:,} || %: {100 * trainable / total:.2f}%")

In [29]:
count_parameters(model)

Trainable Params: 592,130 || Total Params: 124,647,170 || %: 0.48%


In [30]:
try:
    model = torch.compile(model)
except Exception as e:
    print("error")

In [31]:
optimizer = AdamW(
    filter(lambda p: p.requires_grad, model.parameters()),
    lr=LR,
    fused=True
)
num_training_steps = EPOCHS * len(train_dataloader)
scheduler = get_linear_schedule_with_warmup(optimizer,
    num_warmup_steps=int(0.06 * num_training_steps),
    num_training_steps=num_training_steps)

In [32]:
scaler = GradScaler()

  scaler = GradScaler()


In [34]:
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{EPOCHS}")

    for batch in progress_bar:
        batch = {k: v.to(device) for k, v in batch.items()}

        optimizer.zero_grad(set_to_none=True)

        with torch.amp.autocast('cuda'):
            outputs = model(**batch)
            loss = outputs.loss

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        scheduler.step()

        total_loss += loss.item()
        progress_bar.set_postfix({"loss": loss.item()})

    avg_train_loss = total_loss / len(train_dataloader)

    model.eval()
    preds = []
    labels = []

    for batch in eval_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)

        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)

        preds.extend(predictions.cpu().numpy())
        labels.extend(batch["labels"].cpu().numpy())

    mcc = matthews_corrcoef(labels, preds)
    print(f"\nEpoch {epoch + 1}. Train Loss: {avg_train_loss:.4f} | Val MCC: {mcc:.4f}\n")

model.save_model("roberta_lora_sst2.safetensors", merge_weights=False)
print("Model Saved.")

Epoch 1/80:   0%|          | 0/268 [00:00<?, ?it/s]W1204 21:53:53.131000 3478 torch/fx/experimental/symbolic_shapes.py:6833] [0/2] _maybe_guard_rel() was called on non-relation expression Eq(s16, 1) | Eq(s27, s16)
Online softmax is disabled on the fly since Inductor decides to
split the reduction. Cut an issue to PyTorch if this is an
important use case and you want to speed it up with online
softmax.

Online softmax is disabled on the fly since Inductor decides to
split the reduction. Cut an issue to PyTorch if this is an
important use case and you want to speed it up with online
softmax.

Epoch 1/80:   0%|          | 1/268 [00:51<3:47:24, 51.10s/it, loss=0.677]W1204 21:54:43.975000 3478 torch/fx/experimental/symbolic_shapes.py:6833] [0/3] _maybe_guard_rel() was called on non-relation expression Eq(s16, 1) | Eq(s27, s16)
Epoch 1/80:  97%|█████████▋| 261/268 [00:59<00:00, 91.03it/s, loss=0.572]W1204 21:54:52.122000 3478 torch/fx/experimental/symbolic_shapes.py:6833] [0/4] _maybe_guard


Epoch 1. Train Loss: 0.6261 | Val MCC: 0.0000



Epoch 2/80: 100%|██████████| 268/268 [00:03<00:00, 75.29it/s, loss=0.455]



Epoch 2. Train Loss: 0.6058 | Val MCC: 0.0000



Epoch 3/80: 100%|██████████| 268/268 [00:03<00:00, 74.22it/s, loss=0.424]



Epoch 3. Train Loss: 0.5958 | Val MCC: 0.0000



Epoch 4/80: 100%|██████████| 268/268 [00:03<00:00, 75.99it/s, loss=0.66]



Epoch 4. Train Loss: 0.5902 | Val MCC: 0.0000



Epoch 5/80: 100%|██████████| 268/268 [00:03<00:00, 71.96it/s, loss=0.739]



Epoch 5. Train Loss: 0.5835 | Val MCC: 0.0000



Epoch 6/80: 100%|██████████| 268/268 [00:03<00:00, 74.06it/s, loss=0.52]



Epoch 6. Train Loss: 0.5748 | Val MCC: 0.2309



Epoch 7/80: 100%|██████████| 268/268 [00:03<00:00, 73.99it/s, loss=0.646]



Epoch 7. Train Loss: 0.5601 | Val MCC: 0.2776



Epoch 8/80: 100%|██████████| 268/268 [00:03<00:00, 75.05it/s, loss=0.469]



Epoch 8. Train Loss: 0.5642 | Val MCC: 0.1565



Epoch 9/80: 100%|██████████| 268/268 [00:03<00:00, 73.46it/s, loss=0.539]



Epoch 9. Train Loss: 0.5531 | Val MCC: 0.2252



Epoch 10/80: 100%|██████████| 268/268 [00:03<00:00, 74.90it/s, loss=1.18]



Epoch 10. Train Loss: 0.5532 | Val MCC: 0.2178



Epoch 11/80: 100%|██████████| 268/268 [00:03<00:00, 75.24it/s, loss=0.373]



Epoch 11. Train Loss: 0.5469 | Val MCC: 0.1494



Epoch 12/80: 100%|██████████| 268/268 [00:03<00:00, 75.19it/s, loss=0.878]



Epoch 12. Train Loss: 0.5495 | Val MCC: 0.2556



Epoch 13/80: 100%|██████████| 268/268 [00:03<00:00, 76.75it/s, loss=0.833]



Epoch 13. Train Loss: 0.5557 | Val MCC: 0.2747



Epoch 14/80: 100%|██████████| 268/268 [00:03<00:00, 76.08it/s, loss=0.582]



Epoch 14. Train Loss: 0.5481 | Val MCC: 0.0464



Epoch 15/80: 100%|██████████| 268/268 [00:03<00:00, 74.70it/s, loss=0.426]



Epoch 15. Train Loss: 0.5463 | Val MCC: 0.2467



Epoch 16/80: 100%|██████████| 268/268 [00:03<00:00, 77.68it/s, loss=0.572]



Epoch 16. Train Loss: 0.5403 | Val MCC: 0.2945



Epoch 17/80: 100%|██████████| 268/268 [00:03<00:00, 77.78it/s, loss=0.408]



Epoch 17. Train Loss: 0.5382 | Val MCC: 0.2801



Epoch 18/80: 100%|██████████| 268/268 [00:03<00:00, 75.35it/s, loss=0.458]



Epoch 18. Train Loss: 0.5443 | Val MCC: 0.3326



Epoch 19/80: 100%|██████████| 268/268 [00:03<00:00, 76.68it/s, loss=0.464]



Epoch 19. Train Loss: 0.5413 | Val MCC: 0.2750



Epoch 20/80: 100%|██████████| 268/268 [00:03<00:00, 76.43it/s, loss=0.399]



Epoch 20. Train Loss: 0.5408 | Val MCC: 0.2645



Epoch 21/80: 100%|██████████| 268/268 [00:03<00:00, 75.29it/s, loss=0.519]



Epoch 21. Train Loss: 0.5439 | Val MCC: 0.2688



Epoch 22/80: 100%|██████████| 268/268 [00:03<00:00, 77.01it/s, loss=0.427]



Epoch 22. Train Loss: 0.5412 | Val MCC: 0.2513



Epoch 23/80: 100%|██████████| 268/268 [00:03<00:00, 75.55it/s, loss=0.67]



Epoch 23. Train Loss: 0.5387 | Val MCC: 0.3272



Epoch 24/80: 100%|██████████| 268/268 [00:03<00:00, 75.61it/s, loss=0.375]



Epoch 24. Train Loss: 0.5410 | Val MCC: 0.2443



Epoch 25/80: 100%|██████████| 268/268 [00:03<00:00, 76.24it/s, loss=0.189]



Epoch 25. Train Loss: 0.5442 | Val MCC: 0.2106



Epoch 26/80: 100%|██████████| 268/268 [00:03<00:00, 76.42it/s, loss=0.363]



Epoch 26. Train Loss: 0.5360 | Val MCC: 0.2404



Epoch 27/80: 100%|██████████| 268/268 [00:03<00:00, 75.46it/s, loss=0.362]



Epoch 27. Train Loss: 0.5361 | Val MCC: 0.1997



Epoch 28/80: 100%|██████████| 268/268 [00:03<00:00, 76.90it/s, loss=0.612]



Epoch 28. Train Loss: 0.5418 | Val MCC: 0.2748



Epoch 29/80: 100%|██████████| 268/268 [00:03<00:00, 77.72it/s, loss=0.315]



Epoch 29. Train Loss: 0.5432 | Val MCC: 0.2357



Epoch 30/80: 100%|██████████| 268/268 [00:03<00:00, 75.62it/s, loss=0.413]



Epoch 30. Train Loss: 0.5359 | Val MCC: 0.2468



Epoch 31/80: 100%|██████████| 268/268 [00:03<00:00, 75.29it/s, loss=0.471]



Epoch 31. Train Loss: 0.5408 | Val MCC: 0.2577



Epoch 32/80: 100%|██████████| 268/268 [00:03<00:00, 75.78it/s, loss=0.282]



Epoch 32. Train Loss: 0.5375 | Val MCC: 0.2685



Epoch 33/80: 100%|██████████| 268/268 [00:03<00:00, 75.23it/s, loss=0.702]



Epoch 33. Train Loss: 0.5382 | Val MCC: 0.2421



Epoch 34/80: 100%|██████████| 268/268 [00:03<00:00, 74.16it/s, loss=0.578]



Epoch 34. Train Loss: 0.5344 | Val MCC: 0.2278



Epoch 35/80: 100%|██████████| 268/268 [00:03<00:00, 75.46it/s, loss=0.353]



Epoch 35. Train Loss: 0.5321 | Val MCC: 0.2209



Epoch 36/80: 100%|██████████| 268/268 [00:03<00:00, 75.44it/s, loss=0.868]



Epoch 36. Train Loss: 0.5362 | Val MCC: 0.3723



Epoch 37/80: 100%|██████████| 268/268 [00:03<00:00, 75.55it/s, loss=0.648]



Epoch 37. Train Loss: 0.5303 | Val MCC: 0.2357



Epoch 38/80: 100%|██████████| 268/268 [00:03<00:00, 76.16it/s, loss=0.648]



Epoch 38. Train Loss: 0.5351 | Val MCC: 0.3566



Epoch 39/80: 100%|██████████| 268/268 [00:03<00:00, 74.74it/s, loss=0.958]



Epoch 39. Train Loss: 0.5396 | Val MCC: 0.2468



Epoch 40/80: 100%|██████████| 268/268 [00:03<00:00, 74.86it/s, loss=0.543]



Epoch 40. Train Loss: 0.5318 | Val MCC: 0.2867



Epoch 41/80: 100%|██████████| 268/268 [00:03<00:00, 75.93it/s, loss=0.479]



Epoch 41. Train Loss: 0.5351 | Val MCC: 0.2791



Epoch 42/80: 100%|██████████| 268/268 [00:03<00:00, 76.62it/s, loss=0.772]



Epoch 42. Train Loss: 0.5371 | Val MCC: 0.2813



Epoch 43/80: 100%|██████████| 268/268 [00:03<00:00, 72.73it/s, loss=0.799]



Epoch 43. Train Loss: 0.5367 | Val MCC: 0.3214



Epoch 44/80: 100%|██████████| 268/268 [00:03<00:00, 76.42it/s, loss=0.737]



Epoch 44. Train Loss: 0.5373 | Val MCC: 0.2923



Epoch 45/80: 100%|██████████| 268/268 [00:03<00:00, 76.07it/s, loss=0.502]



Epoch 45. Train Loss: 0.5339 | Val MCC: 0.3143



Epoch 46/80: 100%|██████████| 268/268 [00:03<00:00, 75.23it/s, loss=0.454]



Epoch 46. Train Loss: 0.5295 | Val MCC: 0.3221



Epoch 47/80: 100%|██████████| 268/268 [00:03<00:00, 75.98it/s, loss=0.536]



Epoch 47. Train Loss: 0.5276 | Val MCC: 0.2845



Epoch 48/80: 100%|██████████| 268/268 [00:03<00:00, 75.52it/s, loss=0.662]



Epoch 48. Train Loss: 0.5416 | Val MCC: 0.3456



Epoch 49/80: 100%|██████████| 268/268 [00:03<00:00, 73.77it/s, loss=0.362]



Epoch 49. Train Loss: 0.5328 | Val MCC: 0.2178



Epoch 50/80: 100%|██████████| 268/268 [00:03<00:00, 76.30it/s, loss=0.51]



Epoch 50. Train Loss: 0.5327 | Val MCC: 0.2750



Epoch 51/80: 100%|██████████| 268/268 [00:03<00:00, 75.55it/s, loss=0.585]



Epoch 51. Train Loss: 0.5359 | Val MCC: 0.2709



Epoch 52/80: 100%|██████████| 268/268 [00:03<00:00, 74.38it/s, loss=0.49]



Epoch 52. Train Loss: 0.5315 | Val MCC: 0.2748



Epoch 53/80: 100%|██████████| 268/268 [00:03<00:00, 75.32it/s, loss=0.532]



Epoch 53. Train Loss: 0.5290 | Val MCC: 0.2734



Epoch 54/80: 100%|██████████| 268/268 [00:03<00:00, 76.21it/s, loss=0.481]



Epoch 54. Train Loss: 0.5340 | Val MCC: 0.2667



Epoch 55/80: 100%|██████████| 268/268 [00:03<00:00, 74.76it/s, loss=0.582]



Epoch 55. Train Loss: 0.5334 | Val MCC: 0.2845



Epoch 56/80: 100%|██████████| 268/268 [00:03<00:00, 75.56it/s, loss=0.682]



Epoch 56. Train Loss: 0.5350 | Val MCC: 0.3018



Epoch 57/80: 100%|██████████| 268/268 [00:03<00:00, 75.57it/s, loss=0.74]



Epoch 57. Train Loss: 0.5362 | Val MCC: 0.3084



Epoch 58/80: 100%|██████████| 268/268 [00:03<00:00, 75.31it/s, loss=0.38]



Epoch 58. Train Loss: 0.5259 | Val MCC: 0.2884



Epoch 59/80: 100%|██████████| 268/268 [00:03<00:00, 75.81it/s, loss=0.501]



Epoch 59. Train Loss: 0.5282 | Val MCC: 0.3444



Epoch 60/80: 100%|██████████| 268/268 [00:03<00:00, 75.95it/s, loss=0.422]



Epoch 60. Train Loss: 0.5285 | Val MCC: 0.2935



Epoch 61/80: 100%|██████████| 268/268 [00:03<00:00, 75.18it/s, loss=0.694]



Epoch 61. Train Loss: 0.5260 | Val MCC: 0.2806



Epoch 62/80: 100%|██████████| 268/268 [00:03<00:00, 75.30it/s, loss=0.507]



Epoch 62. Train Loss: 0.5291 | Val MCC: 0.2867



Epoch 63/80: 100%|██████████| 268/268 [00:03<00:00, 77.16it/s, loss=0.276]



Epoch 63. Train Loss: 0.5342 | Val MCC: 0.2725



Epoch 64/80: 100%|██████████| 268/268 [00:03<00:00, 75.81it/s, loss=0.984]



Epoch 64. Train Loss: 0.5322 | Val MCC: 0.2884



Epoch 65/80: 100%|██████████| 268/268 [00:03<00:00, 75.05it/s, loss=0.492]



Epoch 65. Train Loss: 0.5263 | Val MCC: 0.2730



Epoch 66/80: 100%|██████████| 268/268 [00:03<00:00, 75.25it/s, loss=0.868]



Epoch 66. Train Loss: 0.5268 | Val MCC: 0.3429



Epoch 67/80: 100%|██████████| 268/268 [00:03<00:00, 74.39it/s, loss=0.516]



Epoch 67. Train Loss: 0.5285 | Val MCC: 0.2664



Epoch 68/80: 100%|██████████| 268/268 [00:03<00:00, 74.51it/s, loss=0.768]



Epoch 68. Train Loss: 0.5343 | Val MCC: 0.2989



Epoch 69/80: 100%|██████████| 268/268 [00:03<00:00, 75.27it/s, loss=0.611]



Epoch 69. Train Loss: 0.5335 | Val MCC: 0.2905



Epoch 70/80: 100%|██████████| 268/268 [00:03<00:00, 75.86it/s, loss=0.506]



Epoch 70. Train Loss: 0.5259 | Val MCC: 0.2952



Epoch 71/80: 100%|██████████| 268/268 [00:03<00:00, 75.85it/s, loss=0.439]



Epoch 71. Train Loss: 0.5286 | Val MCC: 0.2923



Epoch 72/80: 100%|██████████| 268/268 [00:03<00:00, 76.51it/s, loss=0.548]



Epoch 72. Train Loss: 0.5288 | Val MCC: 0.3048



Epoch 73/80: 100%|██████████| 268/268 [00:03<00:00, 76.23it/s, loss=0.501]



Epoch 73. Train Loss: 0.5284 | Val MCC: 0.3055



Epoch 74/80: 100%|██████████| 268/268 [00:03<00:00, 75.86it/s, loss=0.379]



Epoch 74. Train Loss: 0.5263 | Val MCC: 0.2816



Epoch 75/80: 100%|██████████| 268/268 [00:03<00:00, 74.93it/s, loss=0.328]



Epoch 75. Train Loss: 0.5287 | Val MCC: 0.2923



Epoch 76/80: 100%|██████████| 268/268 [00:03<00:00, 75.88it/s, loss=0.836]



Epoch 76. Train Loss: 0.5297 | Val MCC: 0.2856



Epoch 77/80: 100%|██████████| 268/268 [00:03<00:00, 75.45it/s, loss=0.344]



Epoch 77. Train Loss: 0.5264 | Val MCC: 0.3193



Epoch 78/80: 100%|██████████| 268/268 [00:03<00:00, 76.22it/s, loss=0.763]



Epoch 78. Train Loss: 0.5297 | Val MCC: 0.2961



Epoch 79/80: 100%|██████████| 268/268 [00:03<00:00, 76.72it/s, loss=0.578]



Epoch 79. Train Loss: 0.5257 | Val MCC: 0.2952



Epoch 80/80: 100%|██████████| 268/268 [00:03<00:00, 74.98it/s, loss=0.608]



Epoch 80. Train Loss: 0.5226 | Val MCC: 0.2923

Model Saved.
