# 04 - Train DistilBERT Multi-label Classifier

## Goal

Train DistilBERT for multi-label classification on title+abstract. We'll use BCEWithLogits, monitor micro/macro F1, and tune threshold on the validation set.


## Why Multi-label Classification?

Each paper can have multiple study characteristics:
- An **RCT** is also **Human**
- A **Systematic Review** might be a **MetaAnalysis**
- Some studies combine **Animal** + **InVitro**

We use **binary cross-entropy** (BCE) for each label independently.


In [8]:
# === TODO (you code this) ===
import pandas as pd
import numpy as np
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
from datasets import Dataset


In [9]:
# === TODO (you code this) ===
# Goal: Define canonical label list (MUST match order from notebook 02).
# Hints:
# 1) 10 labels in this exact order for consistency
# 2) Store as LABELS constant and compute NUM_LABELS
# Acceptance:
# - LABELS list with 10 study-design categories
# - NUM_LABELS = 10
# - Print for verification

# Canonical label list - MUST match the order from notebook 02
# This order is critical for converting labels to binary vectors
LABELS = [
    'SystematicReview',  # 1. Systematic reviews
    'MetaAnalysis',      # 2. Meta-analyses (quantitative synthesis)
    'RCT',               # 3. Randomized Controlled Trials
    'ClinicalTrial',     # 4. Non-randomized clinical trials
    'Cohort',            # 5. Cohort studies (prospective/retrospective)
    'CaseControl',       # 6. Case-control studies
    'CaseReport',        # 7. Case reports / case series
    'InVitro',           # 8. In vitro or ex vivo laboratory studies
    'Animal',            # 9. Animal studies
    'Human'              # 10. Human subjects (not mutually exclusive)
]

NUM_LABELS = len(LABELS)

# Verification
print(f"âœ… Canonical label list defined:")
print(f"   Number of labels: {NUM_LABELS}")
print(f"   Labels: {LABELS}")
print(f"\nðŸ“‹ Label order (for binary vector encoding):")
for i, label in enumerate(LABELS):
    print(f"   Index {i}: {label}")


âœ… Canonical label list defined:
   Number of labels: 10
   Labels: ['SystematicReview', 'MetaAnalysis', 'RCT', 'ClinicalTrial', 'Cohort', 'CaseControl', 'CaseReport', 'InVitro', 'Animal', 'Human']

ðŸ“‹ Label order (for binary vector encoding):
   Index 0: SystematicReview
   Index 1: MetaAnalysis
   Index 2: RCT
   Index 3: ClinicalTrial
   Index 4: Cohort
   Index 5: CaseControl
   Index 6: CaseReport
   Index 7: InVitro
   Index 8: Animal
   Index 9: Human


## Load Splits & Build HF Datasets


In [10]:
# === TODO (you code this) ===
# Goal: Load train/val/test splits and create 'text' column.
# Hints:
# 1) Load three parquet files from ../data/processed
# 2) Concatenate title + ' ' + abstract into 'text' column
# 3) Truncate to reasonable length (e.g., 2000 chars)
# Acceptance:
# - train_df, val_df, test_df loaded
# - Each has 'text' column
# - Print counts
from pathlib import Path
import pandas as pd

# TODO: load splits and create text column
processed_data_path = Path('../data/processed')
train_df = pd.read_parquet(processed_data_path / 'train.parquet')
val_df = pd.read_parquet(processed_data_path / 'val.parquet')
test_df = pd.read_parquet(processed_data_path / 'test.parquet')

train_df['text'] = train_df['title'] + ' ' + train_df['abstract']
val_df['text'] = val_df['title'] + ' ' + val_df['abstract']
test_df['text'] = test_df['title'] + ' ' + test_df['abstract']

train_df['text'] = train_df['text'].str[:2000]
val_df['text'] = val_df['text'].str[:2000]
test_df['text'] = test_df['text'].str[:2000]

def acceptance_criteria():
    assert train_df['text'].notna().all(), "All texts should be non-null"
    assert val_df['text'].notna().all(), "All texts should be non-null"
    assert test_df['text'].notna().all(), "All texts should be non-null"
    assert train_df['text'].str.len().max() <= 2000, "Texts should be truncated to 2000 characters"
    assert val_df['text'].str.len().max() <= 2000, "Texts should be truncated to 2000 characters"
    assert test_df['text'].str.len().max() <= 2000, "Texts should be truncated to 2000 characters"
    assert train_df['text'].str.len().min() > 0, "Texts should not be empty"
    assert val_df['text'].str.len().min() > 0, "Texts should not be empty"
    assert test_df['text'].str.len().min() > 0, "Texts should not be empty"
    return "Acceptance criteria met"

print(acceptance_criteria())







Acceptance criteria met


## Binarize Labels

Convert list of labels â†’ multi-hot binary vector.


In [11]:
# === TODO (you code this) ===
# Goal: Convert labels lists to multi-hot binary vectors.
# Hints:
# 1) Create zero vector of length NUM_LABELS
# 2) Set index to 1.0 for each label in LABELS
# 3) Apply to all three DataFrames
# Acceptance:
# - Function binarize_labels(labels_list) -> list of floats
# - New column 'label_vec' in all DataFrames
# - Vector length = NUM_LABELS

