# Multilabel Text Classification with BERT

This notebook demonstrates fine-tuning a BERT transformer model for multilabel classification.

Dataset: Jigsaw Toxic Comments, annotated with multiple toxic-related categories.

The model is fine-tuned with:

- Pretrained BERT base uncased
- A classification head with sigmoid output
- Binary cross-entropy loss for multilabel learning

In [2]:
import warnings
warnings.filterwarnings('ignore')

In [None]:
!pip install datasets transformers torch --quiet

## 1. Dataset Loading & Exploration

We load the Jigsaw Toxic Comments dataset.


In [34]:
from datasets import load_dataset
from transformers import BertTokenizerFast, BertForSequenceClassification, Trainer, TrainingArguments
import numpy as np
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score
from transformers import BertTokenizerFast, BertConfig, BertModel, TrainingArguments, Trainer
from transformers.modeling_outputs import SequenceClassifierOutput
import torch
import torch.nn as nn

In [41]:
# # Use the "civil_comments" multilabel toxicity dataset (smaller, public)
# dataset = load_dataset("civil_comments")

# print(dataset['train'].column_names)
# # ['text', 'toxicity', 'severe_toxicity', 'obscene', 'identity_attack', 'insult', 'threat']

# # Prepare labels
# label_cols = ['toxicity', 'severe_toxicity', 'obscene', 'threat', 'insult', 'identity_attack', 'sexual_explicit']

['text', 'toxicity', 'severe_toxicity', 'obscene', 'threat', 'insult', 'identity_attack', 'sexual_explicit']


In [None]:
# def process_labels_batch(batch):
#     label_cols_for_conversion = [col for col in label_cols if col != 'text']

#     batch['labels'] = [
#         [float(batch[col][i]) for col in label_cols_for_conversion]
#         for i in range(len(batch[label_cols_for_conversion[0]]))
#     ]
#     return batch

# # Apply transformation
# dataset = dataset.map(process_labels_batch, batched=True)

## 2. Text Tokenization & Encoding

Tokenization is performed using Hugging Face's `BertTokenizer`.

Input sequences are padded and truncated to a fixed max length.

Attention masks are created to distinguish padding tokens.

In [None]:
# # Load tokenizer
# tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')

# def tokenize(batch):
#     return tokenizer(batch['text'], padding='max_length', truncation=True, max_length=128)

# dataset = dataset.map(tokenize, batched=True)

# # Set format for PyTorch
# dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])

In [None]:
# Load preprocessed dataset from Hugging Face Hub
dataset = load_dataset("Koushim/processed-jigsaw-toxic-comments")

# Define the label columns
label_cols = [
    'toxicity', 'severe_toxicity', 'obscene',
    'threat', 'insult', 'identity_attack', 'sexual_explicit'
]

# No further preprocessing needed — already tokenized and formatted
dataset.set_format(type='torch', columns=['input_ids', 'attention_mask', 'labels'])


In [36]:
# Compute pos_weight
def compute_pos_weights(dataset, label_cols):
    labels = np.array(dataset['train']['labels'])  # Shape: (N, 7)
    num_samples = labels.shape[0]
    pos_counts = np.sum(labels, axis=0)
    neg_counts = num_samples - pos_counts
    pos_weights = neg_counts / (pos_counts + 1e-5)  # Avoid divide-by-zero
    return torch.tensor(pos_weights, dtype=torch.float)

pos_weight = compute_pos_weights(dataset, label_cols)
pos_weight

tensor([  8.6964, 217.1675,  71.0950, 106.3975,  11.3190,  43.1823, 150.3692])

## 3. Model Architecture & Fine-Tuning

We load pretrained BERT and add a classification head with sigmoid activation.

The model is fine-tuned end-to-end using `BCEWithLogitsLoss`.

Training uses AdamW optimizer and learning rate scheduling.

In [37]:
class CustomBertForMultiLabel(nn.Module):
    def __init__(self, base_model_name, num_labels, pos_weight):
        super().__init__()
        self.bert = BertModel.from_pretrained(base_model_name)
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_labels)
        self.loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = self.dropout(outputs.pooler_output)
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            loss = self.loss_fn(logits, labels)

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
        )

# Initialize model
model = CustomBertForMultiLabel('bert-base-uncased', num_labels=len(label_cols), pos_weight=pos_weight)

## 4. Training & Evaluation

The model is trained on Jigsaw toxic comments with multilabel metrics:

- Accuracy, F1-score, Precision, recall

In [38]:
def compute_metrics(pred):
    preds = pred.predictions
    labels = pred.label_ids

    sigmoid_preds = 1 / (1 + np.exp(-preds))
    pred_labels = (sigmoid_preds >= 0.5).astype(int)
    labels = (labels >= 0.5).astype(int)

    precision = precision_score(labels, pred_labels, average='macro', zero_division=0)
    recall = recall_score(labels, pred_labels, average='macro', zero_division=0)
    f1 = f1_score(labels, pred_labels, average='macro', zero_division=0)
    acc = accuracy_score(labels, pred_labels)

    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

In [41]:
training_args = TrainingArguments(
    output_dir='./results',
    num_train_epochs=2,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_dir='./logs',
    logging_steps=1000,
    report_to="none",
    disable_tqdm=False,
    load_best_model_at_end=True,
    metric_for_best_model="f1",
    greater_is_better=True,
)

# Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset['train'].shuffle(seed=42).select(range(10000)),  # limit for speed
    eval_dataset=dataset['validation'].select(range(2000)),
    compute_metrics=compute_metrics,
)

In [42]:
trainer.train()

Epoch,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
1,No log,0.988396,0.683,0.141831,0.08775,0.795741
2,No log,1.025657,0.6855,0.137669,0.082662,0.832361


TrainOutput(global_step=626, training_loss=0.9331124887679713, metrics={'train_runtime': 322.6315, 'train_samples_per_second': 61.99, 'train_steps_per_second': 1.94, 'total_flos': 0.0, 'train_loss': 0.9331124887679713, 'epoch': 2.0})

In [43]:
# Evaluate
results = trainer.evaluate()
print(results)

{'eval_loss': 0.9883957505226135, 'eval_accuracy': 0.683, 'eval_f1': 0.14183121169917884, 'eval_precision': 0.08774965760469693, 'eval_recall': 0.7957405184295939, 'eval_runtime': 8.7431, 'eval_samples_per_second': 228.753, 'eval_steps_per_second': 1.83, 'epoch': 2.0}


## 5. Conclusion

Transformer-based models like BERT significantly improve multilabel classification performance.

Future improvements could include:

- Larger models (RoBERTa, DeBERTa)
- Data augmentation
- Ensemble techniques