In [1]:
!pip install torch
!pip install safetensors



In [16]:
from lora_transformer import LoraConfig, LoraModel
import torch
from transformers import AutoTokenizer,AutoModelForSequenceClassification, get_linear_schedule_with_warmup, DataCollatorWithPadding
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm
from sklearn.metrics import matthews_corrcoef

In [3]:
device = torch.device("cuda")
print(f"Device: {device}")

Device: cuda


In [4]:
MODEL_NAME = "roberta-base"
TASK = "cola"
BATCH_SIZE = 32
LR = 4e-4
EPOCHS = 80
MAX_LEN = 512
RANK = 8
ALPHA = 16
BIAS = "none"
DROPOUT = 0.0
TARGET_MODULES = ["query", "value"]
EXCLUDE_MODULES =  ["classifier"]

In [5]:
dataset = load_dataset("glue", TASK)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)

README.md: 0.00B [00:00, ?B/s]

cola/train-00000-of-00001.parquet:   0%|          | 0.00/251k [00:00<?, ?B/s]

cola/validation-00000-of-00001.parquet:   0%|          | 0.00/37.6k [00:00<?, ?B/s]

cola/test-00000-of-00001.parquet:   0%|          | 0.00/37.7k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/8551 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1043 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1063 [00:00<?, ? examples/s]

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

In [6]:
def preprocess_function(examples):
    return tokenizer(examples["sentence"], truncation=True, max_length=MAX_LEN)

In [7]:
tokenized_datasets = dataset.map(preprocess_function, batched=True)
tokenized_datasets = tokenized_datasets.remove_columns(["sentence", "idx"])
tokenized_datasets = tokenized_datasets.rename_column("label", "labels")
tokenized_datasets.set_format("torch")

Map:   0%|          | 0/8551 [00:00<?, ? examples/s]

Map:   0%|          | 0/1043 [00:00<?, ? examples/s]

Map:   0%|          | 0/1063 [00:00<?, ? examples/s]

In [8]:
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

train_dataloader = DataLoader(tokenized_datasets["train"],
    shuffle=True,
    batch_size=BATCH_SIZE,
    collate_fn=data_collator)
eval_dataloader = DataLoader(tokenized_datasets["validation"], batch_size=BATCH_SIZE, collate_fn=data_collator)

In [9]:
pretrained_model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, num_labels=2)

model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [10]:
lora_config = LoraConfig(rank = RANK, bias = BIAS, alpha = ALPHA,
                        dropout = DROPOUT,
                        target_modules = TARGET_MODULES , exclude_modules = EXCLUDE_MODULES)


In [11]:
model = LoraModel(pretrained_model, lora_config)
model.to(device)

