In [None]:
!pip install torch
!pip install safetensors



In [2]:
import torch
from transformers import AutoTokenizer,AutoModelForSequenceClassification, get_linear_schedule_with_warmup, DataCollatorWithPadding
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm
from sklearn.metrics import matthews_corrcoef

from lora_transformer import LoraConfig, LoraModel

In [3]:
device = torch.device("cuda")
print(f"Device: {device}")

Device: cuda


In [4]:
MODEL_NAME = "roberta-base"
TASK = "cola"
BATCH_SIZE = 32
LR = 4e-4
EPOCHS = 80
MAX_LEN = 512
RANK = 8
ALPHA = 16
BIAS = "none"
DROPOUT = 0.0
TARGET_MODULES = []
EXCLUDE_MODULES =  ["classifier"]

In [5]:
dataset = load_dataset("glue", TASK)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

README.md: 0.00B [00:00, ?B/s]

cola/train-00000-of-00001.parquet:   0%|          | 0.00/251k [00:00<?, ?B/s]

cola/validation-00000-of-00001.parquet:   0%|          | 0.00/37.6k [00:00<?, ?B/s]

cola/test-00000-of-00001.parquet:   0%|          | 0.00/37.7k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/8551 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1043 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1063 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [6]:
def preprocess_function(examples):
    return tokenizer(examples["sentence"], truncation=True, max_length=MAX_LEN)

