# DistilBERT Base Code Training

In [1]:

# Install PyTorch 2.8.0 with CUDA 12.9 support
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu129

# Install HuggingFace Transformers + Datasets for DistilBERT training
%pip install transformers datasets accelerate
%pip install tqdm


Looking in indexes: https://download.pytorch.org/whl/cu129
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [2]:
# Code Cell 2

import torch
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None")


  from .autonotebook import tqdm as notebook_tqdm


PyTorch version: 2.8.0+cu129
CUDA available: True
CUDA device: NVIDIA GeForce RTX 3060 Ti


In [3]:
from datasets import load_dataset
from transformers import DistilBertTokenizerFast

dataset = load_dataset("ag_news")

print("Dataset splits:", dataset)

tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

MAX_LENGTH = 128 

def tokenize_batch(batch):
    return tokenizer(
        batch["text"],
        padding="max_length",
        truncation=True,
        max_length=MAX_LENGTH,
    )

tokenized_dataset = dataset.map(tokenize_batch, batched=True)

tokenized_dataset.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "label"]
)

train_dataset = tokenized_dataset["train"]
test_dataset = tokenized_dataset["test"]

print("Sample tokenized batch shape:")
print("Train dataset example input_ids shape:", train_dataset[0]["input_ids"].shape)
print("Train dataset example attention_mask shape:", train_dataset[0]["attention_mask"].shape)
print("Label example:", train_dataset[0]["label"])


Dataset splits: DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 120000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 7600
    })
})
Sample tokenized batch shape:
Train dataset example input_ids shape: torch.Size([128])
Train dataset example attention_mask shape: torch.Size([128])
Label example: tensor(2)


In [4]:
from torch.utils.data import DataLoader

BATCH_SIZE = 128
NUM_EPOCHS = 30
SUBSET_TRAIN_SIZE = 10_000

print(f"Batch size: {BATCH_SIZE}")
print(f"Planned epochs: {NUM_EPOCHS}")
print(f"Using training subset size: {SUBSET_TRAIN_SIZE}")

train_subset = train_dataset.select(range(SUBSET_TRAIN_SIZE))

NUM_WORKERS = 8

train_loader = DataLoader(
    train_subset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)
print("Train batches (subset):", len(train_loader))
print("Test batches:", len(test_loader))
print(f"Dataloaders initialized with num_workers={NUM_WORKERS}, pin_memory=True")

metrics = {
    "epoch": [],
    "train_samples_per_sec": [],
    "epoch_time_sec": [],
    "train_loss": [],
    "val_accuracy": []
}
print("Metric tracking dict initialized:", list(metrics.keys()))


Batch size: 128
Planned epochs: 30
Using training subset size: 10000
Train batches (subset): 79
Test batches: 60
Dataloaders initialized with num_workers=8, pin_memory=True
Metric tracking dict initialized: ['epoch', 'train_samples_per_sec', 'epoch_time_sec', 'train_loss', 'val_accuracy']


In [5]:
# Code Cell 5

from transformers import DistilBertForSequenceClassification

# 1. Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# 2. Load DistilBERT model for 4-class classification (AG News)
num_labels = 4
model = DistilBertForSequenceClassification.from_pretrained(
    "distilbert-base-uncased",
    num_labels=num_labels
)
model.to(device)
print("Model loaded and moved to device.")

# 3. Define optimizer (AdamW is standard for transformers)
learning_rate = 5e-5
optimizer = AdamW(model.parameters(), lr=learning_rate)

print("Optimizer initialized (AdamW) with learning rate:", learning_rate)


Using device: cuda


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded and moved to device.
Optimizer initialized (AdamW) with learning rate: 5e-05


In [6]:
# Code Cell 6

import time
import torch.nn.functional as F
model.train()

def evaluate(model, data_loader, device):
    """Simple evaluation loop to compute validation accuracy."""
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            logits = outputs.logits
            preds = torch.argmax(logits, dim=-1)

            correct += (preds == labels).sum().item()
            total += labels.size(0)

    model.train() 
    return correct / total if total > 0 else 0.0


total_train_samples = len(train_dataset)