LoraModel(
  (model): RobertaForSequenceClassification(
    (roberta): RobertaModel(
      (embeddings): RobertaEmbeddings(
        (word_embeddings): Embedding(50265, 768, padding_idx=1)
        (position_embeddings): Embedding(514, 768, padding_idx=1)
        (token_type_embeddings): Embedding(1, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): RobertaEncoder(
        (layer): ModuleList(
          (0-11): 12 x RobertaLayer(
            (attention): RobertaAttention(
              (self): RobertaSdpaSelfAttention(
                (query): LoRALinearLayer(
                  in_features=768, out_features=768, bias=True
                  (dropout): Identity()
                )
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): LoRALinearLayer(
                  in_features=768, out_features=768, bias=True
                  (dropout): 

In [12]:
def count_parameters(model):
    trainable = model.get_n_trainable()
    # pretrained model total
    total = sum(p.numel() for p in model.model.parameters())
    print(f"Trainable Params: {trainable:,} || Total Params: {total:,} || %: {100 * trainable / total:.2f}%")

In [13]:
count_parameters(model)

Trainable Params: 887,042 || Total Params: 124,942,082 || %: 0.71%


In [14]:
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=LR)
num_training_steps = EPOCHS * len(train_dataloader)
scheduler = get_linear_schedule_with_warmup(optimizer,
    num_warmup_steps=int(0.06 * num_training_steps),
    num_training_steps=num_training_steps)

In [17]:
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0

    progress_bar = tqdm(train_dataloader, desc=f"Epoch {epoch + 1}/{EPOCHS}")

    for batch in progress_bar:
        batch = {k: v.to(device) for k, v in batch.items()}

        outputs = model(**batch)
        loss = outputs.loss

        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()

        total_loss += loss.item()
        progress_bar.set_postfix({"loss": loss.item()})

    avg_train_loss = total_loss / len(train_dataloader)

    model.eval()
    preds = []
    labels = []

    for batch in eval_dataloader:
        batch = {k: v.to(device) for k, v in batch.items()}
        with torch.no_grad():
            outputs = model(**batch)

        logits = outputs.logits
        predictions = torch.argmax(logits, dim=-1)

        preds.extend(predictions.cpu().numpy())
        labels.extend(batch["labels"].cpu().numpy())

    mcc = matthews_corrcoef(labels, preds)
    print(f"\nEpoch {epoch + 1}. Train Loss: {avg_train_loss:.4f} | Val MCC: {mcc:.4f}\n")

model.save_model("roberta_lora_sst2.safetensors", merge_weights=False)
print("Model Saved.")

Epoch 1/80: 100%|██████████| 268/268 [00:29<00:00,  9.10it/s, loss=0.317]



Epoch 1. Train Loss: 0.4645 | Val MCC: 0.4750



Epoch 2/80: 100%|██████████| 268/268 [00:30<00:00,  8.88it/s, loss=0.395]



Epoch 2. Train Loss: 0.4274 | Val MCC: 0.5153



Epoch 3/80: 100%|██████████| 268/268 [00:28<00:00,  9.55it/s, loss=0.238]



Epoch 3. Train Loss: 0.3947 | Val MCC: 0.5579



Epoch 4/80: 100%|██████████| 268/268 [00:29<00:00,  9.04it/s, loss=0.192]



Epoch 4. Train Loss: 0.3661 | Val MCC: 0.5406



Epoch 5/80: 100%|██████████| 268/268 [00:27<00:00,  9.62it/s, loss=0.373]



Epoch 5. Train Loss: 0.3389 | Val MCC: 0.6113



Epoch 6/80: 100%|██████████| 268/268 [00:27<00:00,  9.64it/s, loss=0.115]



Epoch 6. Train Loss: 0.3130 | Val MCC: 0.5707



Epoch 7/80: 100%|██████████| 268/268 [00:27<00:00,  9.71it/s, loss=0.751]



Epoch 7. Train Loss: 0.2887 | Val MCC: 0.5908



Epoch 8/80: 100%|██████████| 268/268 [00:27<00:00,  9.62it/s, loss=0.549]



Epoch 8. Train Loss: 0.2613 | Val MCC: 0.5967



Epoch 9/80: 100%|██████████| 268/268 [00:27<00:00,  9.71it/s, loss=0.619]



Epoch 9. Train Loss: 0.2340 | Val MCC: 0.5629



Epoch 10/80: 100%|██████████| 268/268 [00:27<00:00,  9.68it/s, loss=0.388]



Epoch 10. Train Loss: 0.2102 | Val MCC: 0.6209



Epoch 11/80: 100%|██████████| 268/268 [00:27<00:00,  9.70it/s, loss=0.739]



Epoch 11. Train Loss: 0.1928 | Val MCC: 0.6284



Epoch 12/80: 100%|██████████| 268/268 [00:27<00:00,  9.66it/s, loss=0.168]



Epoch 12. Train Loss: 0.1705 | Val MCC: 0.5855



Epoch 13/80: 100%|██████████| 268/268 [00:27<00:00,  9.74it/s, loss=0.0664]



Epoch 13. Train Loss: 0.1536 | Val MCC: 0.6257



Epoch 14/80: 100%|██████████| 268/268 [00:27<00:00,  9.63it/s, loss=0.0128]



Epoch 14. Train Loss: 0.1403 | Val MCC: 0.5961



Epoch 15/80: 100%|██████████| 268/268 [00:27<00:00,  9.78it/s, loss=0.337]



Epoch 15. Train Loss: 0.1355 | Val MCC: 0.6036



Epoch 16/80: 100%|██████████| 268/268 [00:27<00:00,  9.66it/s, loss=0.0485]



Epoch 16. Train Loss: 0.1262 | Val MCC: 0.6123



Epoch 17/80: 100%|██████████| 268/268 [00:27<00:00,  9.75it/s, loss=0.155]



Epoch 17. Train Loss: 0.1264 | Val MCC: 0.6288



Epoch 18/80: 100%|██████████| 268/268 [00:27<00:00,  9.65it/s, loss=0.114]



Epoch 18. Train Loss: 0.1101 | Val MCC: 0.6406



Epoch 19/80: 100%|██████████| 268/268 [00:27<00:00,  9.71it/s, loss=0.0103]



Epoch 19. Train Loss: 0.1014 | Val MCC: 0.5966



Epoch 20/80: 100%|██████████| 268/268 [00:27<00:00,  9.63it/s, loss=0.16]



Epoch 20. Train Loss: 0.0902 | Val MCC: 0.6235



Epoch 21/80: 100%|██████████| 268/268 [00:27<00:00,  9.62it/s, loss=0.0131]



Epoch 21. Train Loss: 0.0903 | Val MCC: 0.5727



Epoch 22/80: 100%|██████████| 268/268 [00:27<00:00,  9.64it/s, loss=0.0107]



Epoch 22. Train Loss: 0.0897 | Val MCC: 0.6143



Epoch 23/80: 100%|██████████| 268/268 [00:27<00:00,  9.72it/s, loss=0.0842]



Epoch 23. Train Loss: 0.0768 | Val MCC: 0.6034



Epoch 24/80: 100%|██████████| 268/268 [00:27<00:00,  9.70it/s, loss=0.0204]



Epoch 24. Train Loss: 0.0755 | Val MCC: 0.6082



Epoch 25/80: 100%|██████████| 268/268 [00:27<00:00,  9.78it/s, loss=0.0103]



Epoch 25. Train Loss: 0.0752 | Val MCC: 0.6034



Epoch 26/80: 100%|██████████| 268/268 [00:27<00:00,  9.69it/s, loss=0.103]



Epoch 26. Train Loss: 0.0695 | Val MCC: 0.5983



Epoch 27/80: 100%|██████████| 268/268 [00:27<00:00,  9.64it/s, loss=0.0171]



Epoch 27. Train Loss: 0.0690 | Val MCC: 0.6263



Epoch 28/80: 100%|██████████| 268/268 [00:27<00:00,  9.70it/s, loss=0.113]



Epoch 28. Train Loss: 0.0596 | Val MCC: 0.6107



Epoch 29/80: 100%|██████████| 268/268 [00:27<00:00,  9.71it/s, loss=0.0438]



Epoch 29. Train Loss: 0.0577 | Val MCC: 0.5987



Epoch 30/80: 100%|██████████| 268/268 [00:27<00:00,  9.82it/s, loss=0.134]



Epoch 30. Train Loss: 0.0590 | Val MCC: 0.6209



Epoch 31/80: 100%|██████████| 268/268 [00:27<00:00,  9.66it/s, loss=0.24]



Epoch 31. Train Loss: 0.0511 | Val MCC: 0.6334



Epoch 32/80: 100%|██████████| 268/268 [00:27<00:00,  9.72it/s, loss=0.0544]



Epoch 32. Train Loss: 0.0613 | Val MCC: 0.6082



Epoch 33/80: 100%|██████████| 268/268 [00:27<00:00,  9.72it/s, loss=0.0407]



Epoch 33. Train Loss: 0.0514 | Val MCC: 0.6234



Epoch 34/80: 100%|██████████| 268/268 [00:27<00:00,  9.65it/s, loss=0.184]



Epoch 34. Train Loss: 0.0496 | Val MCC: 0.6182



Epoch 35/80: 100%|██████████| 268/268 [00:27<00:00,  9.70it/s, loss=0.000282]



Epoch 35. Train Loss: 0.0464 | Val MCC: 0.6331



Epoch 36/80: 100%|██████████| 268/268 [00:27<00:00,  9.70it/s, loss=0.0476]



Epoch 36. Train Loss: 0.0473 | Val MCC: 0.6091



Epoch 37/80: 100%|██████████| 268/268 [00:27<00:00,  9.66it/s, loss=0.000896]



Epoch 37. Train Loss: 0.0459 | Val MCC: 0.6143



Epoch 38/80: 100%|██████████| 268/268 [00:27<00:00,  9.68it/s, loss=0.00503]



Epoch 38. Train Loss: 0.0420 | Val MCC: 0.6059



Epoch 39/80: 100%|██████████| 268/268 [00:27<00:00,  9.70it/s, loss=0.00996]



Epoch 39. Train Loss: 0.0454 | Val MCC: 0.6234



Epoch 40/80: 100%|██████████| 268/268 [00:27<00:00,  9.66it/s, loss=0.0236]



Epoch 40. Train Loss: 0.0404 | Val MCC: 0.6358



Epoch 41/80: 100%|██████████| 268/268 [00:27<00:00,  9.66it/s, loss=0.00238]



Epoch 41. Train Loss: 0.0367 | Val MCC: 0.6186



Epoch 42/80: 100%|██████████| 268/268 [00:27<00:00,  9.61it/s, loss=0.0621]



Epoch 42. Train Loss: 0.0362 | Val MCC: 0.6058



Epoch 43/80: 100%|██████████| 268/268 [00:27<00:00,  9.57it/s, loss=0.0456]



Epoch 43. Train Loss: 0.0410 | Val MCC: 0.6135



Epoch 44/80: 100%|██████████| 268/268 [00:27<00:00,  9.60it/s, loss=0.00955]



Epoch 44. Train Loss: 0.0333 | Val MCC: 0.6258



Epoch 45/80: 100%|██████████| 268/268 [00:27<00:00,  9.63it/s, loss=0.000554]



Epoch 45. Train Loss: 0.0357 | Val MCC: 0.6261



Epoch 46/80: 100%|██████████| 268/268 [00:27<00:00,  9.69it/s, loss=0.373]



Epoch 46. Train Loss: 0.0347 | Val MCC: 0.6506



Epoch 47/80: 100%|██████████| 268/268 [00:27<00:00,  9.61it/s, loss=0.0014]



Epoch 47. Train Loss: 0.0315 | Val MCC: 0.6481



Epoch 48/80: 100%|██████████| 268/268 [00:27<00:00,  9.62it/s, loss=0.0155]



Epoch 48. Train Loss: 0.0325 | Val MCC: 0.6483



Epoch 49/80: 100%|██████████| 268/268 [00:27<00:00,  9.67it/s, loss=0.00566]



Epoch 49. Train Loss: 0.0329 | Val MCC: 0.6338



Epoch 50/80: 100%|██████████| 268/268 [00:27<00:00,  9.70it/s, loss=0.0108]



Epoch 50. Train Loss: 0.0270 | Val MCC: 0.6456



Epoch 51/80: 100%|██████████| 268/268 [00:27<00:00,  9.60it/s, loss=0.00252]



Epoch 51. Train Loss: 0.0300 | Val MCC: 0.6530



Epoch 52/80: 100%|██████████| 268/268 [00:27<00:00,  9.71it/s, loss=0.000977]



Epoch 52. Train Loss: 0.0289 | Val MCC: 0.6240



Epoch 53/80: 100%|██████████| 268/268 [00:27<00:00,  9.67it/s, loss=0.000293]



Epoch 53. Train Loss: 0.0268 | Val MCC: 0.6331



Epoch 54/80: 100%|██████████| 268/268 [00:27<00:00,  9.58it/s, loss=0.00051]



Epoch 54. Train Loss: 0.0274 | Val MCC: 0.6209



Epoch 55/80: 100%|██████████| 268/268 [00:27<00:00,  9.61it/s, loss=0.00606]



Epoch 55. Train Loss: 0.0286 | Val MCC: 0.6282



Epoch 56/80: 100%|██████████| 268/268 [00:27<00:00,  9.75it/s, loss=0.248]



Epoch 56. Train Loss: 0.0255 | Val MCC: 0.6232



Epoch 57/80: 100%|██████████| 268/268 [00:27<00:00,  9.73it/s, loss=0.414]



Epoch 57. Train Loss: 0.0269 | Val MCC: 0.6383



Epoch 58/80: 100%|██████████| 268/268 [00:27<00:00,  9.67it/s, loss=0.000706]



Epoch 58. Train Loss: 0.0244 | Val MCC: 0.6186



Epoch 59/80: 100%|██████████| 268/268 [00:27<00:00,  9.67it/s, loss=0.000102]



Epoch 59. Train Loss: 0.0206 | Val MCC: 0.6282



Epoch 60/80: 100%|██████████| 268/268 [00:27<00:00,  9.63it/s, loss=0.000183]



Epoch 60. Train Loss: 0.0212 | Val MCC: 0.6157



Epoch 61/80: 100%|██████████| 268/268 [00:27<00:00,  9.59it/s, loss=0.00223]



Epoch 61. Train Loss: 0.0233 | Val MCC: 0.6166



Epoch 62/80: 100%|██████████| 268/268 [00:27<00:00,  9.63it/s, loss=0.0106]



Epoch 62. Train Loss: 0.0203 | Val MCC: 0.6238



Epoch 63/80: 100%|██████████| 268/268 [00:27<00:00,  9.67it/s, loss=0.000176]



Epoch 63. Train Loss: 0.0206 | Val MCC: 0.6235



Epoch 64/80: 100%|██████████| 268/268 [00:27<00:00,  9.66it/s, loss=0.00139]



Epoch 64. Train Loss: 0.0198 | Val MCC: 0.6035



Epoch 65/80: 100%|██████████| 268/268 [00:27<00:00,  9.70it/s, loss=0.000179]



Epoch 65. Train Loss: 0.0174 | Val MCC: 0.6082



Epoch 66/80: 100%|██████████| 268/268 [00:27<00:00,  9.60it/s, loss=0.0073]



Epoch 66. Train Loss: 0.0190 | Val MCC: 0.6143



Epoch 67/80: 100%|██████████| 268/268 [00:28<00:00,  9.52it/s, loss=0.000161]



Epoch 67. Train Loss: 0.0181 | Val MCC: 0.6085



Epoch 68/80: 100%|██████████| 268/268 [00:27<00:00,  9.73it/s, loss=0.000151]



Epoch 68. Train Loss: 0.0162 | Val MCC: 0.6257



Epoch 69/80: 100%|██████████| 268/268 [00:27<00:00,  9.69it/s, loss=4.96e-5]



Epoch 69. Train Loss: 0.0143 | Val MCC: 0.6160



Epoch 70/80: 100%|██████████| 268/268 [00:27<00:00,  9.67it/s, loss=0.000913]



Epoch 70. Train Loss: 0.0179 | Val MCC: 0.6086



Epoch 71/80: 100%|██████████| 268/268 [00:27<00:00,  9.63it/s, loss=0.000519]



Epoch 71. Train Loss: 0.0172 | Val MCC: 0.6284



Epoch 72/80: 100%|██████████| 268/268 [00:27<00:00,  9.72it/s, loss=7.48e-5]



Epoch 72. Train Loss: 0.0153 | Val MCC: 0.6086



Epoch 73/80: 100%|██████████| 268/268 [00:27<00:00,  9.77it/s, loss=8.92e-5]



Epoch 73. Train Loss: 0.0158 | Val MCC: 0.6208



Epoch 74/80: 100%|██████████| 268/268 [00:27<00:00,  9.66it/s, loss=0.000356]



Epoch 74. Train Loss: 0.0146 | Val MCC: 0.6082



Epoch 75/80: 100%|██████████| 268/268 [00:27<00:00,  9.68it/s, loss=7.67e-5]



Epoch 75. Train Loss: 0.0164 | Val MCC: 0.6160



Epoch 76/80: 100%|██████████| 268/268 [00:27<00:00,  9.67it/s, loss=0.00233]



Epoch 76. Train Loss: 0.0160 | Val MCC: 0.6086



Epoch 77/80: 100%|██████████| 268/268 [00:27<00:00,  9.72it/s, loss=0.000189]



Epoch 77. Train Loss: 0.0140 | Val MCC: 0.6136



Epoch 78/80: 100%|██████████| 268/268 [00:27<00:00,  9.62it/s, loss=0.00014]



Epoch 78. Train Loss: 0.0115 | Val MCC: 0.6136



Epoch 79/80: 100%|██████████| 268/268 [00:27<00:00,  9.58it/s, loss=0.000697]



Epoch 79. Train Loss: 0.0114 | Val MCC: 0.6136



Epoch 80/80: 100%|██████████| 268/268 [00:27<00:00,  9.59it/s, loss=0.000851]



Epoch 80. Train Loss: 0.0120 | Val MCC: 0.6136

Model Saved.
