In [None]:
import torch
from datasets import load_dataset
from transformers import (
    DistilBertTokenizerFast,
    DistilBertForSequenceClassification,
    TrainingArguments,
    Trainer
)


In [None]:
model_id = "distilbert-base-uncased"
dataset_id = "ag_news"


In [None]:
dataset = load_dataset(dataset_id)

train_dataset = dataset['train']
val_dataset = dataset['test'].shard(num_shards=2, index=0)
test_dataset = dataset['test'].shard(num_shards=2, index=1)


In [None]:
tokenizer = DistilBertTokenizerFast.from_pretrained(model_id)


In [None]:
def preprocess_function(examples):
    return tokenizer(
        examples["text"],
        padding="max_length",
        truncation=True,
        max_length=128,
    )


In [None]:
model = DistilBertForSequenceClassification.from_pretrained(model_id, num_labels=4)


In [None]:
training_args = TrainingArguments(
    output_dir="./results",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
)


In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
)


In [None]:
trainer.train()


In [None]:
trainer.evaluate()


In [None]:
trainer.save_model("./saved_model")
