In [1]:
import torch
from torch import nn
from torch.nn import functional as F

model_name = "google/electra-base-discriminator"

In [2]:
from datasets import load_dataset

imdb = load_dataset("imdb")

In [3]:
imdb["test"][0]


{'text': 'I love sci-fi and am willing to put up with a lot. Sci-fi movies/TV are usually underfunded, under-appreciated and misunderstood. I tried to like this, I really did, but it is to good TV sci-fi as Babylon 5 is to Star Trek (the original). Silly prosthetics, cheap cardboard sets, stilted dialogues, CG that doesn\'t match the background, and painfully one-dimensional characters cannot be overcome with a \'sci-fi\' setting. (I\'m sure there are those of you out there who think Babylon 5 is good sci-fi TV. It\'s not. It\'s clichéd and uninspiring.) While US viewers might like emotion and character development, sci-fi is a genre that does not take itself seriously (cf. Star Trek). It may treat important issues, yet not as a serious philosophy. It\'s really difficult to care about the characters here as they are not simply foolish, just missing a spark of life. Their actions and reactions are wooden and predictable, often painful to watch. The makers of Earth KNOW it\'s rubbish as 

In [4]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(model_name)

In [5]:
def preprocess_function(examples):
    return tokenizer(examples["text"], truncation=True)#, max_length=1024)

In [6]:
tokenized_imdb = imdb.map(preprocess_function, batched=True)

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

In [7]:
from transformers import DataCollatorWithPadding

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [8]:
import evaluate

accuracy = evaluate.load("accuracy")

bin c:\Users\Abstract\mambaforge\envs\sentenv2\lib\site-packages\bitsandbytes\libbitsandbytes_cuda121.dll


In [9]:
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [10]:
id2label = {0: "NEGATIVE", 1: "POSITIVE"}
label2id = {"NEGATIVE": 0, "POSITIVE": 1}

In [11]:
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
from peft import get_peft_config, get_peft_model, LoraConfig, TaskType


model = AutoModelForSequenceClassification.from_pretrained(
    model_name, num_labels=2, id2label=id2label, label2id=label2id
)


peft_config = LoraConfig(
    task_type=TaskType.SEQ_CLS, inference_mode=False, r=8, lora_alpha=32, lora_dropout=0.1
)

model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

# for param in model.transformer.parameters():
#     param.requires_grad = False

Some weights of ElectraForSequenceClassification were not initialized from the model checkpoint at google/electra-base-discriminator and are newly initialized: ['classifier.out_proj.weight', 'classifier.out_proj.bias', 'classifier.dense.weight', 'classifier.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


trainable params: 1479172 || all params: 110370820 || trainable%: 1.3401839362976555


In [12]:
training_args = TrainingArguments(
    output_dir="my_awesome_lora_"+model_name,
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    # gradient_accumulation_steps=4,
    num_train_epochs=2,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    logging_steps=100,
    warmup_steps=100,
    # load_best_model_at_end=True,
    # push_to_hub=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_imdb["train"],
    eval_dataset=tokenized_imdb["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()



  0%|          | 0/3126 [00:00<?, ?it/s]

You're using a ElectraTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


{'loss': 0.692, 'learning_rate': 2e-05, 'epoch': 0.06}
{'loss': 0.6855, 'learning_rate': 1.9339061467283543e-05, 'epoch': 0.13}
{'loss': 0.6575, 'learning_rate': 1.8678122934567087e-05, 'epoch': 0.19}
{'loss': 0.4421, 'learning_rate': 1.8017184401850628e-05, 'epoch': 0.26}
{'loss': 0.2843, 'learning_rate': 1.7356245869134173e-05, 'epoch': 0.32}
{'loss': 0.267, 'learning_rate': 1.6695307336417714e-05, 'epoch': 0.38}
{'loss': 0.2422, 'learning_rate': 1.603436880370126e-05, 'epoch': 0.45}
{'loss': 0.2383, 'learning_rate': 1.53734302709848e-05, 'epoch': 0.51}
{'loss': 0.2297, 'learning_rate': 1.4712491738268342e-05, 'epoch': 0.58}
{'loss': 0.217, 'learning_rate': 1.4051553205551885e-05, 'epoch': 0.64}
{'loss': 0.2304, 'learning_rate': 1.3390614672835428e-05, 'epoch': 0.7}
{'loss': 0.1937, 'learning_rate': 1.2729676140118969e-05, 'epoch': 0.77}
{'loss': 0.2039, 'learning_rate': 1.2068737607402512e-05, 'epoch': 0.83}
{'loss': 0.2229, 'learning_rate': 1.1407799074686054e-05, 'epoch': 0.9}
{'l

  0%|          | 0/1563 [00:00<?, ?it/s]

{'eval_loss': 0.177137553691864, 'eval_accuracy': 0.93684, 'eval_runtime': 255.331, 'eval_samples_per_second': 97.912, 'eval_steps_per_second': 6.121, 'epoch': 1.0}
{'loss': 0.1927, 'learning_rate': 1.008592200925314e-05, 'epoch': 1.02}
{'loss': 0.1934, 'learning_rate': 9.424983476536683e-06, 'epoch': 1.09}
{'loss': 0.2017, 'learning_rate': 8.764044943820226e-06, 'epoch': 1.15}
{'loss': 0.1645, 'learning_rate': 8.103106411103768e-06, 'epoch': 1.22}
{'loss': 0.2032, 'learning_rate': 7.442167878387311e-06, 'epoch': 1.28}
{'loss': 0.1877, 'learning_rate': 6.781229345670853e-06, 'epoch': 1.34}
{'loss': 0.1622, 'learning_rate': 6.120290812954396e-06, 'epoch': 1.41}
{'loss': 0.1964, 'learning_rate': 5.459352280237939e-06, 'epoch': 1.47}
{'loss': 0.1815, 'learning_rate': 4.798413747521481e-06, 'epoch': 1.54}
{'loss': 0.1693, 'learning_rate': 4.137475214805023e-06, 'epoch': 1.6}
{'loss': 0.1721, 'learning_rate': 3.476536682088566e-06, 'epoch': 1.66}
{'loss': 0.2127, 'learning_rate': 2.81559814

  0%|          | 0/1563 [00:00<?, ?it/s]

{'eval_loss': 0.16635367274284363, 'eval_accuracy': 0.94284, 'eval_runtime': 251.288, 'eval_samples_per_second': 99.487, 'eval_steps_per_second': 6.22, 'epoch': 2.0}
{'train_runtime': 1907.5034, 'train_samples_per_second': 26.212, 'train_steps_per_second': 1.639, 'train_loss': 0.25676146380350673, 'epoch': 2.0}


TrainOutput(global_step=3126, training_loss=0.25676146380350673, metrics={'train_runtime': 1907.5034, 'train_samples_per_second': 26.212, 'train_steps_per_second': 1.639, 'train_loss': 0.25676146380350673, 'epoch': 2.0})