### **Multi-label Email Classification Fine-Tuning Task on BERT Variants**

**Objective:** The goal of this task is to address the challenge of disorganized email inboxes, where multiple categories of emails are mixed together. Users often have to go through each email individually, making it difficult to quickly identify their type or priority. By fine-tuning BERT variants for multi-label classification, we aim to automatically categorize emails into their respective classes, improving inbox organization and user efficiency.

**Implementation Steps:**

#### **Project Setup**

In [1]:
# Import libraries
import pandas as pd
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from datasets import load_dataset
import torch
import gc
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from transformers import TrainingArguments, Trainer
from transformers import DataCollatorWithPadding
from peft import LoraConfig, get_peft_model, TaskType

from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import f1_score, precision_score, recall_score
from sklearn.metrics import accuracy_score, precision_recall_fscore_support, hamming_loss
import evaluate

#### **Load Dataset**

In [2]:
# Load the multi-label emails dataset
dataset = load_dataset("imnim/multiclass-email-classification")

In [3]:
# Check the dataset
print("Dataset: \n", dataset)
print("="*50)
print(f"Dataset Shape: Rows-{dataset['train'].num_rows}, Columns-{len(dataset['train'].column_names)}")
print("="*50)
print("Sample Data: \n", dataset['train'][0])


Dataset: 
 DatasetDict({
    train: Dataset({
        features: ['subject', 'body', 'labels'],
        num_rows: 2105
    })
})
Dataset Shape: Rows-2105, Columns-3
Sample Data: 
 {'subject': 'Meeting Reminder: Quarterly Sales Review Tomorrow', 'body': 'Dear Team, Just a friendly reminder that our Quarterly Sales Review meeting is scheduled for tomorrow at 10:00 AM in the conference room. Please make sure to bring your sales reports and any relevant updates. Coffee and pastries will be provided. Looking forward to a productive meeting. Best regards, [Your Name]', 'labels': ['Business', 'Reminders']}


#### **Data Preprocessing**

In [4]:
# Combine subject and body of each email into a single text field
def combine_text(examples):
    examples["text"] = examples["subject"] + " " + examples["body"]
    return examples


# Apply the function to the dataset
dataset = dataset.map(combine_text)

In [6]:
# Check the updated dataset
dataset["train"][0]

{'subject': 'Meeting Reminder: Quarterly Sales Review Tomorrow',
 'body': 'Dear Team, Just a friendly reminder that our Quarterly Sales Review meeting is scheduled for tomorrow at 10:00 AM in the conference room. Please make sure to bring your sales reports and any relevant updates. Coffee and pastries will be provided. Looking forward to a productive meeting. Best regards, [Your Name]',
 'labels': ['Business', 'Reminders'],
 'text': 'Meeting Reminder: Quarterly Sales Review Tomorrow Dear Team, Just a friendly reminder that our Quarterly Sales Review meeting is scheduled for tomorrow at 10:00 AM in the conference room. Please make sure to bring your sales reports and any relevant updates. Coffee and pastries will be provided. Looking forward to a productive meeting. Best regards, [Your Name]'}

In [12]:
# Split the dataset before further processing
split_dataset = dataset["train"].train_test_split(test_size=0.2, seed=42)
train_dataset = split_dataset["train"]
test_dataset = split_dataset["test"]

In [13]:
# Encode multi-labels using MultiLabelBinarizer
mlb = MultiLabelBinarizer()
mlb.fit(train_dataset["labels"])
num_labels = len(mlb.classes_)

# Define a function to encode the labels
def encode_labels(examples):
    # Transform the entire list of labels for each sample
    encoded = mlb.transform(examples["labels"])
    # Convert to float32 (important for multi-label)
    examples["labels_encoded"] = encoded.astype(np.float32).tolist()
    return examples

# Apply the function to the dataset
train_dataset = train_dataset.map(encode_labels, batched=True)
test_dataset = test_dataset.map(encode_labels, batched=True)

In [14]:
# Check the encoded dataset
print("Original labels:", train_dataset["labels"][:2])
print("Encoded labels:", train_dataset["labels_encoded"][:2])
print("Label classes:", mlb.classes_)

Original labels: [['Business', 'Reminders'], ['Promotions']]
Encoded labels: [[1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0]]
Label classes: ['Business' 'Customer Support' 'Events & Invitations' 'Finance & Bills'
 'Job Application' 'Newsletters' 'Personal' 'Promotions' 'Reminders'
 'Travel & Bookings']


