# Transformer Lite: DistilBERT with Multi-Output Heads

This notebook fine-tunes a **DistilBERT** model to perform **multi-level classification**. Each output head predicts a specific level in the label hierarchy.


In [None]:
!pip install -q transformers torch datasets accelerate scikit-learn

## 📥 1. Load & Preprocess the Dataset

* Use `fetch_20newsgroups()` to get the 20 NewsGroups dataset.
* Extract hierarchical labels: `category`, `subcategory`, and `subsubcategory`.


In [85]:
from sklearn.datasets import fetch_20newsgroups
import pandas as pd

# Load 20 Newsgroups (for demo, since Yahoo data not directly available)
newsgroups = fetch_20newsgroups(subset='train', remove=('headers', 'footers', 'quotes'))

df = pd.DataFrame({'text': newsgroups.data, 'label': newsgroups.target})
df['category'] = [newsgroups.target_names[i].split('.')[0] if '.' in newsgroups.target_names[i] else newsgroups.target_names[i] for i in df['label']]
df['subcategory'] = [newsgroups.target_names[i] for i in df['label']]
df['subsubcategory'] = df['subcategory'].apply(lambda x: x.split('.')[-1] if '.' in x else 'none')

print(df[['category', 'subcategory', 'subsubcategory']].drop_duplicates())
print(f"Dataset size: {len(df)}")
df.head()

   category               subcategory subsubcategory
0       rec                 rec.autos          autos
1      comp     comp.sys.mac.hardware       hardware
3      comp             comp.graphics       graphics
4       sci                 sci.space          space
5      talk        talk.politics.guns           guns
6       sci                   sci.med            med
7      comp  comp.sys.ibm.pc.hardware       hardware
8      comp   comp.os.ms-windows.misc           misc
10      rec           rec.motorcycles    motorcycles
11     talk        talk.religion.misc           misc
14     misc              misc.forsale        forsale
15      alt               alt.atheism        atheism
18      sci           sci.electronics    electronics
19     comp            comp.windows.x              x
21      rec          rec.sport.hockey         hockey
27      rec        rec.sport.baseball       baseball
28      soc    soc.religion.christian      christian
33     talk     talk.politics.mideast        m

Unnamed: 0,text,label,category,subcategory,subsubcategory
0,I was wondering if anyone out there could enli...,7,rec,rec.autos,autos
1,A fair number of brave souls who upgraded thei...,4,comp,comp.sys.mac.hardware,hardware
2,"well folks, my mac plus finally gave up the gh...",4,comp,comp.sys.mac.hardware,hardware
3,\nDo you have Weitek's address/phone number? ...,1,comp,comp.graphics,graphics
4,"From article <C5owCB.n3p@world.std.com>, by to...",14,sci,sci.space,space


## 🧠 2. Tokenization

* Use `DistilBertTokenizerFast` to tokenize the text into BERT-compatible input format.
* Truncate to 512 tokens with padding.

In [86]:
from transformers import DistilBertTokenizerFast

tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')

def tokenize(batch):
    return tokenizer(batch['text'], padding=True, truncation=True, max_length=512)

# Tokenize all data
tokens = tokenizer(list(df['text']), padding=True, truncation=True, max_length=512, return_tensors="pt")

## 🏷️ 3. Encode Labels

* Encode hierarchy into numerical labels using `LabelEncoder`:

  * Level 1: `category`
  * Level 2: `subcategory`
  * Level 3: `subsubcategory`

In [87]:
from sklearn.preprocessing import LabelEncoder

le_category = LabelEncoder()
le_subcategory = LabelEncoder()
le_subsubcategory = LabelEncoder()

df['cat_label'] = le_category.fit_transform(df['category'])
df['subcat_label'] = le_subcategory.fit_transform(df['subcategory'])
df['subsubcat_label'] = le_subsubcategory.fit_transform(df['subsubcategory'])

print("Number of classes at each level:")
print(f"Category: {len(le_category.classes_)}")
print(f"Subcategory: {len(le_subcategory.classes_)}")
print(f"Subsubcategory: {len(le_subsubcategory.classes_)}")

Number of classes at each level:
Category: 7
Subcategory: 20
Subsubcategory: 17


## 🏗️ 4. Model Definition: `DistilBERTMultiHead`

* Base: `DistilBertModel`
* Three classification heads for multi-level output.
* Each head predicts one level of the label hierarchy.

In [None]:
import torch
import torch.nn as nn
from transformers import DistilBertModel

class DistilBERTMultiHead(nn.Module):
    def __init__(self, num_classes_level1, num_classes_level2, num_classes_level3):
        super().__init__()
        self.distilbert = DistilBertModel.from_pretrained('distilbert-base-uncased')
        hidden_size = self.distilbert.config.hidden_size
        self.classifier_level1 = nn.Linear(hidden_size, num_classes_level1)
        self.classifier_level2 = nn.Linear(hidden_size, num_classes_level2)
        self.classifier_level3 = nn.Linear(hidden_size, num_classes_level3)

    def forward(self, input_ids, attention_mask):
        outputs = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)
        cls_output = outputs.last_hidden_state[:, 0, :]  # CLS token representation
        logits1 = self.classifier_level1(cls_output)
        logits2 = self.classifier_level2(cls_output)
        logits3 = self.classifier_level3(cls_output)
        return logits1, logits2, logits3

