# 04 - Train DistilBERT Multi-label Classifier

## Goal

Train DistilBERT for multi-label classification on title+abstract. We'll use BCEWithLogits, monitor micro/macro F1, and tune threshold on the validation set.


## Why Multi-label Classification?

Each paper can have multiple study characteristics:
- An **RCT** is also **Human**
- A **Systematic Review** might be a **MetaAnalysis**
- Some studies combine **Animal** + **InVitro**

We use **binary cross-entropy** (BCE) for each label independently.


In [None]:
# === TODO (you code this) ===
# Goal: Import libraries for transformer training.
# Hints:
# 1) pandas, numpy, Path, torch
# 2) From transformers: AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer
# 3) From datasets: Dataset
# Acceptance:
# - All imports successful


In [None]:
# === TODO (you code this) ===
# Goal: Define canonical label list (MUST match order from notebook 02).
# Hints:
# 1) 10 labels in this exact order for consistency
# 2) Store as LABELS constant and compute NUM_LABELS
# Acceptance:
# - LABELS list with 10 study-design categories
# - NUM_LABELS = 10
# - Print for verification

# TODO: define LABELS and NUM_LABELS


## Load Splits & Build HF Datasets


In [None]:
# === TODO (you code this) ===
# Goal: Load train/val/test splits and create 'text' column.
# Hints:
# 1) Load three parquet files from ../data/processed
# 2) Concatenate title + ' ' + abstract into 'text' column
# 3) Truncate to reasonable length (e.g., 2000 chars)
# Acceptance:
# - train_df, val_df, test_df loaded
# - Each has 'text' column
# - Print counts

# TODO: load splits and create text column


## Binarize Labels

Convert list of labels â†’ multi-hot binary vector.


In [None]:
# === TODO (you code this) ===
# Goal: Convert labels lists to multi-hot binary vectors.
# Hints:
# 1) Create zero vector of length NUM_LABELS
# 2) Set index to 1.0 for each label in LABELS
# 3) Apply to all three DataFrames
# Acceptance:
# - Function binarize_labels(labels_list) -> list of floats
# - New column 'label_vec' in all DataFrames
# - Vector length = NUM_LABELS

def binarize_labels(labels_list):
    """Convert list of labels to multi-hot vector."""
    # TODO
    raise NotImplementedError

# TODO: create label_vec columns


In [None]:
# === TODO (you code this) ===
# Goal: Create HuggingFace Dataset objects.
# Hints:
# 1) Use Dataset.from_pandas() with just 'text' and 'label_vec' columns
# 2) Create for all three splits
# Acceptance:
# - train_dataset, val_dataset, test_dataset created
# - Each contains 'text' and 'label_vec' fields

# TODO: convert to HF Dataset objects


## Tokenizer & Encoding


In [None]:
# === TODO (you code this) ===
# Goal: Tokenize datasets and prepare for training.
# Hints:
# 1) Load tokenizer for 'distilbert-base-uncased'
# 2) Create tokenize function with max_length=512, truncation, padding
# 3) Map tokenize function to all datasets (batched=True)
# 4) Rename 'label_vec' â†’ 'labels', set format to torch
# Acceptance:
# - tokenizer loaded
# - All datasets have input_ids, attention_mask, labels
# - Format set to 'torch'

# TODO: tokenize and format datasets


## Model Initialization


In [None]:
# === TODO (you code this) ===
# Goal: Initialize DistilBERT for multi-label classification.
# Hints:
# 1) Use AutoModelForSequenceClassification.from_pretrained
# 2) Set num_labels=NUM_LABELS, problem_type='multi_label_classification'
# Acceptance:
# - model initialized from 'distilbert-base-uncased'
# - Configured for multi-label (uses BCEWithLogits loss)

# TODO: initialize model


## Metrics Function


In [None]:
# === TODO (you code this) ===
# Goal: Define metrics function for Trainer.
# Hints:
# 1) Apply sigmoid to logits, threshold at 0.5
# 2) Compute micro/macro precision, recall, F1 using sklearn
# 3) Return dict with 6 metrics
# Acceptance:
# - Function compute_metrics(eval_pred) -> dict
# - Returns precision/recall/f1 for both micro and macro

def compute_metrics(eval_pred):
    """Compute evaluation metrics for multi-label classification."""
    # TODO
    raise NotImplementedError


## Training Arguments & Trainer


In [None]:
# === TODO (you code this) ===
# Goal: Configure training parameters.
# Hints:
# 1) Set output_dir, eval_strategy, save_strategy, learning_rate
# 2) batch_size=8, epochs=3-4, warmup_ratio=0.1
# 3) load_best_model_at_end=True, metric_for_best_model='f1_micro'
# Acceptance:
# - TrainingArguments object configured
# - Will save to ../artifacts/model/checkpoints
# - Evaluates each epoch and keeps best

# TODO: create training_args


In [None]:
# === TODO (you code this) ===
# Goal: Train model and save best checkpoint.
# Hints:
# 1) Create Trainer with model, args, datasets, compute_metrics
# 2) Call trainer.train()
# 3) Save best model and tokenizer to ../artifacts/model/best
# Acceptance:
# - Training completes successfully
# - Best model saved based on micro-F1
# - Both model and tokenizer saved

# TODO: create Trainer, train, and save

# Run training (uncomment when ready)
# trainer = ...
# trainer.train()
# Save best model...


## Recommendations

- If rare labels underperform, try **class weights** or **focal loss** (stretch goal)
- Track **ROC/PR curves** on val to tune per-label thresholds later
- Monitor for **overfitting:** val metrics should not degrade while train improves

## ðŸ§˜ Reflection Log

**What did you learn in this session?**
- 

**What challenges did you encounter?**
- 

**How will this improve Periospot AI?**
- 