In [15]:
# Tokenization
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

# Tokenize and encode from your original data
def tokenize_and_encode(examples):
    # Tokenize text
    tokenized = tokenizer(
        examples["text"], 
        truncation=True, 
        max_length=256,
        padding=False,
        return_tensors=None
    )
    
    # Encode labels
    encoded_labels = mlb.transform(examples["labels"])
    tokenized["labels"] = encoded_labels.astype(np.float32).tolist()
    
    return tokenized

# Apply to clean datasets 
train_dataset = train_dataset.map(tokenize_and_encode, batched=True)
test_dataset = test_dataset.map(tokenize_and_encode, batched=True)

# Remove unnecessary columns
train_dataset = train_dataset.remove_columns(["subject", "body", "text", "labels_encoded"])
test_dataset = test_dataset.remove_columns(["subject", "body", "text", "labels_encoded"])

print("Rebuilt datasets successfully!")
print("Train columns:", train_dataset.column_names)

Rebuilt datasets successfully!
Train columns: ['labels', 'input_ids', 'attention_mask']


In [16]:
# Verify the structure
print("Training sample:", train_dataset[0])

Training sample: {'labels': [1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0], 'input_ids': [101, 14764, 1024, 9046, 2136, 3116, 6203, 2136, 2372, 1010, 2023, 2003, 1037, 5379, 14764, 2008, 2057, 2031, 2256, 4882, 2136, 3116, 5115, 2005, 4826, 2012, 2184, 1024, 4002, 2572, 1012, 3531, 2191, 2469, 2000, 3319, 1996, 11376, 25828, 1998, 2272, 4810, 2007, 2151, 14409, 2030, 20062, 2000, 3745, 1012, 2292, 1005, 1055, 2031, 1037, 13318, 3116, 1998, 6848, 2256, 5082, 2006, 7552, 3934, 1012, 2559, 2830, 2000, 3773, 2017, 2035, 2045, 999, 2190, 12362, 1010, 1031, 2115, 2171, 1033, 102], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}


In [17]:
# Prepare the label ids
id2label = {i: label for i, label in enumerate(mlb.classes_)}
label2id = {label: i for i, label in enumerate(mlb.classes_)}

# Check the mappings
print("ID to Label:", id2label)
print("Label to ID:", label2id)

ID to Label: {0: 'Business', 1: 'Customer Support', 2: 'Events & Invitations', 3: 'Finance & Bills', 4: 'Job Application', 5: 'Newsletters', 6: 'Personal', 7: 'Promotions', 8: 'Reminders', 9: 'Travel & Bookings'}
Label to ID: {'Business': 0, 'Customer Support': 1, 'Events & Invitations': 2, 'Finance & Bills': 3, 'Job Application': 4, 'Newsletters': 5, 'Personal': 6, 'Promotions': 7, 'Reminders': 8, 'Travel & Bookings': 9}


In [18]:
# Define the data collator for dynamic padding
data_collator = DataCollatorWithPadding(
    tokenizer=tokenizer,
    padding=True,           # Enable padding
    return_tensors="pt",    # Return PyTorch tensors
)

#### **Model Fine-Tuning**

In [19]:
# Clear memory
if torch.backends.mps.is_available():
    torch.mps.empty_cache()
else:
    None
gc.collect()

322

In [20]:
# Load pre-trained DistilBERT model for sequence classification
print("Loading DistilBERT model for fine-tuning...")
model = AutoModelForSequenceClassification.from_pretrained(
    "distilbert-base-uncased",
    num_labels=num_labels,
    id2label=id2label,
    label2id=label2id,
    problem_type="multi_label_classification"
)
print(f"Model loaded with {num_labels} labels: {list(mlb.classes_)}")

Loading DistilBERT model for fine-tuning...


Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded with 10 labels: ['Business', 'Customer Support', 'Events & Invitations', 'Finance & Bills', 'Job Application', 'Newsletters', 'Personal', 'Promotions', 'Reminders', 'Travel & Bookings']


In [21]:
# Configure LoRA for parameter-efficient fine-tuning
lora_config = LoraConfig(
    task_type=TaskType.SEQ_CLS,
    inference_mode=False,
    r=16,
    lora_alpha=32,
    lora_dropout=0.1,
    target_modules=["q_lin", "k_lin", "v_lin", "out_lin"],
    bias="none"
)

