In [5]:
import time
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
from peft import get_peft_model, LoraConfig, TaskType


def tokenize_function(examples):
    return tokenizer(examples["text"], truncation=True, padding="max_length", max_length=128)


dataset = load_dataset("ag_news")
train_dataset = dataset["train"]
test_dataset = dataset["test"]



In [None]:
test_dataset[0:20]

{'text': ["Fears for T N pension after talks Unions representing workers at Turner   Newall say they are 'disappointed' after talks with stricken parent firm Federal Mogul.",
  'The Race is On: Second Private Team Sets Launch Date for Human Spaceflight (SPACE.com) SPACE.com - TORONTO, Canada -- A second\\team of rocketeers competing for the  #36;10 million Ansari X Prize, a contest for\\privately funded suborbital space flight, has officially announced the first\\launch date for its manned rocket.',
  'Ky. Company Wins Grant to Study Peptides (AP) AP - A company founded by a chemistry researcher at the University of Louisville won a grant to develop a method of producing better peptides, which are short chains of amino acids, the building blocks of proteins.',
  "Prediction Unit Helps Forecast Wildfires (AP) AP - It's barely dawn when Mike Fitzpatrick starts his shift with a blur of colorful maps, figures and endless charts, but already he knows what the day will bring. Lightning will 

In [2]:
model_name = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name)

train_dataset = train_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)
train_dataset = train_dataset.remove_columns(["text"])
test_dataset = test_dataset.remove_columns(["text"])
train_dataset.set_format("torch")
test_dataset.set_format("torch")


def train_and_evaluate(model, method_name, train_args):
    trainer = Trainer(
        model=model,
        args=train_args,
        train_dataset=train_dataset,
        eval_dataset=test_dataset,
        tokenizer=tokenizer,
        compute_metrics=lambda p: {
            "accuracy": (p.predictions.argmax(-1) == p.label_ids).mean()
        },
    )

    start_time = time.time()
    trainer.train()
    training_time = time.time() - start_time

    metrics = trainer.evaluate()
    model.save_pretrained("./lora_final_model")
    tokenizer.save_pretrained("./lora_final_model")
    return metrics, training_time


training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
    push_to_hub=False,
)

lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    r=8,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_lin", "k_lin", "v_lin"],
)
base_model = AutoModelForSequenceClassification.from_pretrained(
    model_name, num_labels=4)
lora_model = get_peft_model(base_model, lora_config)

lora_metrics, lora_time = train_and_evaluate(lora_model, "lora", training_args)
lora_trainable_params = sum(p.numel()
                            for p in lora_model.parameters() if p.requires_grad)

print(f"LoRA - Trainable Parameters: {lora_trainable_params}")
print(f"LoRA - Training Time: {lora_time:.2f} seconds")
print(f"LoRA - Metrics: {lora_metrics}")

