<a href="https://colab.research.google.com/github/Narendiran100/text-gaurdrail-model/blob/master/toxic_classifier_sst.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Toxic Text Classification using DistilBERT SST-2

This notebook implements a text classifier to identify toxic content using the DistilBERT model fine-tuned on SST-2 dataset. The model is further fine-tuned on the Jigsaw Toxic Comment Classification dataset.


## 1. Setup and Dependencies

Install required packages and configure the environment

In [None]:
# Install required packages
!pip install transformers datasets torch codecarbon pandas numpy matplotlib seaborn

In [None]:
# Import required libraries
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments
)
from datasets import load_dataset
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from codecarbon import EmissionsTracker
import matplotlib.pyplot as plt
import seaborn as sns
from google.colab import drive

# Mount Google Drive for saving model checkpoints
drive.mount('/content/drive')

# Create project directories
!mkdir -p '/content/drive/MyDrive/toxic_classifier_sst/models'
!mkdir -p '/content/drive/MyDrive/toxic_classifier_sst/logs'

## 2. Data Loading and Preprocessing

Load the Jigsaw dataset and prepare it for training

In [None]:
# Load the dataset
dataset = load_dataset("thesofakillers/jigsaw-toxic-comment-classification-challenge")

def preprocess_data(examples):
    """Convert multi-label toxic classification to binary labels.

    Args:
        examples: Dataset examples containing comment_text and toxicity labels

    Returns:
        dict: Processed examples with text and binary labels
    """
    # Define toxicity types
    toxicity_types = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']

    # Create binary labels (1 for toxic, 0 for safe)
    labels = []
    for i in range(len(examples['comment_text'])):
        is_toxic = any(examples[col][i] == 1 for col in toxicity_types)
        labels.append(1 if is_toxic else 0)

    return {
        'text': examples['comment_text'],
        'label': labels
    }

# Apply preprocessing
dataset = dataset.map(preprocess_data, batched=True)

# Split dataset
train_test = dataset['train'].train_test_split(test_size=0.2)
train_val = train_test['train'].train_test_split(test_size=0.1)

train_dataset = train_val['train']
val_dataset = train_val['test']
test_dataset = train_test['test']

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

## 3. Model Setup and Training

Initialize the SST-2 model and configure training parameters

In [None]:
# Initialize tokenizer and model
model_name = "distilbert/distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

def tokenize_function(examples):
    """Tokenize the input texts.

    Args:
        examples: Dataset examples containing text field

    Returns:
        dict: Tokenized inputs
    """
    return tokenizer(
        examples['text'],
        padding='max_length',
        truncation=True,
        max_length=128
    )

# Tokenize datasets
train_dataset = train_dataset.map(tokenize_function, batched=True)
val_dataset = val_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)

def compute_metrics(pred):
    """Calculate evaluation metrics.

    Args:
        pred: Prediction object containing predictions and label_ids

    Returns:
        dict: Computed metrics
    """
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

# Set training arguments with CO2 tracking
training_args = TrainingArguments(
    output_dir='/content/drive/MyDrive/toxic_classifier_sst/models',
    do_eval=True,
    per_device_train_batch_size=16,  # Smaller batch size for Colab
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    learning_rate=2e-5,
    weight_decay=0.01,
    logging_dir='/content/drive/MyDrive/toxic_classifier_sst/logs',
    logging_steps=100,
    save_total_limit=3,
    save_steps=500,
    eval_steps=500,
    eval_strategy='steps',
    load_best_model_at_end=True,
    metric_for_best_model='f1',
    fp16=True,  # Mixed precision training
    gradient_accumulation_steps=2,  # Gradient accumulation for larger effective batch size
    report_to=["codecarbon","wandb"],  # Track CO2 emissions
)

In [None]:
# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

# Start CO2 emissions tracker
tracker = EmissionsTracker(project_name="toxic_classifier_sst")
tracker.start()

# Train the model
trainer.train()

