In [None]:
# Import Libraries

'''
Load all the packages we need for BERT training, data handling, and evaluation.
'''

import re
import numpy as np
import pandas as pd 
import os
import torch.nn as nn
import torch
import transformers

from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from sklearn import model_selection
from sklearn import metrics
from torch.optim import AdamW
from transformers import BertTokenizer, BertModel, get_linear_schedule_with_warmup


In [None]:
# Configuration Settings

'''
Define all our hyperparameters: sequence length, batch sizes, number of epochs, and learning rate.
'''

MAX_LEN = 128             
TRAIN_BATCH_SIZE = 16
VALID_BATCH_SIZE = 8       
EPOCHS = 3
BERT_PATH = "bert-base-uncased"
MODEL_PATH = "bert_model.bin"
LEARNING_RATE = 3e-5

In [None]:
# Dataset and Model Classes

'''
BERTDataset handles tokenization and data loading. 
BERTBaseUncased is our BERT model with a classification head.
'''

class BERTDataset(Dataset):
    
    def __init__(self, comment_text, target):
        
        self.comment_text = comment_text
        self.target = target
        self.tokenizer = BertTokenizer.from_pretrained(BERT_PATH, do_lower_case=True)
        self.max_len = MAX_LEN
    
    def __len__(self):
        
        return len(self.comment_text)                                                   # Return how many samples there are
    
    def __getitem__(self, item):
        
        comment_text = str(self.comment_text[item])                                     # Convert to string
        comment_text = " ".join(comment_text.split())                                   # Clean up whitespace
        
        inputs = self.tokenizer.encode_plus(
            comment_text,
            None,
            add_special_tokens=True,                                                    # Add CLS and SEP tokens
            max_length=self.max_len,                                                    # Pad or truncate to this length
            padding='max_length',                                                       # Pad shorter sequences
            truncation=True,                                                            # Cut off longer sequences
            return_attention_mask=True,                                                 # Tell BERT which tokens are real vs padding
            return_tensors='pt'                                                         # Return as PyTorch tensors
        )
        
        return {
            'ids': inputs['input_ids'].flatten(),
            'mask': inputs['attention_mask'].flatten(),
            'targets': torch.tensor(self.target[item], dtype=torch.float)
        }

class BERTBaseUncased(nn.Module):

    def __init__(self):
        
        super(BERTBaseUncased, self).__init__()
        self.bert = BertModel.from_pretrained(BERT_PATH)
        self.bert_drop = nn.Dropout(0.3)                                                # Dropout to prevent overfitting
        self.out = nn.Linear(768, 1)                                                    # BERT outputs 768 dimensions with 1 output (toxic or not)
    
    def forward(self, ids, mask):
        
        outputs = self.bert(ids, attention_mask=mask)
        
        
        pooled_output = outputs.pooler_output                                           # Use CLS token representation
        
        bo = self.bert_drop(pooled_output)
        output = self.out(bo)
        return output

In [None]:
# Train and evaluate functions

'''
train_fn runs one epoch of training.
eval_fn evaluates the model on validation data without updating weights.
'''

def train_fn(data_loader, model, optimizer, device, scheduler):

    model.train()
    total_loss = 0
    
    for bi, d in enumerate(tqdm(data_loader, desc='Training')):
        ids = d["ids"].to(device)
        mask = d["mask"].to(device)
        targets = d["targets"].to(device)
        
        optimizer.zero_grad()                                             # Zero out gradients from last batch
        outputs = model(ids=ids, mask=mask)                               # Forward pass get predictions
        
        loss_fn = nn.BCEWithLogitsLoss()                                  # Calculate loss
        loss = loss_fn(outputs, targets.view(-1, 1))
        
        loss.backward()                                                   # Calculate gradients
        optimizer.step()                                                  # Update model weights
        scheduler.step()                                                  # Update learning rate schedule
        
        total_loss += loss.item()                                         # Track total loss
        
        if bi % 100 == 0:                                                 # Double check progress
            print(f"Batch {bi}, Loss: {loss.item():.4f}")
    
    return total_loss / len(data_loader)                                  # Return average loss for the current epoch


def eval_fn(data_loader, model, device):

    model.eval()
    fin_targets = []
    fin_outputs = []

    with torch.no_grad():                                                # Don't calculate gradients, just evaluating
        for bi, d in enumerate(tqdm(data_loader, desc='Evaluating')):
            ids = d["ids"].to(device)
            mask = d["mask"].to(device)
            targets = d["targets"].to(device)
            
            outputs = model(ids=ids, mask=mask)                          # Get predictions
            
            fin_targets.extend(targets.cpu().detach().numpy().tolist())  # Collect all predictions and labels, convert to lists, move to CPU for calculations
            fin_outputs.extend(torch.sigmoid(outputs).cpu().detach().numpy().tolist()) # Apply signmoid to convert logits to probabilities (0 to 1)
    
    return fin_outputs, fin_targets

In [None]:
# Text preprocessing

'''
Removes HTML tags, URLs, and extra whitespace from comments.
'''


def clean_text(text):

    
    text = re.sub(r'<.*?>', '', text)          # Remove HTML tags
    text = re.sub(r'http\S+|www\S+', '', text) # Remove URLs
    text = ' '.join(text.split())              # Remove extra whitespace
    
    return text