## 📦 5. Dataset and Dataloader

* Define a custom `HierarchicalDataset`.
* Wrap tokenized inputs and labels into a PyTorch `DataLoader`.

In [112]:
from torch.utils.data import DataLoader, Dataset
import torch.optim as optim

class HierarchicalDataset(Dataset):
    def __init__(self, encodings, labels1, labels2, labels3):
        self.encodings = encodings
        self.labels1 = labels1
        self.labels2 = labels2
        self.labels3 = labels3

    def __len__(self):
        return len(self.labels1)

    def __getitem__(self, idx):
        item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
        item['labels1'] = torch.tensor(self.labels1[idx])
        item['labels2'] = torch.tensor(self.labels2[idx])
        item['labels3'] = torch.tensor(self.labels3[idx])
        return item

dataset = HierarchicalDataset(tokens, df['cat_label'].values, df['subcat_label'].values, df['subsubcat_label'].values)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

In [113]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = DistilBERTMultiHead(
    num_classes_level1=len(le_category.classes_),
    num_classes_level2=len(le_subcategory.classes_),
    num_classes_level3=len(le_subsubcategory.classes_)
)
model.to(device)

optimizer = optim.Adam(model.parameters(), lr=2e-5)
criterion = nn.CrossEntropyLoss()

## 🔧 6. Training Loop

* Train model over 3 epochs using `Adam` optimizer and `CrossEntropyLoss` for each level.
* Sum of three losses is used for backpropagation.

In [114]:
from tqdm import tqdm

model.train()
epochs = 3

for epoch in range(epochs):
    total_loss = 0
    for batch in tqdm(dataloader):
        optimizer.zero_grad()
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels1 = batch['labels1'].to(device)
        labels2 = batch['labels2'].to(device)
        labels3 = batch['labels3'].to(device)

        logits1, logits2, logits3 = model(input_ids, attention_mask)
        loss1 = criterion(logits1, labels1)
        loss2 = criterion(logits2, labels2)
        loss3 = criterion(logits3, labels3)

        loss = loss1 + loss2 + loss3
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f"Epoch {epoch+1} loss: {total_loss/len(dataloader)}")


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}
100%|██████████| 708/708 [09:04<00:00,  1.30it/s]


Epoch 1 loss: 3.3445821265547964


100%|██████████| 708/708 [09:03<00:00,  1.30it/s]


Epoch 2 loss: 1.8690194942183413


100%|██████████| 708/708 [09:03<00:00,  1.30it/s]

Epoch 3 loss: 1.277032246002875





## 📊 7. Evaluation

* Compute accuracy at each level (category, subcategory, subsubcategory) on the training set.

In [115]:
model.eval()
correct_level1, correct_level2, correct_level3 = 0, 0, 0
total = 0

with torch.no_grad():
    for batch in dataloader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels1 = batch['labels1'].to(device)
        labels2 = batch['labels2'].to(device)
        labels3 = batch['labels3'].to(device)

        logits1, logits2, logits3 = model(input_ids, attention_mask)
        preds1 = torch.argmax(logits1, dim=1)
        preds2 = torch.argmax(logits2, dim=1)
        preds3 = torch.argmax(logits3, dim=1)

        correct_level1 += (preds1 == labels1).sum().item()
        correct_level2 += (preds2 == labels2).sum().item()
        correct_level3 += (preds3 == labels3).sum().item()
        total += labels1.size(0)

print(f"Accuracy level 1: {correct_level1/total:.4f}")
print(f"Accuracy level 2: {correct_level2/total:.4f}")
print(f"Accuracy level 3: {correct_level3/total:.4f}")


  item = {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}


Accuracy level 1: 0.9513
Accuracy level 2: 0.8889
Accuracy level 3: 0.9213


## 🔮 8. Inference

* Predict hierarchy for a new input text.
* Decode predictions using inverse transformation of the label encoders.

In [93]:
def predict(text):
    model.eval()
    inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512).to(device)
    with torch.no_grad():
        logits1, logits2, logits3 = model(inputs['input_ids'], inputs['attention_mask'])
        pred1 = le_category.inverse_transform([torch.argmax(logits1, dim=1).item()])[0]
        pred2 = le_subcategory.inverse_transform([torch.argmax(logits2, dim=1).item()])[0]
        pred3 = le_subsubcategory.inverse_transform([torch.argmax(logits3, dim=1).item()])[0]
    return pred1, pred2, pred3

sample_text = "How do I install Python packages on Windows?"
print("Predicted hierarchy:", predict(sample_text))


Predicted hierarchy: ('comp', 'comp.os.ms-windows.misc', 'misc')
