In [None]:
!uv pip install evaluate

[2mUsing Python 3.12.12 environment at: /usr[0m
[2K[2mResolved [1m33 packages[0m [2min 999ms[0m[0m
[2K[2mPrepared [1m1 package[0m [2min 42ms[0m[0m
[2K[2mInstalled [1m1 package[0m [2min 12ms[0m[0m
 [32m+[39m [1mevaluate[0m[2m==0.4.6[0m


In [None]:
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    TrainingArguments,
    Trainer,
)
import numpy as np
import evaluate
import torch

# -----------------------
# 1. Load IMDB dataset
# -----------------------
dataset = load_dataset("imdb")

train_dataset = dataset["train"]
eval_dataset = dataset["test"]

# -----------------------
# 2. Load model & tokenizer
# -----------------------
model_checkpoint = "distilbert-base-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint)

# -----------------------
# 3. Tokenization
# -----------------------
def preprocess_function(examples):
    return tokenizer(
        examples["text"],
        truncation=True,
        padding="max_length",
        max_length=256,
    )

train_dataset = train_dataset.map(preprocess_function, batched=True)
eval_dataset = eval_dataset.map(preprocess_function, batched=True)

train_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])
eval_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

In [None]:
# -----------------------
# 4. Load model (binary classification)
# -----------------------
num_labels = 2
model = AutoModelForSequenceClassification.from_pretrained(
    model_checkpoint, num_labels=num_labels
)

# -----------------------
# 5. Freeze base model parameters
# -----------------------
for param in model.base_model.parameters():
    param.requires_grad = False

# -----------------------
# 6. Parameter counting utility
# -----------------------
def count_parameters(model):
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return {
        "Total": total_params,
        "Trainable": trainable_params,
        "Frozen": total_params - trainable_params,
    }

param_count = count_parameters(model)
for k, v in param_count.items():
    print(f"{k}: {v}")

# -----------------------
# 7. Define metric
# -----------------------
accuracy_metric = evaluate.load("accuracy")

def compute_metrics(eval_pred):
    logits, labels = eval_pred
    predictions = np.argmax(logits, axis=1)
    return accuracy_metric.compute(predictions=predictions, references=labels)

# -----------------------
# 8. Training arguments (compatible version)
# -----------------------
training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",        # âœ… Use eval_strategy (older versions)
    save_strategy="epoch",
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=2,
    weight_decay=0.01,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    report_to="none",
    logging_steps=10,
)

# -----------------------
# 9. Trainer setup
# -----------------------
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,
)

# -----------------------
# 10. Train model
# -----------------------
trainer.train()

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Map:   0%|          | 0/25000 [00:00<?, ? examples/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Total: 66955010
Trainable: 592130
Frozen: 66362880


Downloading builder script: 0.00B [00:00, ?B/s]

  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Accuracy
1,0.3546,0.384328,0.82996
2,0.4333,0.374064,0.83432


TrainOutput(global_step=3126, training_loss=0.42818721699851947, metrics={'train_runtime': 766.9715, 'train_samples_per_second': 65.191, 'train_steps_per_second': 4.076, 'total_flos': 3311684966400000.0, 'train_loss': 0.42818721699851947, 'epoch': 2.0})