## Toxic Comment Classification using DistilBERT with Hyperparameter Tuning

This script trains a multi-label toxic comment classifier using the DistilBERT architecture with Optuna for hyperparameter optimization. It implements stratified k-fold validation, mixed-precision training, and gradient accumulation.

### Key Features
- Text preprocessing with regex patterns and contractions handling  
- Multilabel stratified data splitting  
- Custom classifier head with layer normalization  
- Optuna integration for hyperparameter search  
- Mixed-precision training with gradient accumulation  
- Model checkpointing and metadata saving  

**Author:** Jaejun Shim 

**Date:** 2024-12-23

**License:** MIT  

### Importing Libraries

In [None]:
import os
import re
import time
import joblib
import torch
import numpy as np
import pandas as pd
import optuna
from torch.utils.data import DataLoader, Dataset
from torch.optim import AdamW
from torch.amp import autocast
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit
import shutil
import zipfile

### Configurations

In [None]:
class Config:
    """Global configuration parameters."""
    RANDOM_SEED = 42
    BASE_MODEL = "distilbert-base-uncased"
    DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    DATA_PATH = "./data"
    MODEL_SAVE_PATH = "./saved_model"
    BEST_PARAMS_PATH = "./best_params.pkl"
    
    # Training Flags
    HP_TUNING = False
    FT_TUNING = False
    FULL_TRAINING = True

    # Optimization Parameters
    ACCUMULATION_STEPS = 4
    MAX_EPOCHS = 10
    PATIENCE = 2
torch.manual_seed(Config.RANDOM_SEED)
np.random.seed(Config.RANDOM_SEED)
torch.backends.cudnn.benchmark = True

### Data Processing

In [None]:
# Downloading Files From Kaggle
!kaggle competitions download -c jigsaw-toxic-comment-classification-challenge -q
with zipfile.ZipFile("jigsaw-toxic-comment-classification-challenge.zip", "r") as zip_ref:
    zip_ref.extractall(Config.DATA_PATH)
with zipfile.ZipFile("data/train.csv.zip", "r") as zip_ref:
    zip_ref.extractall(Config.DATA_PATH)
for file in ['jigsaw-toxic-comment-classification-challenge.zip', './data/sample_submission.csv.zip', './data/test_labels.csv.zip',
             './data/test.csv.zip', './data/train.csv.zip']:
    if os.path.exists(file):
        os.remove(file)
    else:
        raise FileNotFoundError(f"File {file} not found!")

In [None]:
def clean_text(text: str) -> str:
    """
    Clean and normalize text data using regex patterns and contractions handling.
    
    Args:
        text: Raw input text to be cleaned.
        
    Returns:
        str: Cleaned and normalized text.
    """
    # Remove special characters
    text = re.sub('[“”¨«»®´·º½¾¿¡§£₤‘’â€¢&]', "", text)
    
    # Remove Wikipedia artifacts
    text = re.sub(r"WP:[A-Z]+", "", text, flags=re.IGNORECASE)
    
    # Replace sensitive information
    text = re.sub(r'\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}\b', 
                 'emailaddress', text, flags=re.IGNORECASE)
    text = re.sub(r'https?://\S+', 'websiteurl', text, flags=re.IGNORECASE)
    text = re.sub(r'\(?\d{3}\)?[\s-]?\d{3}[\s-]?\d{4}', 'phonenumber', text)
    
    # Normalize whitespace and fix contractions
    text = re.sub(r'\s+', ' ', text).strip().lower()
    return contractions.fix(text)

In [None]:
def load_data(file_name: str, subsample: bool = False) -> tuple:
    """
    Load and preprocess dataset with optional subsampling.
    
    Args:
        file_name: Name of the CSV file containing the data.
        subsample: Whether to subsample the data for faster experimentation.
        
    Returns:
        tuple: (texts, labels) as numpy arrays.
    """
    df = pd.read_csv(os.path.join(Config.DATA_PATH, file_name))
    texts = df['comment_text']
    labels = df.iloc[:,2:].values.astype(np.float32)
    
    if subsample:
        splitter = MultilabelStratifiedShuffleSplit(
            n_splits=1, 
            test_size=0.6,
            random_state=Config.RANDOM_SEED
        )
        for train_index, _ in splitter.split(texts, labels):
            texts = texts[train_index]
            labels = labels[train_index]
    
    return texts.apply(clean_text).values, labels

