In [None]:
# Imports
import evaluate
import numpy as np
import torch
from datasets import Dataset, DatasetDict, load_dataset
from peft import LoraConfig, PeftConfig, PeftModel, get_peft_model
from transformers import (
    AutoConfig,
    AutoModelForSequenceClassification,
    AutoTokenizer,
    DataCollatorWithPadding,
    Trainer,
    TrainingArguments,
)


In [None]:
device = "mps" if torch.backends.mps.is_available() else "cpu"
print(f"Running on {device}!")

In [None]:
model_checkpoint = 'distilbert-base-uncased'

# define label maps
id2label = {0: "Negative", 1: "Positive"}
label2id = {"Negative": 0, "Positive": 1}

# generate classification model from model_checkpoint
model = AutoModelForSequenceClassification.from_pretrained(
    model_checkpoint, num_labels=2, label2id=label2id, id2label=id2label
)


In [None]:
# Load the data
dataset = load_dataset("shawhin/imdb-truncated")
dataset

In [None]:
# create tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, add_prefix_space=True)

# create tokenize function
def tokenize_function(example):
    # extract text
    text = example["text"]

    # tokenize and truncate text
    tokenizer.truncation_side = "left"
    tokenized_inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=512
    )

    return tokenized_inputs

if tokenizer.pad_token is None:
    tokenizer.add_special_tokens({"pad_token": '[PAD]'})
    model.resize_token_embeddings(len(tokenizer))

# tokenize training and validation datasets
tokenized_datasets = dataset.map(tokenize_function, batched=True)

In [None]:
# create DataCollator (automatically pads all data to have the same length)

data_collator = DataCollatorWithPadding(tokenizer=tokenizer)

In [None]:
# define evalutation metrics
accuracy = evaluate.load("accuracy")

# define an evaluation function to pass into trainer later
def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=1)

    return {"accuracy": accuracy.compute(predictions=predictions, references=labels)}

In [None]:
# define list of examples
text_list = ["It was good.",
             "Not a fan, don't recommend.",
             "Better than the first one.",
             "This is not worth watching even once.",
             "This one is a pass.",
             "Do not watch."
            ]

print("Untrained model predictions:")
print("------------------------------")
for text in text_list:
    # tokenize text
    inputs = tokenizer.encode(text, return_tensors="pt")
    # compute logits
    logits = model(inputs).logits
    # convert logits to label
    predictions = torch.argmax(logits)

    print(text, "-", id2label[predictions.tolist()])


In [None]:
peft_config = LoraConfig(
    task_type="SEQ_CLS", # sequence classification
    r=4, # intrinsic rank of trainable weight matrix
    lora_alpha=32, # this is like a learning rate (?)
    lora_dropout=0.01, # probability of dropout
    target_modules=["q_lin"] # we apply lora to query layer
    )

In [None]:
model = get_peft_model(model, peft_config=peft_config)# .to(device)
model.print_trainable_parameters()

In [None]:
# hypermarameters
lr = 4e-3
batch_size = 4
num_epochs = 10

# define training arguments
trainings_args = TrainingArguments(
    output_dir=f"{model_checkpoint}-lota-text-classification",
    learning_rate=lr,
    per_device_eval_batch_size=batch_size,
    per_device_train_batch_size=batch_size,
    num_train_epochs=num_epochs,
    weight_decay=0.01,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
)

In [None]:
# create trainer object
trainer = Trainer(
    model=model, # define base model
    args=trainings_args, # previously defined training arguments
    train_dataset=tokenized_datasets["train"], # training data
    eval_dataset=tokenized_datasets["validation"], # validation data
    tokenizer=tokenizer, # define tokenizer
    data_collator=data_collator, # this will dynamically pad examples in each batch
    compute_metrics=compute_metrics # evaluates model using previously defined compute_metrics() function
)

# train model
trainer.train()

In [None]:
# model.to(device)

print("Trained model predictions")
print("-------------------------")
for text in text_list:
    # tokenize text
    inputs = tokenizer.encode(text, return_tensors="pt").to(device)
    # compute logits
    logits = model(inputs).logits
    # convert logits to label
    predictions = torch.argmax(logits)

    print(text, "-", id2label[predictions.tolist()])