print("Starting training...")
for epoch in range(1, NUM_EPOCHS + 1):
    epoch_start_time = time.time()
    running_loss = 0.0
    num_batches = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch}/{NUM_EPOCHS}"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        optimizer.zero_grad()

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        num_batches += 1

    epoch_end_time = time.time()
    epoch_time = epoch_end_time - epoch_start_time
    samples_per_sec = total_train_samples / epoch_time if epoch_time > 0 else 0.0
    avg_train_loss = running_loss / num_batches if num_batches > 0 else 0.0

    val_acc = evaluate(model, test_loader, device)

    metrics["epoch"].append(epoch)
    metrics["train_samples_per_sec"].append(samples_per_sec)
    metrics["epoch_time_sec"].append(epoch_time)
    metrics["train_loss"].append(avg_train_loss)
    metrics["val_accuracy"].append(val_acc)

    print(
        f"Epoch {epoch}/{NUM_EPOCHS} | "
        f"Time: {epoch_time:.2f}s | "
        f"Throughput: {samples_per_sec:.2f} samples/s | "
        f"Train Loss: {avg_train_loss:.4f} | "
        f"Val Acc: {val_acc:.4f}"
    )

print("Training complete.")


Starting training...


Epoch 1/30: 100%|██████████| 79/79 [00:44<00:00,  1.77it/s]


Epoch 1/30 | Time: 44.70s | Throughput: 2684.67 samples/s | Train Loss: 0.4901 | Val Acc: 0.9041


Epoch 2/30: 100%|██████████| 79/79 [00:44<00:00,  1.79it/s]


Epoch 2/30 | Time: 44.15s | Throughput: 2718.14 samples/s | Train Loss: 0.2073 | Val Acc: 0.8999


Epoch 3/30: 100%|██████████| 79/79 [00:44<00:00,  1.78it/s]


Epoch 3/30 | Time: 44.34s | Throughput: 2706.43 samples/s | Train Loss: 0.1277 | Val Acc: 0.9067


Epoch 4/30: 100%|██████████| 79/79 [00:44<00:00,  1.78it/s]


Epoch 4/30 | Time: 44.39s | Throughput: 2703.52 samples/s | Train Loss: 0.0863 | Val Acc: 0.9057


Epoch 5/30: 100%|██████████| 79/79 [00:44<00:00,  1.78it/s]


Epoch 5/30 | Time: 44.48s | Throughput: 2698.11 samples/s | Train Loss: 0.0611 | Val Acc: 0.9124


Epoch 6/30: 100%|██████████| 79/79 [00:44<00:00,  1.78it/s]


Epoch 6/30 | Time: 44.49s | Throughput: 2697.44 samples/s | Train Loss: 0.0381 | Val Acc: 0.9061


Epoch 7/30: 100%|██████████| 79/79 [00:44<00:00,  1.77it/s]


Epoch 7/30 | Time: 44.53s | Throughput: 2694.53 samples/s | Train Loss: 0.0336 | Val Acc: 0.9064


Epoch 8/30: 100%|██████████| 79/79 [00:44<00:00,  1.77it/s]


Epoch 8/30 | Time: 44.54s | Throughput: 2694.32 samples/s | Train Loss: 0.0283 | Val Acc: 0.9068


Epoch 9/30: 100%|██████████| 79/79 [00:44<00:00,  1.78it/s]


Epoch 9/30 | Time: 44.49s | Throughput: 2697.30 samples/s | Train Loss: 0.0196 | Val Acc: 0.9062


Epoch 10/30: 100%|██████████| 79/79 [00:44<00:00,  1.78it/s]


Epoch 10/30 | Time: 44.44s | Throughput: 2700.15 samples/s | Train Loss: 0.0138 | Val Acc: 0.9116


Epoch 11/30: 100%|██████████| 79/79 [00:44<00:00,  1.78it/s]


Epoch 11/30 | Time: 44.48s | Throughput: 2697.85 samples/s | Train Loss: 0.0159 | Val Acc: 0.9093


Epoch 12/30: 100%|██████████| 79/79 [00:44<00:00,  1.78it/s]


Epoch 12/30 | Time: 44.50s | Throughput: 2696.88 samples/s | Train Loss: 0.0134 | Val Acc: 0.9025


Epoch 13/30: 100%|██████████| 79/79 [00:44<00:00,  1.78it/s]


Epoch 13/30 | Time: 44.43s | Throughput: 2700.95 samples/s | Train Loss: 0.0284 | Val Acc: 0.8989