def binarize_labels(labels_list):
    """Convert list of labels to multi-hot vector."""
    # TODO
    zero_vector = [0.0] * NUM_LABELS
    for label in labels_list:
        if label in LABELS:
            zero_vector[LABELS.index(label)] = 1.0
    return zero_vector

train_df['label_vec'] = train_df['labels'].apply(binarize_labels)
val_df['label_vec'] = val_df['labels'].apply(binarize_labels)
test_df['label_vec'] = test_df['labels'].apply(binarize_labels)

print(train_df[['text', "labels", 'label_vec']].head())
print(train_df["label_vec"].value_counts())





                                                   text   labels  \
1347  Stability of Class II Malocclusion Treatment w...  [Human]   
1348  The Dental Aesthetic Index and Its Association...  [Human]   
1349  Aboriginal Health Workers Promoting Oral Healt...  [Human]   
1350  Sleep Bruxism in Children: Etiology, Diagnosis...  [Human]   
1351  Orthodontic Extrusion vs. Surgical Extrusion t...  [Human]   

                                              label_vec  
1347  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  
1348  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  
1349  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  
1350  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  
1351  [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...  
label_vec
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0]    15441
[0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 1.0]     2636
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 1.0]     1810
[0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0]   

In [12]:
# === TODO (you code this) ===
# Goal: Create HuggingFace Dataset objects.
# Hints:
# 1) Use Dataset.from_pandas() with just 'text' and 'label_vec' columns
# 2) Create for all three splits
# Acceptance:
# - train_dataset, val_dataset, test_dataset created
# - Each contains 'text' and 'label_vec' fields
from datasets import Dataset

# TODO: convert to HF Dataset objects
def convert_to_hf_dataset(df):
    return Dataset.from_pandas(df[['text', 'label_vec']])

train_dataset = convert_to_hf_dataset(train_df)
val_dataset = convert_to_hf_dataset(val_df)
test_dataset = convert_to_hf_dataset(test_df)

# Test
print(train_dataset[0])
print(val_dataset[0])
print(test_dataset[0])




{'text': "Stability of Class II Malocclusion Treatment with the Austro Repositioner Followed by Fixed Appliances in Brachyfacial Patients. One of the goals of functional-appliance devices is to modify the vertical growth pattern, solving several kinds of malocclusion. This study aimed to evaluate Class II malocclusion treatment's stability with Austro Repositioner, followed by fixed appliances, and assess its capacity to modify vertical dimensions in brachyfacial patients. A test group of 30 patients (16 boys and 14 girls, mean 11.9 years old) with Class II malocclusion due to mandibular retrognathism and brachyfacial pattern treated with Austro Repositioner and fixed appliance were compared to a matched untreated Class II control group of 30 patients (17 boys and 13 girls, mean age 11.7 years old). Lateral cephalograms were taken at T1 (initial records), T2 (end of treatment), and T3 (one year after treatment). Statistical comparisons were performed with a paired-sample  t -test and t

## Tokenizer & Encoding


In [13]:
# === TODO (you code this) ===
# Goal: Tokenize datasets and prepare for training.
# Hints:
# 1) Load tokenizer for 'distilbert-base-uncased'
# 2) Create tokenize function with max_length=512, truncation, padding
# 3) Map tokenize function to all datasets (batched=True)
# 4) Rename 'label_vec' â†’ 'labels', set format to torch
# Acceptance:
# - tokenizer loaded
# - All datasets have input_ids, attention_mask, labels
# - Format set to 'torch'

# TODO: tokenize and format datasets
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('distilbert-base-uncased')

def tokenize_and_format(dataset):
    return dataset.map(lambda x: tokenizer(x['text'], max_length=512, truncation=True, padding='max_length'), batched=True)

train_dataset = tokenize_and_format(train_dataset)
val_dataset = tokenize_and_format(val_dataset)
test_dataset = tokenize_and_format(test_dataset)

# Rename 'label_vec' â†’ 'labels', set format to torch
train_dataset = train_dataset.rename_column('label_vec', 'labels')
val_dataset = val_dataset.rename_column('label_vec', 'labels')
test_dataset = test_dataset.rename_column('label_vec', 'labels')

train_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])
val_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])
test_dataset.set_format('torch', columns=['input_ids', 'attention_mask', 'labels'])

# Test
print(train_dataset[0])
print(val_dataset[0])
print(test_dataset[0])

print(train_dataset.format)
print(val_dataset.format)
print(test_dataset.format)