In [None]:
# Training model

'''
Loads data, preprocesses it, trains BERT for 3 epochs, and saves the best model based on accuracy.
'''


def main():
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"\nUsing device: {device}")
    
    if torch.cuda.is_available():
        print(f"GPU: {torch.cuda.get_device_name(0)}")
    
    
    df = pd.read_csv("/kaggle/input/jigsaw-toxic-comment-classification-challenge/train.csv.zip", usecols=["comment_text", "toxic"]) # Load data from Kaggle
    
    print(f"Before cleaning: {len(df)}")
    
    # Preprocessing
    df = df.dropna(subset=['comment_text', 'toxic'])                 # Remove rows with missing data
    df = df[df['comment_text'].str.strip() != '']                    # Remove rows that are empty/whitespace
    df = df[df['comment_text'].str.len() < 5000]                     # Remove rows with super long comments
    df['comment_text'] = df['comment_text'].apply(clean_text)        # Apply text cleaning to each row
    df = df[df['comment_text'].str.strip() != '']                    # Remove rows that became empty after cleaning
    
    print(f"After cleaning: {len(df)}")
    
    df = df.sample(n=75000, random_state=42)                         # Train from 75k comments instead of the full 160k
    print(f"Using {len(df)} samples")
    
    
    df_train, df_valid = model_selection.train_test_split(           # Train/validation split, 90% train, 10% validation
        df,
        test_size=0.1,
        random_state=42,
        stratify=df.toxic.values
    )
    
    print(f"\nTraining samples: {len(df_train)}")
    print(f"Validation samples: {len(df_valid)}")
    
    # Create datasets
    train_dataset = BERTDataset(
        comment_text=df_train.comment_text.values,
        target=df_train.toxic.values
    )
    
    valid_dataset = BERTDataset(
        comment_text=df_valid.comment_text.values,
        target=df_valid.toxic.values
    )

    # Create data loaders
    train_data_loader = DataLoader(
        train_dataset,
        batch_size=TRAIN_BATCH_SIZE,
        shuffle=True,
        num_workers=0                                             # Set to 0 for Kaggle compatibility
    )
    
    valid_data_loader = DataLoader(
        valid_dataset,
        batch_size=VALID_BATCH_SIZE,
        shuffle=False,
        num_workers=0
    )
    
    
    model = BERTBaseUncased()                                    # Create model
    model.to(device)
    
    print(f"Total parameters: {sum(p.numel() for p in model.parameters()):,}")
    
    
    param_optimizer = list(model.named_parameters())             # Optimizer setup
    no_decay = ["bias", "LayerNorm.bias", "LayerNorm.weight"]
    optimizer_parameters = [
        {"params": [p for n, p in param_optimizer if not any(nd in n for nd in no_decay)],"weight_decay": 0.001,},
        {"params": [p for n, p in param_optimizer if any(nd in n for nd in no_decay)],"weight_decay": 0.0,},
    ]
    
    num_train_steps = int(len(df_train) / TRAIN_BATCH_SIZE * EPOCHS)  # Calculate the total training steps for the learning rate scheduler
    
    optimizer = AdamW(optimizer_parameters, lr=LEARNING_RATE)         # Using AdamW as it works well with transformers
    
    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=0,
        num_training_steps=num_train_steps
    )
    
    # Training loop
    
    best_accuracy = 0
    
    for epoch in range(EPOCHS):
        
        print(f"\nEpoch {epoch + 1}/{EPOCHS}")
        
        
        train_loss = train_fn(train_data_loader, model, optimizer, device, scheduler) # Train for one epoch
        print(f"Average training loss: {train_loss:.4f}")
        
        
        outputs, targets = eval_fn(valid_data_loader, model, device)                  # Evaluate on validation set
        
        outputs = np.array(outputs) >= 0.5                                            # Convert probabilities to binary predictions, anything >= .5 is toxic

        # Calculate all of the needed metrics
        accuracy = metrics.accuracy_score(targets, outputs)
        precision = metrics.precision_score(targets, outputs)
        recall = metrics.recall_score(targets, outputs)
        f1 = metrics.f1_score(targets, outputs)
        auc = metrics.roc_auc_score(targets, outputs)
        
        print(f"\nValidation Metrics:")
        print(f"  Accuracy:  {accuracy*100:.2f}%")
        print(f"  Precision: {precision*100:.2f}%")
        print(f"  Recall:    {recall*100:.2f}%")
        print(f"  F1-Score:  {f1:.4f}")
        print(f"  AUC:       {auc:.4f}")

        # Save the best model
        if accuracy > best_accuracy:
            torch.save({
                'model_state_dict': model.state_dict(),
                'config': {
                    'max_len': MAX_LEN,
                    'bert_path': BERT_PATH
                },
                'metrics': {
                    'accuracy': accuracy,
                    'precision': precision,
                    'recall': recall,
                    'f1_score': f1,
                    'auc': auc
                }
            }, MODEL_PATH)
            print(f"Best model (Accuracy: {accuracy*100:.2f}%)")
            best_accuracy = accuracy
    
    print(f"Best Accuracy: {best_accuracy*100:.2f}%")
    print(f"Model saved to: {MODEL_PATH}")
    

if __name__ == "__main__":
    main()