# Stop tracking emissions
emissions = tracker.stop()
print(f"Total CO2 emissions: {emissions} kg")

## 4. Model Evaluation and Analysis

In [None]:
# Evaluate on test set
test_results = trainer.evaluate(test_dataset)
print("\nTest Results:")
for key, value in test_results.items():
    print(f"{key}: {value:.4f}")

# Save model and tokenizer
save_path = '/content/drive/MyDrive/toxic_classifier_sst/models/final_model'
trainer.save_model(save_path)
tokenizer.save_pretrained(save_path)

## 5. Inference Examples and Integration Demo

In [None]:
class TextGuardrail:
    """Text classification model for detecting toxic content."""

    def __init__(self, model_path):
        """Initialize the model.

        Args:
            model_path: Path to the saved model
        """
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
        self.model.eval()

    def check_text(self, text):
        """Check if input text is toxic.

        Args:
            text: Input text to classify

        Returns:
            dict: Classification results
        """
        inputs = self.tokenizer(
            text,
            truncation=True,
            max_length=128,
            return_tensors="pt"
        )

        with torch.no_grad():
            outputs = self.model(**inputs)
            predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)

        label = "unsafe" if predictions[0][1] > 0.5 else "safe"
        confidence = float(predictions[0][1] if label == "unsafe" else predictions[0][0])

        return {
            "text": text,
            "label": label,
            "confidence": f"{confidence:.2%}"
        }

# Initialize guardrail with saved model
guardrail = TextGuardrail("/content/drive/MyDrive/toxic_classifier_sst/models/final_model_sst")

# Test examples
test_texts = [
    "Hello, how are you today?",
    "I really enjoyed the movie!",
    "You're so stupid and worthless",
    "Let's work together to solve this problem",
    "I hate you"
]

# Test the model
for text in test_texts:
    result = guardrail.check_text(text)
    print(f"\nInput: {result['text']}")
    print(f"Classification: {result['label']}")
    print(f"Confidence: {result['confidence']}")

## Real-time Integration Example

Here's how the model could be integrated into a real-time system:

In [None]:
def process_prompt(prompt, guardrail, threshold=0.5):
    """Process a prompt through the toxicity filter.

    Args:
        prompt: User input prompt
        guardrail: TextGuardrail instance
        threshold: Confidence threshold for rejection

    Returns:
        tuple: (bool, str) - (is_safe, message)
    """
    result = guardrail.check_text(prompt)
    confidence = float(result['confidence'].strip('%')) / 100

    if result['label'] == 'unsafe' and confidence > threshold:
        return False, "This prompt contains potentially harmful content and cannot be processed."
    return True, "Prompt is safe for processing."

# Example usage in a prompt processing pipeline
def prompt_pipeline(user_input):
    """Example prompt processing pipeline.

    Args:
        user_input: Raw user input

    Returns:
        str: Response message
    """
    # Check prompt safety
    is_safe, message = process_prompt(user_input, guardrail)

    if not is_safe:
        return message

    # If safe, continue with normal processing
    return "Processing your safe prompt: " + user_input

# Test the pipeline
test_prompts = [
    "Write a poem about spring flowers",
    "You're all worthless and should die",
    "Help me solve this math problem",
    "I will find you and hurt you",
]

for prompt in test_prompts:
    print(f"\nInput: {prompt}")
    print(f"Result: {prompt_pipeline(prompt)}")

## Potential Extensions

1. Real-time System Integration:
   - Deploy as a REST API using FastAPI/Flask
   - Implement caching for frequent prompts
   - Add rate limiting and request queuing

2. Performance Improvements:
   - Model quantization for faster inference
   - Batch processing for multiple prompts
   - GPU acceleration in production

3. Enhanced Features:
   - Confidence threshold adjustment
   - Multi-language support
   - Toxicity category detection
   - Feedback loop for continuous improvement

4. Monitoring and Maintenance:
   - Track false positives/negatives
   - Monitor model drift
   - Regular retraining with new data
   - A/B testing for improvements