### Pytorch Finetuning ###

In [3]:
from datasets import load_dataset
from transformers import (
    AutoModelForSequenceClassification,
    DistilBertTokenizer,
    TrainingArguments,
    Trainer,
)
import torch
from sklearn.metrics import accuracy_score
import numpy as np

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

tokenizer = DistilBertTokenizer.from_pretrained(
    "distilbert-base-uncased-finetuned-sst-2-english"
)
model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert-base-uncased-finetuned-sst-2-english"
).to(device)

number_of_samples = 100

dataset = load_dataset("imdb")

train_dataset = dataset["train"].shuffle(seed=42).select(range(number_of_samples))
test_dataset = dataset["test"].shuffle(seed=1337).select(range(number_of_samples))

def tokenize_function(examples):
    return tokenizer(examples["text"], padding="max_length", truncation=True)

train_dataset = train_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)

train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
test_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=1)
    return {"accuracy": accuracy_score(labels, predictions)}

training_args = TrainingArguments(
    output_dir="./results",
    num_train_epochs=3,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=10,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=test_dataset,
    compute_metrics=compute_metrics,
)

trainer.train()

results = trainer.evaluate()

print(f"Evaluation Results: {results}")
print(f"Accuracy {results['eval_accuracy']}")
print()

Map: 100%|██████████| 3/3 [00:00<00:00, 17.93 examples/s]


Step,Training Loss


Evaluation Results: {'eval_loss': 0.37706610560417175, 'eval_accuracy': 0.6666666666666666, 'eval_runtime': 2.1035, 'eval_samples_per_second': 1.426, 'eval_steps_per_second': 0.475, 'epoch': 3.0}
