# DistilBERT Base Code Training

In [None]:

# Install PyTorch 2.8.0 with CUDA 12.9 support
%pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu129

# Install HuggingFace Transformers + Datasets for DistilBERT training
%pip install transformers datasets accelerate
%pip install tqdm


Looking in indexes: https://download.pytorch.org/whl/cu129
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.
Note: you may need to restart the kernel to use updated packages.


In [None]:
# Code Cell 2

import torch
from transformers import DistilBertTokenizerFast, DistilBertForSequenceClassification
from datasets import load_dataset
from torch.utils.data import DataLoader
from torch.optim import AdamW
from tqdm import tqdm

print("PyTorch version:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())
print("CUDA device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "None")


  from .autonotebook import tqdm as notebook_tqdm


PyTorch version: 2.8.0+cu129
CUDA available: True
CUDA device: NVIDIA GeForce RTX 5060 Ti


In [None]:
from datasets import load_dataset
from transformers import DistilBertTokenizerFast

dataset = load_dataset("ag_news")

print("Dataset splits:", dataset)

tokenizer = DistilBertTokenizerFast.from_pretrained("distilbert-base-uncased")

MAX_LENGTH = 128 

def tokenize_batch(batch):
    return tokenizer(
        batch["text"],
        padding="max_length",
        truncation=True,
        max_length=MAX_LENGTH,
    )

tokenized_dataset = dataset.map(tokenize_batch, batched=True)

tokenized_dataset.set_format(
    type="torch",
    columns=["input_ids", "attention_mask", "label"]
)

train_dataset = tokenized_dataset["train"]
test_dataset = tokenized_dataset["test"]

print("Sample tokenized batch shape:")
print("Train dataset example input_ids shape:", train_dataset[0]["input_ids"].shape)
print("Train dataset example attention_mask shape:", train_dataset[0]["attention_mask"].shape)
print("Label example:", train_dataset[0]["label"])


Dataset splits: DatasetDict({
    train: Dataset({
        features: ['text', 'label'],
        num_rows: 120000
    })
    test: Dataset({
        features: ['text', 'label'],
        num_rows: 7600
    })
})


Map: 100%|██████████| 7600/7600 [00:00<00:00, 15664.64 examples/s]

Sample tokenized batch shape:
Train dataset example input_ids shape: torch.Size([128])
Train dataset example attention_mask shape: torch.Size([128])
Label example: tensor(2)





In [None]:
from torch.utils.data import DataLoader

BATCH_SIZE = 128
NUM_EPOCHS = 30
SUBSET_TRAIN_SIZE = 10_000

print(f"Batch size: {BATCH_SIZE}")
print(f"Planned epochs: {NUM_EPOCHS}")
print(f"Using training subset size: {SUBSET_TRAIN_SIZE}")

train_subset = train_dataset.select(range(SUBSET_TRAIN_SIZE))

NUM_WORKERS = 8

train_loader = DataLoader(
    train_subset,
    batch_size=BATCH_SIZE,
    shuffle=True,
    num_workers=NUM_WORKERS,
    pin_memory=True
)

test_loader = DataLoader(
    test_dataset,
    batch_size=BATCH_SIZE,
    shuffle=False,
    num_workers=NUM_WORKERS,
    pin_memory=True
)
print("Train batches (subset):", len(train_loader))
print("Test batches:", len(test_loader))
print(f"Dataloaders initialized with num_workers={NUM_WORKERS}, pin_memory=True")

metrics = {
    "epoch": [],
    "train_samples_per_sec": [],
    "epoch_time_sec": [],
    "train_loss": [],
    "val_accuracy": []
}
print("Metric tracking dict initialized:", list(metrics.keys()))


Batch size: 128
Planned epochs: 30
Using training subset size: 10000
Train batches (subset): 79
Test batches: 60
Dataloaders initialized with num_workers=8, pin_memory=True
Metric tracking dict initialized: ['epoch', 'train_samples_per_sec', 'epoch_time_sec', 'train_loss', 'val_accuracy']


In [None]:
# Code Cell 5

from transformers import DistilBertForSequenceClassification

# 1. Device setup
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using device:", device)

# 2. Load DistilBERT model for 4-class classification (AG News)
num_labels = 4
model = DistilBertForSequenceClassification.from_pretrained(
    "distilbert-base-uncased",
    num_labels=num_labels
)
model.to(device)
print("Model loaded and moved to device.")

# 3. Define optimizer (AdamW is standard for transformers)
learning_rate = 5e-5
optimizer = AdamW(model.parameters(), lr=learning_rate)

print("Optimizer initialized (AdamW) with learning rate:", learning_rate)


Using device: cuda


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded and moved to device.
Optimizer initialized (AdamW) with learning rate: 5e-05


In [None]:
# Code Cell 6

import time
import torch.nn.functional as F
model.train()

def evaluate(model, data_loader, device):
    """Simple evaluation loop to compute validation accuracy."""
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for batch in data_loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["label"].to(device)

            outputs = model(
                input_ids=input_ids,
                attention_mask=attention_mask
            )
            logits = outputs.logits
            preds = torch.argmax(logits, dim=-1)

            correct += (preds == labels).sum().item()
            total += labels.size(0)

    model.train() 
    return correct / total if total > 0 else 0.0


total_train_samples = len(train_dataset)

print("Starting training...")
for epoch in range(1, NUM_EPOCHS + 1):
    epoch_start_time = time.time()
    running_loss = 0.0
    num_batches = 0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch}/{NUM_EPOCHS}"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["label"].to(device)

        optimizer.zero_grad()

        outputs = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            labels=labels
        )
        loss = outputs.loss
        loss.backward()
        optimizer.step()

        running_loss += loss.item()
        num_batches += 1

    epoch_end_time = time.time()
    epoch_time = epoch_end_time - epoch_start_time
    samples_per_sec = total_train_samples / epoch_time if epoch_time > 0 else 0.0
    avg_train_loss = running_loss / num_batches if num_batches > 0 else 0.0

    val_acc = evaluate(model, test_loader, device)

    metrics["epoch"].append(epoch)
    metrics["train_samples_per_sec"].append(samples_per_sec)
    metrics["epoch_time_sec"].append(epoch_time)
    metrics["train_loss"].append(avg_train_loss)
    metrics["val_accuracy"].append(val_acc)

    print(
        f"Epoch {epoch}/{NUM_EPOCHS} | "
        f"Time: {epoch_time:.2f}s | "
        f"Throughput: {samples_per_sec:.2f} samples/s | "
        f"Train Loss: {avg_train_loss:.4f} | "
        f"Val Acc: {val_acc:.4f}"
    )

