In [3]:
# Step 1: Load IMDb dataset and tokenizer
from datasets import load_dataset
from transformers import DistilBertTokenizerFast

# Load IMDb dataset
dataset = load_dataset("imdb")

# Initialize tokenizer
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

In [4]:
# Step 2: Tokenize the dataset
def tokenize(batch):
    return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=256)

tokenized_dataset = dataset.map(tokenize, batched=True)
tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

Map: 100%|██████████| 25000/25000 [00:06<00:00, 3964.65 examples/s]


In [13]:
# Step 3: Create a balanced evaluation set
import random
import numpy as np

test_dataset = tokenized_dataset["test"]
labels = test_dataset["label"]

# Get indices for each class
class0_indices = [i for i, label in enumerate(labels) if label == 0]
class1_indices = [i for i, label in enumerate(labels) if label == 1]

# Shuffle and select 500 samples from each class
random.seed(42)
random.shuffle(class0_indices)
random.shuffle(class1_indices)

selected_indices = class0_indices[:500] + class1_indices[:500]
random.shuffle(selected_indices)

# Subset the evaluation set
balanced_eval_dataset = test_dataset.select(selected_indices)

# Diagnostic print to confirm balance
print("Balanced eval label distribution:", np.bincount(balanced_eval_dataset["label"].numpy()))

Balanced eval label distribution: [500 500]


In [15]:
# Step 4: Load model and define compute_metrics
from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
import torch

finetune_model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)

def compute_metrics(p):
    preds = torch.argmax(torch.tensor(p.predictions), dim=1)
    labels = p.label_ids
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary', zero_division=0)
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}

# Step 5: Define training arguments
training_args = TrainingArguments(
    output_dir="./results_finetuned",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=100
)

# Step 6: Trainer setup
trainer = Trainer(
    model=finetune_model,
    args=training_args,
    train_dataset=tokenized_dataset["train"].shuffle(seed=42).select(range(5000)),  # small sample for faster training
    eval_dataset=balanced_eval_dataset,
    compute_metrics=compute_metrics,
)

# Step 7: Train and evaluate
trainer.train()
results_finetuned = trainer.evaluate()
print("Final Evaluation Results:", results_finetuned)


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,0.3352,0.270635,0.885,0.888023,0.865275,0.912
2,0.1669,0.344068,0.894,0.893788,0.895582,0.892
3,0.0656,0.450358,0.892,0.891566,0.895161,0.888


Final Evaluation Results: {'eval_loss': 0.45035770535469055, 'eval_accuracy': 0.892, 'eval_f1': 0.891566265060241, 'eval_precision': 0.8951612903225806, 'eval_recall': 0.888, 'eval_runtime': 200.1546, 'eval_samples_per_second': 4.996, 'eval_steps_per_second': 0.315, 'epoch': 3.0}


In [17]:
# Step 1: Imports
from transformers import DistilBertModel, DistilBertTokenizerFast
from datasets import load_dataset
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm

# Step 2: Load IMDb dataset and tokenizer
dataset = load_dataset("imdb")
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

# Step 3: Tokenize (same as before)
def tokenize(batch):
    return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=256)

tokenized_dataset = dataset.map(tokenize, batched=True)
tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

# Step 4: Load frozen DistilBERT model (no classification head)
bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
bert.eval()  # Freeze weights

# Step 5: Extract embeddings (e.g., mean pooled CLS token)
def extract_embeddings(dataset_split):
    dataloader = DataLoader(dataset_split, batch_size=16)
    embeddings, labels = [], []

    with torch.no_grad():
        for batch in tqdm(dataloader):
            input_ids = batch["input_ids"]
            attention_mask = batch["attention_mask"]
            outputs = bert(input_ids=input_ids, attention_mask=attention_mask)
            # Mean pooling
            mean_embeddings = outputs.last_hidden_state.mean(dim=1)
            embeddings.append(mean_embeddings.numpy())
            labels.append(batch["label"].numpy())

    return np.concatenate(embeddings), np.concatenate(labels)

# Step 6: Use smaller subset (to fit faster)
train_subset = tokenized_dataset["train"].shuffle(seed=42).select(range(2000))
test_subset = tokenized_dataset["test"].shuffle(seed=42).select(range(1000))

X_train, y_train = extract_embeddings(train_subset)
X_test, y_test = extract_embeddings(test_subset)

# Step 7: Train classifier (e.g., Logistic Regression)
clf = LogisticRegression(max_iter=1000)
clf.fit(X_train, y_train)

# Step 8: Evaluate
y_pred = clf.predict(X_test)
print(classification_report(y_test, y_pred, digits=4))

Map: 100%|██████████| 50000/50000 [00:12<00:00, 3995.49 examples/s]
100%|██████████| 125/125 [05:27<00:00,  2.62s/it]
100%|██████████| 63/63 [02:47<00:00,  2.66s/it]


              precision    recall  f1-score   support

           0     0.8776    0.8398    0.8583       512
           1     0.8392    0.8770    0.8577       488

    accuracy                         0.8580      1000
   macro avg     0.8584    0.8584    0.8580      1000
weighted avg     0.8588    0.8580    0.8580      1000