### PyTorch Data Pipeline

In [None]:
class ToxicCommentDataset(Dataset):
    """Custom dataset for toxic comment classification"""
    
    def __init__(self, texts: np.ndarray, labels: np.ndarray, 
                 tokenizer: AutoTokenizer, max_length: int):
        """
        Initialize dataset.
        
        Args:
            texts: Array of preprocessed text samples.
            labels: Array of multi-hot encoded labels.
            tokenizer: Pretrained tokenizer instance.
            max_length: Maximum sequence length for truncation.
        """
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_length = max_length

    def __len__(self) -> int:
        return len(self.texts)

    def __getitem__(self, idx: int) -> dict:
        """Return tokenized sample and labels."""
        text = self.texts[idx]
        labels = self.labels[idx]
        
        encoding = self.tokenizer(
            text,
            padding='max_length',
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(labels, dtype=torch.float)
        }

### Model Architecture

In [None]:
def initialize_model(params: dict) -> AutoModelForSequenceClassification:
    """
    Initialize DistilBERT model with custom classifier head.
    
    Args:
        params: Dictionary of hyperparameters.
        
    Returns:
        AutoModelForSequenceClassification: Configured model instance.
    """
    model = AutoModelForSequenceClassification.from_pretrained(
        Config.BASE_MODEL,
        num_labels=6,  # Number of toxicity classes
        dropout=params['dropout'],
        attention_dropout=params['dropout'],
    ).to(Config.DEVICE)
    
    # Freeze base layers except last two
    for param in model.distilbert.parameters():
        param.requires_grad = False
    for layer in model.distilbert.transformer.layer[-2:]:
        for param in layer.parameters():
            param.requires_grad = True
            
    # Custom classifier with layer normalization
    model.classifier = torch.nn.Sequential(
        torch.nn.Dropout(params['dropout']),
        torch.nn.Linear(model.config.hidden_size, params['hidden_units']),
        torch.nn.ReLU(),
        torch.nn.LayerNorm(params['hidden_units']),
        torch.nn.Dropout(params['dropout']),
        torch.nn.Linear(params['hidden_units'], 6)
    ).to(Config.DEVICE)
    
    return model

### Training Utilities

In [None]:
def evaluate_model(model: AutoModelForSequenceClassification, 
                  dataloader: DataLoader) -> float:
    """
    Evaluate model performance on validation set.
    
    Args:
        model: Model to evaluate.
        dataloader: Validation dataloader.
        
    Returns:
        float: Weighted F1 score.
    """
    model.eval()
    all_preds, all_labels = [], []
    
    with torch.no_grad():
        for batch in dataloader:
            inputs = {k: v.to(Config.DEVICE) for k, v in batch.items() 
                     if k != 'labels'}
            labels = batch['labels'].cpu().numpy()
            
            outputs = model(**inputs)
            preds = torch.sigmoid(outputs.logits).cpu().numpy()
            all_preds.extend(preds)
            all_labels.extend(labels)
    
    threshold_preds = (np.array(all_preds) > 0.5).astype(int)
    return f1_score(all_labels, threshold_preds, average='weighted')

In [None]:
def train_epoch(model: nn.Module, 
                loader: DataLoader, 
                optimizer: torch.optim.Optimizer,
                scaler: torch.amp.GradScaler) -> float:
    """
    Perform single epoch of training with gradient accumulation
    
    Args:
        model: Model to train
        loader: Training data loader
        optimizer: Configured optimizer
        scaler: Gradient scaler for mixed precision
        
    Returns:
        float: Average training loss for the epoch
    """
    model.train()
    total_loss = 0.0
    
    for step, batch in enumerate(loader):
        # Move data to device
        inputs = {k: v.to(Config.DEVICE, non_blocking=True) 
                 for k, v in batch.items() if k != 'labels'}
        labels = batch['labels'].to(Config.DEVICE, non_blocking=True)
        
        # Mixed precision forward pass
        with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
            outputs = model(**inputs)
            loss = torch.nn.BCEWithLogitsLoss()(outputs.logits, labels)
            loss = loss / Config.ACCUMULATION_STEPS
            
        # Gradient accumulation
        scaler.scale(loss).backward()
        
        # Weight update every accumulation_steps
        if (step + 1) % Config.ACCUMULATION_STEPS == 0:
            scaler.unscale_(optimizer)
            torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
            scaler.step(optimizer)
            scaler.update()
            optimizer.zero_grad()
            
        total_loss += loss.item() * Config.ACCUMULATION_STEPS
        
    return total_loss / len(loader)