# Acceptance
assert train_dataset.format is not None and train_dataset.format['type'] == 'torch'
assert val_dataset.format is not None and val_dataset.format['type'] == 'torch'
assert test_dataset.format is not None and test_dataset.format['type'] == 'torch'
assert 'labels' in train_dataset.column_names
assert 'labels' in val_dataset.column_names
assert 'labels' in test_dataset.column_names
assert train_dataset.num_rows == len(train_df)
assert val_dataset.num_rows == len(val_df)
assert test_dataset.num_rows == len(test_df)
assert train_dataset.num_columns == 3
assert val_dataset.num_columns == 3
assert test_dataset.num_columns == 3
assert train_dataset.num_columns == val_dataset.num_columns == test_dataset.num_columns
assert train_dataset.num_rows == val_dataset.num_rows == test_dataset.num_rows
assert train_dataset.columns[0] == 'input_ids'
assert val_dataset.columns[0] == 'input_ids'
assert test_dataset.columns[0] == 'input_ids'
assert train_dataset.columns[1] == 'attention_mask'
assert val_dataset.columns[1] == 'attention_mask'
assert test_dataset.columns[1] == 'attention_mask'
assert train_dataset.columns[2] == 'labels'
assert val_dataset.columns[2] == 'labels'
assert test_dataset.columns[2] == 'labels'





Map:   0%|          | 0/29926 [00:00<?, ? examples/s]

Map:   0%|          | 0/16057 [00:00<?, ? examples/s]

Map:   0%|          | 0/18666 [00:00<?, ? examples/s]

{'labels': tensor([0., 0., 0., 0., 0., 0., 0., 0., 0., 1.]), 'input_ids': tensor([  101,  9211,  1997,  2465,  2462, 15451, 10085, 20464, 14499,  3949,
         2007,  1996, 16951, 16360, 19234,  2121,  2628,  2011,  4964, 22449,
         1999, 11655, 11714,  7011, 13247,  5022,  1012,  2028,  1997,  1996,
         3289,  1997,  8360,  1011, 10439, 15204,  3401,  5733,  2003,  2000,
        19933,  1996,  7471,  3930,  5418,  1010, 13729,  2195,  7957,  1997,
        15451, 10085, 20464, 14499,  1012,  2023,  2817,  6461,  2000, 16157,
         2465,  2462, 15451, 10085, 20464, 14499,  3949,  1005,  1055,  9211,
         2007, 16951, 16360, 19234,  2121,  1010,  2628,  2011,  4964, 22449,
         1010,  1998, 14358,  2049,  3977,  2000, 19933,  7471,  9646,  1999,
        11655, 11714,  7011, 13247,  5022,  1012,  1037,  3231,  2177,  1997,
         2382,  5022,  1006,  2385,  3337,  1998,  2403,  3057,  1010,  2812,
         2340,  1012,  1023,  2086,  2214,  1007,  2007,  2465,  246

AssertionError: 

## Model Initialization


In [None]:
# === TODO (you code this) ===
# Goal: Initialize DistilBERT for multi-label classification.
# Hints:
# 1) Use AutoModelForSequenceClassification.from_pretrained
# 2) Set num_labels=NUM_LABELS, problem_type='multi_label_classification'
# Acceptance:
# - model initialized from 'distilbert-base-uncased'
# - Configured for multi-label (uses BCEWithLogits loss)

# TODO: initialize model


## Metrics Function


In [None]:
# === TODO (you code this) ===
# Goal: Define metrics function for Trainer.
# Hints:
# 1) Apply sigmoid to logits, threshold at 0.5
# 2) Compute micro/macro precision, recall, F1 using sklearn
# 3) Return dict with 6 metrics
# Acceptance:
# - Function compute_metrics(eval_pred) -> dict
# - Returns precision/recall/f1 for both micro and macro

def compute_metrics(eval_pred):
    """Compute evaluation metrics for multi-label classification."""
    # TODO
    raise NotImplementedError


## Training Arguments & Trainer


In [None]:
# === TODO (you code this) ===
# Goal: Configure training parameters.
# Hints:
# 1) Set output_dir, eval_strategy, save_strategy, learning_rate
# 2) batch_size=8, epochs=3-4, warmup_ratio=0.1
# 3) load_best_model_at_end=True, metric_for_best_model='f1_micro'
# Acceptance:
# - TrainingArguments object configured
# - Will save to ../artifacts/model/checkpoints
# - Evaluates each epoch and keeps best

# TODO: create training_args


In [None]:
# === TODO (you code this) ===
# Goal: Train model and save best checkpoint.
# Hints:
# 1) Create Trainer with model, args, datasets, compute_metrics
# 2) Call trainer.train()
# 3) Save best model and tokenizer to ../artifacts/model/best
# Acceptance:
# - Training completes successfully
# - Best model saved based on micro-F1
# - Both model and tokenizer saved

# TODO: create Trainer, train, and save

# Run training (uncomment when ready)
# trainer = ...
# trainer.train()
# Save best model...


## Recommendations

- If rare labels underperform, try **class weights** or **focal loss** (stretch goal)
- Track **ROC/PR curves** on val to tune per-label thresholds later
- Monitor for **overfitting:** val metrics should not degrade while train improves

## ðŸ§˜ Reflection Log

**What did you learn in this session?**
- 

**What challenges did you encounter?**
- 

**How will this improve Periospot AI?**
- 