In [7]:
tokenized_datasets = dataset.map(preprocess_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["sentence", "idx"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")

Map:   0%|          | 0/8551 [00:00<?, ? examples/s]

Map:   0%|          | 0/1043 [00:00<?, ? examples/s]

Map:   0%|          | 0/1063 [00:00<?, ? examples/s]

In [8]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

train_dataloader = DataLoader(tokenized_datasets["train"],
    shuffle=True,
    batch_size=BATCH_SIZE,
    collate_fn=data_collator)
eval_dataloader = DataLoader(tokenized_datasets["validation"], batch_size=BATCH_SIZE, collate_fn=data_collator)

In [9]:

pretrained_model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)

model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
lora_config = LoraConfig(rank = RANK, bias = BIAS, alpha = ALPHA,
                        dropout = DROPOUT,
                        target_modules = TARGET_MODULES , exclude_modules = EXCLUDE_MODULES)


In [11]:
model = LoraModel(pretrained_model, lora_config)
model.to(device)

LoraModel(
  (model): RobertaForSequenceClassification(
    (roberta): RobertaModel(
      (embeddings): RobertaEmbeddings(
        (word_embeddings): Embedding(50265, 768, padding_idx=1)
        (position_embeddings): Embedding(514, 768, padding_idx=1)
        (token_type_embeddings): Embedding(1, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): RobertaEncoder(
        (layer): ModuleList(
          (0-11): 12 x RobertaLayer(
            (attention): RobertaAttention(
              (self): RobertaSdpaSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): RobertaSelfOutput(
                (dense): 

In [12]:
def count_parameters(model):
    trainable = model.get_n_trainable()
    # pretrained model total
    total = sum(p.numel() for p in model.model.parameters())
    print(f"Trainable Params: {trainable:,} || Total Params: {total:,} || %: {100 * trainable / total:.2f}%")

In [13]:
count_parameters(model)

Trainable Params: 592,130 || Total Params: 124,647,170 || %: 0.48%


In [14]:
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LR)
num_training_steps = EPOCHS * len(train_dataloader)
scheduler = get_linear_schedule_with_warmup(optimizer,
    num_warmup_steps=int(0.06 * num_training_steps),
    num_training_steps=num_training_steps)

In [15]:
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{EPOCHS}")

    for batch in progress_bar:
        batch = {k: v.to(device) for k, v in batch.items()}

        outputs = model(**batch)
        loss = outputs.loss

        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        total_loss += loss.item()
        progress_bar.set_postfix({"loss": loss.item()})

    avg_train_loss = total_loss / len(train_dataloader)

    model.eval()
    preds = []
    labels = []

    for batch in eval_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)

        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)

        preds.extend(predictions.cpu().numpy())
        labels.extend(batch["labels"].cpu().numpy())

    mcc = matthews_corrcoef(labels, preds)
    print(f"\nEpoch {epoch + 1}. Train Loss: {avg_train_loss:.4f} | Val MCC: {mcc:.4f}\n")

model.save_model("roberta_lora_sst2.safetensors", merge_weights=False)
print("Model Saved.")

Epoch 1/80: 100%|██████████| 268/268 [00:12<00:00, 21.89it/s, loss=0.592]



Epoch 1. Train Loss: 0.6267 | Val MCC: 0.0000



Epoch 2/80: 100%|██████████| 268/268 [00:12<00:00, 20.84it/s, loss=0.522]



Epoch 2. Train Loss: 0.6047 | Val MCC: 0.0000



Epoch 3/80: 100%|██████████| 268/268 [00:11<00:00, 22.69it/s, loss=0.421]



Epoch 3. Train Loss: 0.5940 | Val MCC: 0.0000



Epoch 4/80: 100%|██████████| 268/268 [00:12<00:00, 22.12it/s, loss=0.605]



Epoch 4. Train Loss: 0.5911 | Val MCC: 0.1172



Epoch 5/80: 100%|██████████| 268/268 [00:12<00:00, 21.71it/s, loss=0.731]



Epoch 5. Train Loss: 0.5834 | Val MCC: 0.0000



Epoch 6/80: 100%|██████████| 268/268 [00:12<00:00, 21.13it/s, loss=0.303]



Epoch 6. Train Loss: 0.5734 | Val MCC: 0.1941



Epoch 7/80: 100%|██████████| 268/268 [00:13<00:00, 20.34it/s, loss=0.364]



Epoch 7. Train Loss: 0.5619 | Val MCC: 0.0864



Epoch 8/80: 100%|██████████| 268/268 [00:13<00:00, 19.84it/s, loss=0.85]



Epoch 8. Train Loss: 0.5574 | Val MCC: 0.0738



Epoch 9/80: 100%|██████████| 268/268 [00:12<00:00, 20.64it/s, loss=0.39]



Epoch 9. Train Loss: 0.5537 | Val MCC: 0.2565



Epoch 10/80: 100%|██████████| 268/268 [00:12<00:00, 20.81it/s, loss=0.463]



Epoch 10. Train Loss: 0.5525 | Val MCC: 0.3201



Epoch 11/80: 100%|██████████| 268/268 [00:12<00:00, 21.12it/s, loss=0.735]



Epoch 11. Train Loss: 0.5561 | Val MCC: 0.2309



Epoch 12/80: 100%|██████████| 268/268 [00:14<00:00, 18.41it/s, loss=0.571]



Epoch 12. Train Loss: 0.5414 | Val MCC: 0.3122



Epoch 13/80: 100%|██████████| 268/268 [00:12<00:00, 20.71it/s, loss=0.545]



Epoch 13. Train Loss: 0.5470 | Val MCC: 0.2327



Epoch 14/80: 100%|██████████| 268/268 [00:12<00:00, 20.71it/s, loss=0.638]



Epoch 14. Train Loss: 0.5409 | Val MCC: 0.3477



Epoch 15/80: 100%|██████████| 268/268 [00:12<00:00, 20.78it/s, loss=0.425]



Epoch 15. Train Loss: 0.5503 | Val MCC: 0.2178



Epoch 16/80: 100%|██████████| 268/268 [00:12<00:00, 20.98it/s, loss=0.494]



Epoch 16. Train Loss: 0.5425 | Val MCC: 0.2448



Epoch 17/80: 100%|██████████| 268/268 [00:12<00:00, 20.76it/s, loss=0.658]



Epoch 17. Train Loss: 0.5408 | Val MCC: 0.2259



Epoch 18/80: 100%|██████████| 268/268 [00:13<00:00, 20.61it/s, loss=0.491]



Epoch 18. Train Loss: 0.5423 | Val MCC: 0.2448



Epoch 19/80: 100%|██████████| 268/268 [00:13<00:00, 20.61it/s, loss=0.372]



Epoch 19. Train Loss: 0.5451 | Val MCC: 0.2448



Epoch 20/80: 100%|██████████| 268/268 [00:12<00:00, 20.75it/s, loss=0.309]



Epoch 20. Train Loss: 0.5411 | Val MCC: 0.2513



Epoch 21/80: 100%|██████████| 268/268 [00:12<00:00, 20.71it/s, loss=0.663]



Epoch 21. Train Loss: 0.5433 | Val MCC: 0.2608



Epoch 22/80: 100%|██████████| 268/268 [00:12<00:00, 20.74it/s, loss=0.589]



Epoch 22. Train Loss: 0.5471 | Val MCC: 0.2443



Epoch 23/80: 100%|██████████| 268/268 [00:12<00:00, 21.01it/s, loss=0.37]



Epoch 23. Train Loss: 0.5431 | Val MCC: 0.2259



Epoch 24/80: 100%|██████████| 268/268 [00:12<00:00, 20.80it/s, loss=0.818]



Epoch 24. Train Loss: 0.5372 | Val MCC: 0.3214



Epoch 25/80: 100%|██████████| 268/268 [00:12<00:00, 20.80it/s, loss=0.988]



Epoch 25. Train Loss: 0.5376 | Val MCC: 0.2357



Epoch 26/80: 100%|██████████| 268/268 [00:12<00:00, 20.68it/s, loss=0.55]



Epoch 26. Train Loss: 0.5402 | Val MCC: 0.1824



Epoch 27/80: 100%|██████████| 268/268 [00:12<00:00, 20.73it/s, loss=0.558]



Epoch 27. Train Loss: 0.5380 | Val MCC: 0.4354



Epoch 28/80: 100%|██████████| 268/268 [00:12<00:00, 20.75it/s, loss=0.608]



Epoch 28. Train Loss: 0.5454 | Val MCC: 0.2512



Epoch 29/80: 100%|██████████| 268/268 [00:13<00:00, 20.59it/s, loss=0.258]



Epoch 29. Train Loss: 0.5374 | Val MCC: 0.2209



Epoch 30/80: 100%|██████████| 268/268 [00:12<00:00, 20.94it/s, loss=0.479]



Epoch 30. Train Loss: 0.5412 | Val MCC: 0.3138



Epoch 31/80: 100%|██████████| 268/268 [00:12<00:00, 20.84it/s, loss=0.473]



Epoch 31. Train Loss: 0.5396 | Val MCC: 0.2402



Epoch 32/80: 100%|██████████| 268/268 [00:12<00:00, 20.78it/s, loss=0.571]



Epoch 32. Train Loss: 0.5376 | Val MCC: 0.2513



Epoch 33/80: 100%|██████████| 268/268 [00:12<00:00, 20.87it/s, loss=0.22]



Epoch 33. Train Loss: 0.5366 | Val MCC: 0.2327



Epoch 34/80: 100%|██████████| 268/268 [00:12<00:00, 20.93it/s, loss=0.754]



Epoch 34. Train Loss: 0.5386 | Val MCC: 0.2100



Epoch 35/80: 100%|██████████| 268/268 [00:12<00:00, 20.68it/s, loss=0.585]



Epoch 35. Train Loss: 0.5369 | Val MCC: 0.2897



Epoch 36/80: 100%|██████████| 268/268 [00:12<00:00, 20.89it/s, loss=0.198]



Epoch 36. Train Loss: 0.5366 | Val MCC: 0.2432



Epoch 37/80: 100%|██████████| 268/268 [00:12<00:00, 20.72it/s, loss=0.473]



Epoch 37. Train Loss: 0.5336 | Val MCC: 0.2747



Epoch 38/80: 100%|██████████| 268/268 [00:12<00:00, 20.67it/s, loss=0.675]



Epoch 38. Train Loss: 0.5371 | Val MCC: 0.2421



Epoch 39/80: 100%|██████████| 268/268 [00:13<00:00, 20.54it/s, loss=0.698]



Epoch 39. Train Loss: 0.5374 | Val MCC: 0.2788



Epoch 40/80: 100%|██████████| 268/268 [00:12<00:00, 20.69it/s, loss=0.554]



Epoch 40. Train Loss: 0.5365 | Val MCC: 0.2375



Epoch 41/80: 100%|██████████| 268/268 [00:12<00:00, 20.82it/s, loss=0.44]



Epoch 41. Train Loss: 0.5348 | Val MCC: 0.2989



Epoch 42/80: 100%|██████████| 268/268 [00:12<00:00, 20.68it/s, loss=0.295]



Epoch 42. Train Loss: 0.5336 | Val MCC: 0.2375



Epoch 43/80: 100%|██████████| 268/268 [00:12<00:00, 20.72it/s, loss=0.974]



Epoch 43. Train Loss: 0.5392 | Val MCC: 0.2468



Epoch 44/80: 100%|██████████| 268/268 [00:12<00:00, 20.83it/s, loss=0.502]



Epoch 44. Train Loss: 0.5288 | Val MCC: 0.2538



Epoch 45/80: 100%|██████████| 268/268 [00:12<00:00, 20.69it/s, loss=0.666]



Epoch 45. Train Loss: 0.5363 | Val MCC: 0.2806



Epoch 46/80: 100%|██████████| 268/268 [00:12<00:00, 20.69it/s, loss=0.594]



Epoch 46. Train Loss: 0.5319 | Val MCC: 0.2736



Epoch 47/80: 100%|██████████| 268/268 [00:12<00:00, 20.69it/s, loss=0.385]



Epoch 47. Train Loss: 0.5328 | Val MCC: 0.2961



Epoch 48/80: 100%|██████████| 268/268 [00:12<00:00, 20.70it/s, loss=0.384]



Epoch 48. Train Loss: 0.5317 | Val MCC: 0.2507



Epoch 49/80: 100%|██████████| 268/268 [00:12<00:00, 20.74it/s, loss=0.584]



Epoch 49. Train Loss: 0.5348 | Val MCC: 0.2581



Epoch 50/80: 100%|██████████| 268/268 [00:12<00:00, 20.73it/s, loss=0.339]



Epoch 50. Train Loss: 0.5320 | Val MCC: 0.2688



Epoch 51/80: 100%|██████████| 268/268 [00:12<00:00, 20.79it/s, loss=0.978]



Epoch 51. Train Loss: 0.5366 | Val MCC: 0.2052



Epoch 52/80: 100%|██████████| 268/268 [00:12<00:00, 20.76it/s, loss=0.238]



Epoch 52. Train Loss: 0.5315 | Val MCC: 0.2685



Epoch 53/80: 100%|██████████| 268/268 [00:12<00:00, 20.73it/s, loss=0.35]



Epoch 53. Train Loss: 0.5326 | Val MCC: 0.2533



Epoch 54/80: 100%|██████████| 268/268 [00:12<00:00, 20.71it/s, loss=0.388]



Epoch 54. Train Loss: 0.5332 | Val MCC: 0.2488



Epoch 55/80: 100%|██████████| 268/268 [00:12<00:00, 20.85it/s, loss=0.809]



Epoch 55. Train Loss: 0.5338 | Val MCC: 0.3048



Epoch 56/80: 100%|██████████| 268/268 [00:12<00:00, 20.69it/s, loss=0.496]



Epoch 56. Train Loss: 0.5361 | Val MCC: 0.2106



Epoch 57/80: 100%|██████████| 268/268 [00:13<00:00, 20.60it/s, loss=0.482]



Epoch 57. Train Loss: 0.5393 | Val MCC: 0.2612



Epoch 58/80: 100%|██████████| 268/268 [00:12<00:00, 20.65it/s, loss=0.403]



Epoch 58. Train Loss: 0.5292 | Val MCC: 0.2725



Epoch 59/80: 100%|██████████| 268/268 [00:12<00:00, 20.77it/s, loss=0.615]



Epoch 59. Train Loss: 0.5300 | Val MCC: 0.2836



Epoch 60/80: 100%|██████████| 268/268 [00:12<00:00, 20.77it/s, loss=0.715]



Epoch 60. Train Loss: 0.5298 | Val MCC: 0.2664



Epoch 61/80: 100%|██████████| 268/268 [00:12<00:00, 20.73it/s, loss=0.588]



Epoch 61. Train Loss: 0.5361 | Val MCC: 0.2624



Epoch 62/80: 100%|██████████| 268/268 [00:12<00:00, 20.78it/s, loss=0.369]



Epoch 62. Train Loss: 0.5241 | Val MCC: 0.2806



Epoch 63/80: 100%|██████████| 268/268 [00:12<00:00, 20.80it/s, loss=0.63]



Epoch 63. Train Loss: 0.5276 | Val MCC: 0.2565



Epoch 64/80: 100%|██████████| 268/268 [00:12<00:00, 20.80it/s, loss=0.509]



Epoch 64. Train Loss: 0.5274 | Val MCC: 0.3055



Epoch 65/80: 100%|██████████| 268/268 [00:12<00:00, 20.85it/s, loss=0.491]



Epoch 65. Train Loss: 0.5335 | Val MCC: 0.2923



Epoch 66/80: 100%|██████████| 268/268 [00:12<00:00, 20.90it/s, loss=0.728]



Epoch 66. Train Loss: 0.5325 | Val MCC: 0.2905



Epoch 67/80: 100%|██████████| 268/268 [00:12<00:00, 20.81it/s, loss=0.519]



Epoch 67. Train Loss: 0.5328 | Val MCC: 0.2685



Epoch 68/80: 100%|██████████| 268/268 [00:13<00:00, 20.58it/s, loss=0.358]



Epoch 68. Train Loss: 0.5291 | Val MCC: 0.2952



Epoch 69/80: 100%|██████████| 268/268 [00:13<00:00, 20.57it/s, loss=0.467]



Epoch 69. Train Loss: 0.5274 | Val MCC: 0.2678



Epoch 70/80: 100%|██████████| 268/268 [00:12<00:00, 20.70it/s, loss=0.972]



Epoch 70. Train Loss: 0.5279 | Val MCC: 0.2836



Epoch 71/80: 100%|██████████| 268/268 [00:12<00:00, 20.65it/s, loss=0.792]



Epoch 71. Train Loss: 0.5289 | Val MCC: 0.2796



Epoch 72/80: 100%|██████████| 268/268 [00:12<00:00, 20.79it/s, loss=0.367]



Epoch 72. Train Loss: 0.5295 | Val MCC: 0.2875



Epoch 73/80: 100%|██████████| 268/268 [00:13<00:00, 20.45it/s, loss=0.595]



Epoch 73. Train Loss: 0.5337 | Val MCC: 0.2775



Epoch 74/80: 100%|██████████| 268/268 [00:13<00:00, 20.55it/s, loss=0.502]



Epoch 74. Train Loss: 0.5210 | Val MCC: 0.2706



Epoch 75/80: 100%|██████████| 268/268 [00:13<00:00, 20.54it/s, loss=0.334]



Epoch 75. Train Loss: 0.5320 | Val MCC: 0.2806



Epoch 76/80: 100%|██████████| 268/268 [00:13<00:00, 20.58it/s, loss=0.567]



Epoch 76. Train Loss: 0.5293 | Val MCC: 0.2913



Epoch 77/80: 100%|██████████| 268/268 [00:12<00:00, 20.87it/s, loss=0.466]



Epoch 77. Train Loss: 0.5334 | Val MCC: 0.2913



Epoch 78/80: 100%|██████████| 268/268 [00:12<00:00, 20.85it/s, loss=0.502]



Epoch 78. Train Loss: 0.5288 | Val MCC: 0.2913



Epoch 79/80: 100%|██████████| 268/268 [00:12<00:00, 20.70it/s, loss=0.56]



Epoch 79. Train Loss: 0.5233 | Val MCC: 0.2913



Epoch 80/80: 100%|██████████| 268/268 [00:12<00:00, 20.78it/s, loss=0.562]



Epoch 80. Train Loss: 0.5212 | Val MCC: 0.2913

Model Saved.