Map:   0%|          | 0/7600 [00:00<?, ? examples/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


  0%|          | 0/22500 [00:00<?, ?it/s]

  attn_output = torch.nn.functional.scaled_dot_product_attention(


{'loss': 1.3591, 'grad_norm': 1.4737880229949951, 'learning_rate': 4.997777777777778e-05, 'epoch': 0.0}
{'loss': 1.3054, 'grad_norm': 1.5686885118484497, 'learning_rate': 4.995555555555556e-05, 'epoch': 0.0}
{'loss': 1.2502, 'grad_norm': 1.4237165451049805, 'learning_rate': 4.993333333333334e-05, 'epoch': 0.0}
{'loss': 1.184, 'grad_norm': 1.5734106302261353, 'learning_rate': 4.991111111111111e-05, 'epoch': 0.01}
{'loss': 1.0801, 'grad_norm': 1.5965105295181274, 'learning_rate': 4.9888888888888894e-05, 'epoch': 0.01}
{'loss': 0.9459, 'grad_norm': 2.915076494216919, 'learning_rate': 4.986666666666667e-05, 'epoch': 0.01}
{'loss': 0.8446, 'grad_norm': 1.969710111618042, 'learning_rate': 4.984444444444445e-05, 'epoch': 0.01}
{'loss': 0.79, 'grad_norm': 1.8664323091506958, 'learning_rate': 4.982222222222222e-05, 'epoch': 0.01}
{'loss': 0.6438, 'grad_norm': 2.717935562133789, 'learning_rate': 4.9800000000000004e-05, 'epoch': 0.01}
{'loss': 0.5478, 'grad_norm': 1.4625554084777832, 'learning_ra

  0%|          | 0/475 [00:00<?, ?it/s]

{'eval_loss': 0.21919435262680054, 'eval_accuracy': 0.9226315789473685, 'eval_runtime': 28.4779, 'eval_samples_per_second': 266.873, 'eval_steps_per_second': 16.68, 'epoch': 1.0}




{'loss': 0.1865, 'grad_norm': 1.1646591424942017, 'learning_rate': 3.3311111111111116e-05, 'epoch': 1.0}
{'loss': 0.2118, 'grad_norm': 0.8110383749008179, 'learning_rate': 3.328888888888889e-05, 'epoch': 1.0}
{'loss': 0.1248, 'grad_norm': 3.680609941482544, 'learning_rate': 3.326666666666667e-05, 'epoch': 1.0}
{'loss': 0.13, 'grad_norm': 3.156205177307129, 'learning_rate': 3.3244444444444445e-05, 'epoch': 1.01}
{'loss': 0.2645, 'grad_norm': 3.5746958255767822, 'learning_rate': 3.322222222222222e-05, 'epoch': 1.01}
{'loss': 0.2635, 'grad_norm': 2.9165894985198975, 'learning_rate': 3.32e-05, 'epoch': 1.01}
{'loss': 0.1641, 'grad_norm': 4.382316589355469, 'learning_rate': 3.317777777777778e-05, 'epoch': 1.01}
{'loss': 0.1609, 'grad_norm': 2.836970806121826, 'learning_rate': 3.3155555555555556e-05, 'epoch': 1.01}
{'loss': 0.2509, 'grad_norm': 4.427008628845215, 'learning_rate': 3.313333333333333e-05, 'epoch': 1.01}
{'loss': 0.1708, 'grad_norm': 1.5031808614730835, 'learning_rate': 3.311111

  0%|          | 0/475 [00:00<?, ?it/s]

{'eval_loss': 0.2052578181028366, 'eval_accuracy': 0.9307894736842105, 'eval_runtime': 28.8906, 'eval_samples_per_second': 263.062, 'eval_steps_per_second': 16.441, 'epoch': 2.0}
{'loss': 0.2401, 'grad_norm': 2.751708984375, 'learning_rate': 1.6644444444444445e-05, 'epoch': 2.0}
{'loss': 0.2559, 'grad_norm': 4.192925453186035, 'learning_rate': 1.6622222222222223e-05, 'epoch': 2.0}
{'loss': 0.162, 'grad_norm': 3.021077871322632, 'learning_rate': 1.66e-05, 'epoch': 2.0}
{'loss': 0.1279, 'grad_norm': 0.5824797749519348, 'learning_rate': 1.6577777777777778e-05, 'epoch': 2.01}
{'loss': 0.1793, 'grad_norm': 4.517690658569336, 'learning_rate': 1.655555555555556e-05, 'epoch': 2.01}
{'loss': 0.1933, 'grad_norm': 3.3399338722229004, 'learning_rate': 1.6533333333333333e-05, 'epoch': 2.01}
{'loss': 0.1963, 'grad_norm': 3.6532740592956543, 'learning_rate': 1.651111111111111e-05, 'epoch': 2.01}
{'loss': 0.0519, 'grad_norm': 1.0803459882736206, 'learning_rate': 1.648888888888889e-05, 'epoch': 2.01}
{



  0%|          | 0/475 [00:00<?, ?it/s]

{'eval_loss': 0.20430131256580353, 'eval_accuracy': 0.9317105263157894, 'eval_runtime': 28.445, 'eval_samples_per_second': 267.183, 'eval_steps_per_second': 16.699, 'epoch': 3.0}
{'train_runtime': 3087.8105, 'train_samples_per_second': 116.587, 'train_steps_per_second': 7.287, 'train_loss': 0.22366626537905798, 'epoch': 3.0}


  0%|          | 0/475 [00:00<?, ?it/s]

LoRA - Trainable Parameters: 814852
LoRA - Training Time: 3088.01 seconds
LoRA - Metrics: {'eval_loss': 0.20430131256580353, 'eval_accuracy': 0.9317105263157894, 'eval_runtime': 28.6819, 'eval_samples_per_second': 264.975, 'eval_steps_per_second': 16.561, 'epoch': 3.0}