print("Training complete.")


Starting training...


Epoch 1/30: 100%|██████████| 79/79 [00:40<00:00,  1.93it/s]


Epoch 1/30 | Time: 40.89s | Throughput: 2934.86 samples/s | Train Loss: 0.4933 | Val Acc: 0.8995


Epoch 2/30: 100%|██████████| 79/79 [00:35<00:00,  2.19it/s]


Epoch 2/30 | Time: 35.99s | Throughput: 3333.96 samples/s | Train Loss: 0.2201 | Val Acc: 0.9116


Epoch 3/30: 100%|██████████| 79/79 [00:35<00:00,  2.22it/s]


Epoch 3/30 | Time: 35.67s | Throughput: 3364.25 samples/s | Train Loss: 0.1346 | Val Acc: 0.9104


Epoch 4/30: 100%|██████████| 79/79 [00:36<00:00,  2.15it/s]


Epoch 4/30 | Time: 36.81s | Throughput: 3259.66 samples/s | Train Loss: 0.0907 | Val Acc: 0.9095


Epoch 5/30: 100%|██████████| 79/79 [00:37<00:00,  2.12it/s]


Epoch 5/30 | Time: 37.30s | Throughput: 3217.47 samples/s | Train Loss: 0.0611 | Val Acc: 0.9093


Epoch 6/30: 100%|██████████| 79/79 [00:38<00:00,  2.08it/s]


Epoch 6/30 | Time: 38.00s | Throughput: 3157.67 samples/s | Train Loss: 0.0453 | Val Acc: 0.8945


Epoch 7/30: 100%|██████████| 79/79 [00:41<00:00,  1.89it/s]


Epoch 7/30 | Time: 41.83s | Throughput: 2869.06 samples/s | Train Loss: 0.0353 | Val Acc: 0.8934


Epoch 8/30: 100%|██████████| 79/79 [00:41<00:00,  1.91it/s]


Epoch 8/30 | Time: 41.39s | Throughput: 2898.94 samples/s | Train Loss: 0.0260 | Val Acc: 0.8974


Epoch 9/30: 100%|██████████| 79/79 [00:41<00:00,  1.91it/s]


Epoch 9/30 | Time: 41.43s | Throughput: 2896.15 samples/s | Train Loss: 0.0233 | Val Acc: 0.9037


Epoch 10/30: 100%|██████████| 79/79 [00:40<00:00,  1.95it/s]


Epoch 10/30 | Time: 40.54s | Throughput: 2960.25 samples/s | Train Loss: 0.0175 | Val Acc: 0.9061


Epoch 11/30: 100%|██████████| 79/79 [00:41<00:00,  1.89it/s]


Epoch 11/30 | Time: 41.70s | Throughput: 2877.79 samples/s | Train Loss: 0.0224 | Val Acc: 0.9062