In [None]:
def run_early_stopping_training(model: nn.Module,
                               train_loader: DataLoader,
                               val_loader: DataLoader,
                               optimizer: torch.optim.Optimizer,
                               max_epochs: int = Config.MAX_EPOCHS) -> dict:
    """
    Complete training loop with early stopping.
    
    Args:
        model: Initialized model.
        train_loader: Training data loader.
        val_loader: Validation data loader.
        optimizer: Configured optimizer.
        max_epochs: Maximum number of training epochs.
        
    Returns:
        dict: Training results containing:
            - best_f1: Best validation F1 score.
            - best_epoch: Epoch number of best performance.
            - model_weights: Best model state dict.
    """
    scaler = torch.amp.GradScaler(enabled=True)
    best_results = {
        'best_f1': 0.0,
        'best_epoch': 0,
        'model_weights': None,
        'patience_counter': 0
    }
    
    for epoch in range(max_epochs):
        # Training phase
        train_loss = train_epoch(model, train_loader, optimizer, scaler)
        
        # Validation phase
        current_f1 = evaluate_model(model, val_loader)
        print(f"Epoch {epoch+1}/{max_epochs} | "
             f"Train Loss: {train_loss:.4f} | "
             f"Val F1: {current_f1:.4f}")
        
        # Early stopping check
        if current_f1 > best_results['best_f1']:
            best_results.update({
                'best_f1': current_f1,
                'best_epoch': epoch + 1,
                'model_weights': model.state_dict().copy(),
                'patience_counter': 0
            })
        else:
            best_results['patience_counter'] += 1
            if best_results['patience_counter'] >= Config.PATIENCE:
                print(f"Early stopping triggered at epoch {epoch+1}")
                break
                
    return best_results

### Hyperparameter Tuning

In [None]:
def objective(trial: optuna.Trial) -> float:
    """Optuna objective function for hyperparameter optimization"""
    start_time = time.time()
    
    # Load and split data
    texts, labels = load_data("train.csv", subsample=True)
    train_texts, val_texts, train_labels, val_labels = train_test_split(
        texts, labels, 
        test_size=0.2, 
        random_state=Config.RANDOM_SEED + trial.number
    )
    
    # Suggest hyperparameters
    params = {
        'max_length': trial.suggest_categorical('max_length', [128, 256]),
        'lr': trial.suggest_float('lr', 1e-5, 5e-5, log=True),
        'batch_size': trial.suggest_categorical('batch_size', [32, 64]),
        'num_epochs': trial.suggest_int('num_epochs', 2, 3),
        'dropout': trial.suggest_float('dropout', 0.1, 0.4),
        'hidden_units': trial.suggest_categorical('hidden_units', [256, 512]),
    }
    
    # Create data loaders
    train_dataset = ToxicCommentDataset(train_texts, train_labels, 
                                       tokenizer, params['max_length'])
    train_loader = DataLoader(train_dataset, 
                             batch_size=params['batch_size'], 
                             shuffle=True, 
                             num_workers=4)
    
    # Initialize model and optimizer
    model = initialize_model(params)
    optimizer = torch.optim.AdamW(model.parameters(), lr=params['lr'])
    scaler = torch.amp.GradScaler(enabled=True)
    
    # Training loop
    model.train()
    for epoch in range(params['num_epochs']):
        optimizer.zero_grad()
        
        for step, batch in enumerate(train_loader):
            inputs = {k: v.to(Config.DEVICE) for k, v in batch.items() 
                     if k != 'labels'}
            labels = batch['labels'].to(Config.DEVICE)
            
            with torch.amp.autocast(device_type='cuda', dtype=torch.float16):
                outputs = model(**inputs)
                loss = torch.nn.BCEWithLogitsLoss()(outputs.logits, labels)
                loss = loss / Config.ACCUMULATION_STEPS
                
            scaler.scale(loss).backward()
            
            if (step + 1) % Config.ACCUMULATION_STEPS == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
    
    # Validation and cleanup
    val_dataset = ToxicCommentDataset(val_texts, val_labels, 
                                     tokenizer, params['max_length'])
    val_loader = DataLoader(val_dataset, 
                           batch_size=params['batch_size']*2, 
                           shuffle=False)
    
    val_f1 = evaluate_model(model, val_loader)
    del model, train_loader, val_loader
    torch.cuda.empty_cache()
    
    print(f"Trial {trial.number} completed in {(time.time()-start_time)/60:.1f}m")
    return val_f1