Epoch 14/30: 100%|██████████| 79/79 [00:45<00:00,  1.73it/s]


Epoch 14/30 | Time: 45.76s | Throughput: 2622.54 samples/s | Train Loss: 0.0203 | Val Acc: 0.9050


Epoch 15/30: 100%|██████████| 79/79 [00:46<00:00,  1.72it/s]


Epoch 15/30 | Time: 46.02s | Throughput: 2607.50 samples/s | Train Loss: 0.0126 | Val Acc: 0.9053


Epoch 16/30: 100%|██████████| 79/79 [00:45<00:00,  1.75it/s]


Epoch 16/30 | Time: 45.11s | Throughput: 2660.16 samples/s | Train Loss: 0.0110 | Val Acc: 0.9079


Epoch 17/30: 100%|██████████| 79/79 [00:45<00:00,  1.73it/s]


Epoch 17/30 | Time: 45.61s | Throughput: 2631.03 samples/s | Train Loss: 0.0069 | Val Acc: 0.9012


Epoch 18/30: 100%|██████████| 79/79 [00:45<00:00,  1.75it/s]


Epoch 18/30 | Time: 45.11s | Throughput: 2660.39 samples/s | Train Loss: 0.0071 | Val Acc: 0.9009


Epoch 19/30: 100%|██████████| 79/79 [00:45<00:00,  1.74it/s]


Epoch 19/30 | Time: 45.48s | Throughput: 2638.24 samples/s | Train Loss: 0.0070 | Val Acc: 0.9038


Epoch 20/30: 100%|██████████| 79/79 [00:46<00:00,  1.68it/s]


Epoch 20/30 | Time: 46.98s | Throughput: 2554.34 samples/s | Train Loss: 0.0054 | Val Acc: 0.9062


Epoch 21/30: 100%|██████████| 79/79 [00:47<00:00,  1.65it/s]


Epoch 21/30 | Time: 47.90s | Throughput: 2505.15 samples/s | Train Loss: 0.0041 | Val Acc: 0.8999


Epoch 22/30: 100%|██████████| 79/79 [00:47<00:00,  1.67it/s]


Epoch 22/30 | Time: 47.40s | Throughput: 2531.65 samples/s | Train Loss: 0.0097 | Val Acc: 0.8916


Epoch 23/30: 100%|██████████| 79/79 [00:46<00:00,  1.68it/s]


Epoch 23/30 | Time: 46.96s | Throughput: 2555.29 samples/s | Train Loss: 0.0147 | Val Acc: 0.9028


Epoch 24/30: 100%|██████████| 79/79 [00:47<00:00,  1.68it/s]


Epoch 24/30 | Time: 47.09s | Throughput: 2548.51 samples/s | Train Loss: 0.0065 | Val Acc: 0.9020


Epoch 25/30: 100%|██████████| 79/79 [00:47<00:00,  1.68it/s]


Epoch 25/30 | Time: 47.12s | Throughput: 2546.79 samples/s | Train Loss: 0.0034 | Val Acc: 0.9028


Epoch 26/30: 100%|██████████| 79/79 [00:46<00:00,  1.68it/s]


Epoch 26/30 | Time: 46.99s | Throughput: 2553.62 samples/s | Train Loss: 0.0039 | Val Acc: 0.9051


Epoch 27/30: 100%|██████████| 79/79 [00:47<00:00,  1.67it/s]


Epoch 27/30 | Time: 47.18s | Throughput: 2543.63 samples/s | Train Loss: 0.0037 | Val Acc: 0.9057


Epoch 28/30: 100%|██████████| 79/79 [00:46<00:00,  1.68it/s]


Epoch 28/30 | Time: 46.91s | Throughput: 2558.23 samples/s | Train Loss: 0.0019 | Val Acc: 0.9061


Epoch 29/30: 100%|██████████| 79/79 [00:47<00:00,  1.68it/s]


Epoch 29/30 | Time: 47.16s | Throughput: 2544.31 samples/s | Train Loss: 0.0014 | Val Acc: 0.9062


Epoch 30/30: 100%|██████████| 79/79 [00:47<00:00,  1.68it/s]


Epoch 30/30 | Time: 47.11s | Throughput: 2547.49 samples/s | Train Loss: 0.0048 | Val Acc: 0.9045
Training complete.