# Apply LoRA to the model
print("LoRA configuration applied to the model.")
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

LoRA configuration applied to the model.
trainable params: 1,188,106 || all params: 68,149,268 || trainable%: 1.7434


In [22]:
# Set device training on MPS
if torch.backends.mps.is_available():
    device = torch.device("mps")
    print("Using MPS (Apple Silicon GPU)")
elif torch.cuda.is_available():
    device = torch.device("cuda")
    print("Using CUDA (NVIDIA GPU)")
else:
    device = torch.device("cpu")
    print("Using CPU")

# Move model to device
model.to(device)

Using MPS (Apple Silicon GPU)


PeftModelForSequenceClassification(
  (base_model): LoraModel(
    (model): DistilBertForSequenceClassification(
      (distilbert): DistilBertModel(
        (embeddings): Embeddings(
          (word_embeddings): Embedding(30522, 768, padding_idx=0)
          (position_embeddings): Embedding(512, 768)
          (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
          (dropout): Dropout(p=0.1, inplace=False)
        )
        (transformer): Transformer(
          (layer): ModuleList(
            (0-5): 6 x TransformerBlock(
              (attention): DistilBertSdpaAttention(
                (dropout): Dropout(p=0.1, inplace=False)
                (q_lin): lora.Linear(
                  (base_layer): Linear(in_features=768, out_features=768, bias=True)
                  (lora_dropout): ModuleDict(
                    (default): Dropout(p=0.1, inplace=False)
                  )
                  (lora_A): ModuleDict(
                    (default): Linear(in_features=76

In [23]:
# Define evaluation logic for multi-label classification
def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    
    # Apply sigmoid to convert logits to probabilities
    sigmoid = torch.nn.Sigmoid()
    probs = sigmoid(torch.Tensor(predictions))
    
    # Convert probabilities to binary predictions (threshold = 0.5)
    y_pred = np.zeros(probs.shape)
    y_pred[np.where(probs >= 0.5)] = 1
    
    # Ensure labels are in correct format
    y_true = labels
    
    # Calculate various multi-label metrics
    
    # Exact match: all labels must be predicted correctly
    exact_match = accuracy_score(y_true, y_pred)
    
    # Hamming loss: fraction of incorrectly predicted labels
    hamming = hamming_loss(y_true, y_pred)
    
    # Micro-averaged metrics (global across all labels)
    precision_micro, recall_micro, f1_micro, _ = precision_recall_fscore_support(
        y_true, y_pred, average='micro', zero_division=0
    )
    
    # Macro-averaged metrics (average per label)
    precision_macro, recall_macro, f1_macro, _ = precision_recall_fscore_support(
        y_true, y_pred, average='macro', zero_division=0
    )
    
    # Calculate label-wise accuracy (useful for imbalanced datasets)
    label_accuracy = np.mean((y_true == y_pred).all(axis=1))
    
    return {
        'exact_match_accuracy': exact_match,
        'label_accuracy': label_accuracy,
        'hamming_loss': hamming,
        'f1_micro': f1_micro,
        'f1_macro': f1_macro,
        'precision_micro': precision_micro,
        'precision_macro': precision_macro,
        'recall_micro': recall_micro,
        'recall_macro': recall_macro,
    }

In [24]:
# Setup the training arguments
training_args = TrainingArguments(
    # Output and logging
    output_dir="multi-label-email-classification-finetuning/distilbert-lora-multi-label",
    logging_dir="multi-label-email-classification-finetuning/logs",
    logging_steps=50,
    
    # Training parameters
    num_train_epochs=3,
    per_device_train_batch_size=4,      # Small batch size for memory efficiency
    per_device_eval_batch_size=8,       # Can be larger for evaluation
    gradient_accumulation_steps=4,       # Simulates batch size of 16 (4*4)
    
    # Learning rate and optimization
    learning_rate=5e-4,                 # Higher learning rate for LoRA
    weight_decay=0.01,
    warmup_steps=100,
    
    # Evaluation and saving (UPDATED NAMES)
    eval_strategy="steps",              # Changed from evaluation_strategy
    eval_steps=200,
    save_strategy="steps",
    save_steps=200,
    load_best_model_at_end=True,
    metric_for_best_model="eval_f1_micro",  # Added 'eval_' prefix
    greater_is_better=True,
    
    # Memory optimization (especially for Apple Silicon)
    dataloader_pin_memory=False,        # Disable for MPS
    dataloader_num_workers=0,           # Best for MPS
    gradient_checkpointing=True,        # Trade speed for memory
    
    # Reproducibility
    seed=42,
    
    # Reporting
    report_to=None,  # Disable wandb/tensorboard if not needed
    
    # Early stopping patience
    save_total_limit=2,
)

print("Training arguments configured")

Training arguments configured


In [25]:
#  Create the Trainer with all components
trainer = Trainer(
    model=model,                        # LoRA-adapted model
    args=training_args,                 # Training configuration
    train_dataset=train_dataset,        # Training data
    eval_dataset=test_dataset,          # Evaluation data
    processing_class=tokenizer,         # Tokenizer for text processing
    data_collator=data_collator,        # Dynamic padding
    compute_metrics=compute_metrics,    # Evaluation metrics
)

print("Trainer initialized successfully")
print(f"Training samples: {len(train_dataset)}")
print(f"Evaluation samples: {len(test_dataset)}")

Trainer initialized successfully
Training samples: 1684
Evaluation samples: 421


In [26]:
#  Display model info before training
print("\n" + "="*50)
print("MODEL INFORMATION")
print("="*50)
print(f"Total parameters: {model.num_parameters():,}")
model.print_trainable_parameters()


MODEL INFORMATION
Total parameters: 68,149,268
trainable params: 1,188,106 || all params: 68,149,268 || trainable%: 1.7434


In [27]:
# Start training
print("\n" + "="*50)
print("STARTING TRAINING")
print("="*50)

try:
    # Train the model
    training_results = trainer.train()
    
    print("Training completed successfully!")
    print(f"Final training loss: {training_results.training_loss:.4f}")
    
except Exception as e:
    print(f"Training failed with error: {e}")
    # If memory error, suggest reducing batch size
    if "memory" in str(e).lower():
        print("Suggestion: Reduce per_device_train_batch_size to 2 or 1")


STARTING TRAINING




Step,Training Loss,Validation Loss,Exact Match Accuracy,Label Accuracy,Hamming Loss,F1 Micro,F1 Macro,Precision Micro,Precision Macro,Recall Micro,Recall Macro,Runtime,Samples Per Second,Steps Per Second
200,0.2323,0.236745,0.437055,0.437055,0.095249,0.663308,0.481088,0.768482,0.647561,0.583456,0.47047,2.671,157.62,19.843




Training completed successfully!
Final training loss: 0.2934


In [28]:
# 11.1: Evaluate on test set
print("\n" + "="*50)
print("FINAL EVALUATION")
print("="*50)

eval_results = trainer.evaluate()

# 11.2: Display results in a formatted way
print("Multi-Label Classification Results:")
print("-" * 40)
for metric, value in eval_results.items():
    if metric.startswith('eval_'):
        metric_name = metric.replace('eval_', '').replace('_', ' ').title()
        print(f"{metric_name:<25}: {value:.4f}")

# 11.3: Analyze per-label performance
print("\n" + "="*30)
print("PER-LABEL ANALYSIS")
print("="*30)

# Get predictions for detailed analysis
predictions = trainer.predict(test_dataset)
y_pred_probs = torch.sigmoid(torch.tensor(predictions.predictions))
y_pred = (y_pred_probs >= 0.5).int().numpy()
y_true = predictions.label_ids

# Calculate per-label metrics
for i, label_name in enumerate(mlb.classes_):
    label_true = y_true[:, i]
    label_pred = y_pred[:, i]
    
    # Skip if no positive samples
    if label_true.sum() == 0:
        continue
        
    precision = precision_recall_fscore_support(label_true, label_pred, average='binary')[0]
    recall = precision_recall_fscore_support(label_true, label_pred, average='binary')[1]
    f1 = precision_recall_fscore_support(label_true, label_pred, average='binary')[2]
    support = label_true.sum()
    
    print(f"{label_name:<15}: P={precision:.3f}, R={recall:.3f}, F1={f1:.3f}, Support={support}")


FINAL EVALUATION


Multi-Label Classification Results:
----------------------------------------
Loss                     : 0.2367
Exact Match Accuracy     : 0.4371
Label Accuracy           : 0.4371
Hamming Loss             : 0.0952
F1 Micro                 : 0.6633
F1 Macro                 : 0.4811
Precision Micro          : 0.7685
Precision Macro          : 0.6476
Recall Micro             : 0.5835
Recall Macro             : 0.4705
Runtime                  : 2.3773
Samples Per Second       : 177.0940
Steps Per Second         : 22.2940

PER-LABEL ANALYSIS
Business       : P=0.777, R=0.812, F1=0.794, Support=176.0
Customer Support: P=1.000, R=0.024, F1=0.047, Support=42.0
Events & Invitations: P=0.789, R=0.563, F1=0.657, Support=126.0
Finance & Bills: P=0.887, R=0.829, F1=0.857, Support=76.0
Job Application: P=0.000, R=0.000, F1=0.000, Support=30.0
Newsletters    : P=0.667, R=0.067, F1=0.121, Support=30.0
Personal       : P=0.000, R=0.000, F1=0.000, Support=59.0
Promotions     : P=0.889, R=0.696, F1=0.780,

  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])
  _warn_prf(average, modifier, f"{metric.capitalize()} is", result.shape[0])


In [29]:
# # 12.1: Save the fine-tuned model
# print("\n" + "="*30)
# print("SAVING MODEL")
# print("="*30)

# save_path = "./final-distilbert-lora-multi-label"
# trainer.save_model(save_path)
# tokenizer.save_pretrained(save_path)

# print(f"Model saved to: {save_path}")

# 12.2: Create prediction function for new texts
def predict_email_labels(text, threshold=0.5, return_probabilities=False):
    """
    Predict labels for a new email text
    
    Args:
        text (str): Email text to classify
        threshold (float): Probability threshold for prediction
        return_probabilities (bool): Whether to return probabilities
    
    Returns:
        list: Predicted labels
        dict (optional): Label probabilities if return_probabilities=True
    """
    # Tokenize input text
    inputs = tokenizer(
        text, 
        return_tensors="pt", 
        truncation=True, 
        max_length=256,
        padding=True
    ).to(device)
    
    # Get model predictions
    model.eval()
    with torch.no_grad():
        outputs = model(**inputs)
        # Apply sigmoid to get probabilities
        probabilities = torch.sigmoid(outputs.logits)[0]
    
    # Convert to numpy for easier handling
    probs = probabilities.cpu().numpy()
    
    # Get predicted labels based on threshold
    predicted_labels = []
    label_probs = {}
    
    for i, prob in enumerate(probs):
        label_name = mlb.classes_[i]
        label_probs[label_name] = float(prob)
        
        if prob >= threshold:
            predicted_labels.append(label_name)
    
    if return_probabilities:
        return predicted_labels, label_probs
    else:
        return predicted_labels

# 12.3: Test the prediction function
print("\n" + "="*30)
print("TESTING PREDICTIONS")
print("="*30)

# Test with sample emails
test_emails = [
    "Reminder: Team meeting tomorrow at 10 AM in conference room. Please bring your reports.",
    "50% off all items this weekend! Don't miss this amazing deal. Shop now!",
    "Your account statement for this month is ready. Please review the attached document.",
    "Breaking: Major tech company announces new product launch. Stock prices surge."
]

for i, email in enumerate(test_emails, 1):
    print(f"\nTest Email {i}:")
    print(f"Text: {email[:60]}...")
    
    labels, probs = predict_email_labels(email, return_probabilities=True)
    
    print(f"Predicted Labels: {labels}")
    print("Top Probabilities:")
    sorted_probs = sorted(probs.items(), key=lambda x: x[1], reverse=True)[:3]
    for label, prob in sorted_probs:
        print(f"  {label}: {prob:.3f}")


TESTING PREDICTIONS

Test Email 1:
Text: Reminder: Team meeting tomorrow at 10 AM in conference room....
Predicted Labels: ['Business', 'Reminders']
Top Probabilities:
  Business: 0.792
  Reminders: 0.776
  Events & Invitations: 0.256

Test Email 2:
Text: 50% off all items this weekend! Don't miss this amazing deal...
Predicted Labels: ['Promotions']
Top Probabilities:
  Promotions: 0.553
  Newsletters: 0.396
  Events & Invitations: 0.128

Test Email 3:
Text: Your account statement for this month is ready. Please revie...
Predicted Labels: []
Top Probabilities:
  Travel & Bookings: 0.386
  Finance & Bills: 0.299
  Customer Support: 0.188

Test Email 4:
Text: Breaking: Major tech company announces new product launch. S...
Predicted Labels: ['Business', 'Events & Invitations']
Top Probabilities:
  Business: 0.750
  Events & Invitations: 0.627
  Newsletters: 0.447