#### Further Tuning

In [None]:
def further_tune_model(best_params: dict):
    """
    Given the best hyperparameters from Optuna study, further tuen the model to find the best epoch number.
    
    Args:
        best_params: Optimized hyperparameters from Optuna study

    Returns:
        dict: Training results containing:
            - best_f1: Best validation F1 score.
            - best_epoch: Epoch number of best performance.
            - model_weights: Best model weights.
            - patience_counter: Number of epochs without improvement.
    """
    # Data preparation
    texts, labels = load_data("train.csv")
    splitter = MultilabelStratifiedShuffleSplit(n_splits=1, test_size=0.2)
    
    tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL)
    for train_idx, val_idx in splitter.split(texts, labels):
        train_dataset = ToxicCommentDataset(texts[train_idx], labels[train_idx],
                                           tokenizer, best_params['max_length'])
        val_dataset = ToxicCommentDataset(texts[val_idx], labels[val_idx],
                                         tokenizer, best_params['max_length'])
        
    # Initialize training components
    model = initialize_model(best_params)
    optimizer = torch.optim.AdamW(model.parameters(), lr=best_params['lr'])
    train_loader = DataLoader(train_dataset,
                             batch_size=best_params['batch_size'],
                             shuffle=True,
                             num_workers=4)
    val_loader = DataLoader(val_dataset,
                           batch_size=best_params['batch_size']*2,
                           shuffle=False)
    
    # Run training with early stopping
    training_results = run_early_stopping_training(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer
    )
    
    return training_results

### Main Execution

#### Hyperparameter Tuning

In [None]:
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL)

# Hyperparameter Optimization
if Config.HP_TUNING:
    study = optuna.create_study(
        direction="maximize",
        sampler=optuna.samplers.TPESampler(seed=Config.RANDOM_SEED),
        storage="sqlite:///optuna.db",
        study_name="toxic_comment_study",
        load_if_exists=True
    )
    study.optimize(objective, n_trials=200)
    joblib.dump(study.best_params, Config.BEST_PARAMS_PATH)

#### Further Tuning

In [None]:
if Config.FT_TUNING:
    best_params = joblib.load(Config.BEST_PARAMS_PATH)
    training_results = further_tune_model(best_params)

#### Final Model Tuning

In [None]:
if Config.FULL_TRAINING:
    best_params = joblib.load(Config.BEST_PARAMS_PATH)
    texts, labels = load_data("train.csv")
    
    # Load full data
    full_dataset = ToxicCommentDataset(texts, labels, tokenizer, best_params['max_length'])
    
    # Initialize model and training components
    final_model = initialize_model(best_params)
    tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL)
    optimizer = torch.optim.AdamW(final_model.parameters(), lr=best_params['lr'])
    full_loader = DataLoader(full_dataset, 
                                batch_size=best_params['batch_size'],
                                shuffle=True)
    scaler = torch.amp.GradScaler(enabled=True)

    # Final training loop
    best_f1 = 0
    patience_counter = 0
    
    final_model.train()
    for epoch in range(training_results['best_epoch']):
        optimizer.zero_grad()
        for step, batch in enumerate(full_loader):
            inputs = {k: v.to(device, non_blocking=True) for k, v in batch.items() if k != 'labels'}
            labels = batch['labels'].to(device, non_blocking=True)

            with autocast(device_type=Config.DEVICE, dtype=torch.float16):
                outputs = final_model(**inputs)
                loss = torch.nn.BCEWithLogitsLoss()(outputs.logits, labels)
                loss = loss / Config.ACCUMULATION_STEPS
            
            scaler.scale(loss).backward()

            if (step + 1) % Config.ACCUMULATION_STEPS == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(final_model.parameters(), 1.0)
                scaler.step(optimizer)
                scaler.update()
                optimizer.zero_grad()
    
    # Save final model
    final_model.save_pretrained(Config.MODEL_SAVE_PATH)
    tokenizer.save_pretrained(Config.MODEL_SAVE_PATH)

In [None]:
file = './data/train.csv'
if os.path.exists(file):
    os.remove(file)
else:
    raise FileNotFoundError(f"File {file} not found!")