In [3]:
# Step 1: Loading imports 
from transformers import DistilBertModel, DistilBertTokenizerFast
from transformers import DistilBertForSequenceClassification, Trainer, TrainingArguments
from datasets import load_dataset
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
from sklearn.model_selection import train_test_split
import torch
from torch.utils.data import DataLoader
import numpy as np
from tqdm import tqdm
import random
from sklearn.metrics import accuracy_score, precision_recall_fscore_support

In [7]:
# Step 2: Loading IMDb dataset
dataset = load_dataset("imdb")

# Initializing tokenizer
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

In [9]:
# Step 3: Tokenize the dataset
def tokenize(batch):
    return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=256)

tokenized_dataset = dataset.map(tokenize, batched=True)
tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

In [11]:
# Step 4: Creating a balanced evaluation set
test_dataset = tokenized_dataset["test"]
labels = test_dataset["label"]

# Getting indices for each class
class0_indices = [i for i, label in enumerate(labels) if label == 0]
class1_indices = [i for i, label in enumerate(labels) if label == 1]

# Shuffling and selecting 500 samples from each class
random.seed(42)
random.shuffle(class0_indices)
random.shuffle(class1_indices)

selected_indices = class0_indices[:500] + class1_indices[:500]
random.shuffle(selected_indices)

# Subset the evaluation set
balanced_eval_dataset = test_dataset.select(selected_indices)

# Diagnostic print to confirm balance
print("Balanced eval label distribution:", np.bincount(balanced_eval_dataset["label"].numpy()))

Balanced eval label distribution: [500 500]


In [13]:
# Step 5:definining compute_metrics
finetune_model = DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)

def compute_metrics(p):
    preds = torch.argmax(torch.tensor(p.predictions), dim=1)
    labels = p.label_ids
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary', zero_division=0)
    acc = accuracy_score(labels, preds)
    return {"accuracy": acc, "f1": f1, "precision": precision, "recall": recall}

# Step 6: Define training arguments
training_args = TrainingArguments(
    output_dir="./results_finetuned",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_dir="./logs",
    logging_steps=100
)

# Step 7: Trainer setup
trainer = Trainer(
    model=finetune_model,
    args=training_args,
    train_dataset=tokenized_dataset["train"].shuffle(seed=42).select(range(5000)),  # small sample for faster training
    eval_dataset=balanced_eval_dataset,
    compute_metrics=compute_metrics,
)

# Step 8: Training and evaluation
trainer.train()
results_finetuned = trainer.evaluate()
print("Final Evaluation Results:", results_finetuned)


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,0.3294,0.268968,0.889,0.88999,0.882122,0.898
2,0.1775,0.343982,0.899,0.899702,0.893491,0.906
3,0.0925,0.478502,0.898,0.896552,0.909465,0.884


Final Evaluation Results: {'eval_loss': 0.4785021245479584, 'eval_accuracy': 0.898, 'eval_f1': 0.896551724137931, 'eval_precision': 0.9094650205761317, 'eval_recall': 0.884, 'eval_runtime': 192.7773, 'eval_samples_per_second': 5.187, 'eval_steps_per_second': 0.327, 'epoch': 3.0}


In [15]:
# Step 9: Loading IMDb dataset and tokenizer
dataset = load_dataset("imdb")
tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

# Step 10: Tokenization (same as before)
def tokenize(batch):
    return tokenizer(batch["text"], truncation=True, padding="max_length", max_length=256)

tokenized_dataset = dataset.map(tokenize, batched=True)
tokenized_dataset.set_format(type="torch", columns=["input_ids", "attention_mask", "label"])

# Step 11: Loading frozen DistilBERT model (no classification head)
bert = DistilBertModel.from_pretrained("distilbert-base-uncased")
bert.eval()  # Freeze weights

# Step 12: Extracting embeddings (e.g., mean pooled CLS token)
def extract_embeddings(dataset_split):
    dataloader = DataLoader(dataset_split, batch_size=16)
    embeddings, labels = [], []

    with torch.no_grad():
        for batch in tqdm(dataloader):
            input_ids = batch["input_ids"]
            attention_mask = batch["attention_mask"]
            outputs = bert(input_ids=input_ids, attention_mask=attention_mask)
            # Mean pooling
            mean_embeddings = outputs.last_hidden_state.mean(dim=1)
            embeddings.append(mean_embeddings.numpy())
            labels.append(batch["label"].numpy())

    return np.concatenate(embeddings), np.concatenate(labels)

# Step 13: Use smaller subset (to fit faster)
train_subset = tokenized_dataset["train"].shuffle(seed=42).select(range(5000))
test_subset = tokenized_dataset["test"].shuffle(seed=42).select(range(1000))

X_train, y_train = extract_embeddings(train_subset)
X_test, y_test = extract_embeddings(test_subset)

# Step 14: Train classifier (e.g., Logistic Regression)
clf = LogisticRegression(max_iter=1000)
clf.fit(X_train, y_train)

# Step 15: Evaluate
y_pred = clf.predict(X_test)
print(classification_report(y_test, y_pred, digits=4))

100%|██████████| 313/313 [13:15<00:00,  2.54s/it]
100%|██████████| 63/63 [02:42<00:00,  2.59s/it]


              precision    recall  f1-score   support

           0     0.8635    0.8398    0.8515       512
           1     0.8367    0.8607    0.8485       488

    accuracy                         0.8500      1000
   macro avg     0.8501    0.8502    0.8500      1000
weighted avg     0.8504    0.8500    0.8500      1000