Epoch 12/30: 100%|██████████| 79/79 [00:42<00:00,  1.88it/s]


Epoch 12/30 | Time: 42.06s | Throughput: 2853.20 samples/s | Train Loss: 0.0133 | Val Acc: 0.9068


Epoch 13/30: 100%|██████████| 79/79 [00:41<00:00,  1.90it/s]


Epoch 13/30 | Time: 41.61s | Throughput: 2883.91 samples/s | Train Loss: 0.0079 | Val Acc: 0.9124


Epoch 14/30: 100%|██████████| 79/79 [00:41<00:00,  1.89it/s]


Epoch 14/30 | Time: 41.82s | Throughput: 2869.41 samples/s | Train Loss: 0.0063 | Val Acc: 0.9072


Epoch 15/30: 100%|██████████| 79/79 [00:40<00:00,  1.96it/s]


Epoch 15/30 | Time: 40.26s | Throughput: 2980.70 samples/s | Train Loss: 0.0113 | Val Acc: 0.9057


Epoch 16/30: 100%|██████████| 79/79 [00:41<00:00,  1.90it/s]


Epoch 16/30 | Time: 41.60s | Throughput: 2884.35 samples/s | Train Loss: 0.0101 | Val Acc: 0.9050


Epoch 17/30: 100%|██████████| 79/79 [00:44<00:00,  1.79it/s]


Epoch 17/30 | Time: 44.11s | Throughput: 2720.56 samples/s | Train Loss: 0.0108 | Val Acc: 0.9070


Epoch 18/30: 100%|██████████| 79/79 [00:41<00:00,  1.89it/s]


Epoch 18/30 | Time: 41.80s | Throughput: 2870.80 samples/s | Train Loss: 0.0171 | Val Acc: 0.9022


Epoch 19/30: 100%|██████████| 79/79 [00:41<00:00,  1.89it/s]


Epoch 19/30 | Time: 41.86s | Throughput: 2866.82 samples/s | Train Loss: 0.0092 | Val Acc: 0.9100


Epoch 20/30: 100%|██████████| 79/79 [00:40<00:00,  1.95it/s]


Epoch 20/30 | Time: 40.54s | Throughput: 2959.93 samples/s | Train Loss: 0.0049 | Val Acc: 0.9092


Epoch 21/30: 100%|██████████| 79/79 [00:40<00:00,  1.95it/s]


Epoch 21/30 | Time: 40.42s | Throughput: 2968.50 samples/s | Train Loss: 0.0045 | Val Acc: 0.9096


Epoch 22/30: 100%|██████████| 79/79 [00:41<00:00,  1.90it/s]


Epoch 22/30 | Time: 41.49s | Throughput: 2892.16 samples/s | Train Loss: 0.0125 | Val Acc: 0.9024


Epoch 23/30: 100%|██████████| 79/79 [00:40<00:00,  1.93it/s]


Epoch 23/30 | Time: 41.00s | Throughput: 2926.76 samples/s | Train Loss: 0.0104 | Val Acc: 0.9083


Epoch 24/30: 100%|██████████| 79/79 [00:41<00:00,  1.89it/s]


Epoch 24/30 | Time: 41.89s | Throughput: 2864.64 samples/s | Train Loss: 0.0063 | Val Acc: 0.9036


Epoch 25/30: 100%|██████████| 79/79 [00:40<00:00,  1.94it/s]


Epoch 25/30 | Time: 40.71s | Throughput: 2947.48 samples/s | Train Loss: 0.0078 | Val Acc: 0.9070


Epoch 26/30: 100%|██████████| 79/79 [00:40<00:00,  1.96it/s]


Epoch 26/30 | Time: 40.36s | Throughput: 2973.17 samples/s | Train Loss: 0.0047 | Val Acc: 0.8993


Epoch 27/30: 100%|██████████| 79/79 [00:41<00:00,  1.92it/s]


Epoch 27/30 | Time: 41.07s | Throughput: 2921.98 samples/s | Train Loss: 0.0088 | Val Acc: 0.9043


Epoch 28/30: 100%|██████████| 79/79 [00:40<00:00,  1.93it/s]


Epoch 28/30 | Time: 40.98s | Throughput: 2928.56 samples/s | Train Loss: 0.0031 | Val Acc: 0.9050


Epoch 29/30: 100%|██████████| 79/79 [00:41<00:00,  1.92it/s]


Epoch 29/30 | Time: 41.11s | Throughput: 2919.19 samples/s | Train Loss: 0.0019 | Val Acc: 0.9080


Epoch 30/30: 100%|██████████| 79/79 [00:40<00:00,  1.94it/s]


Epoch 30/30 | Time: 40.64s | Throughput: 2952.48 samples/s | Train Loss: 0.0013 | Val Acc: 0.9066
Training complete.
