<a href="https://colab.research.google.com/github/Narendiran100/text-gaurdrail-model/blob/master/toxic_classifier_sst.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Toxic Text Classification using DistilBERT SST-2

This notebook implements a text classifier to identify toxic content using the DistilBERT model fine-tuned on SST-2 dataset. The model is further fine-tuned on the Jigsaw Toxic Comment Classification dataset.


## 1. Setup and Dependencies

Install required packages and configure the environment

In [None]:
# Install required packages
!pip install transformers datasets torch codecarbon pandas numpy matplotlib seaborn

Collecting codecarbon
  Downloading codecarbon-3.0.2-py3-none-any.whl.metadata (9.1 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collec

In [None]:
# Import required libraries
import torch
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments
)
from datasets import load_dataset
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from codecarbon import EmissionsTracker
import matplotlib.pyplot as plt
import seaborn as sns
from google.colab import drive

# Mount Google Drive for saving model checkpoints
drive.mount('/content/drive')

# Create project directories
!mkdir -p '/content/drive/MyDrive/toxic_classifier_sst/models'
!mkdir -p '/content/drive/MyDrive/toxic_classifier_sst/logs'

Mounted at /content/drive


## 2. Data Loading and Preprocessing

Load the Jigsaw dataset and prepare it for training

In [None]:
# Load the dataset
dataset = load_dataset("thesofakillers/jigsaw-toxic-comment-classification-challenge")

def preprocess_data(examples):
    """Convert multi-label toxic classification to binary labels.

    Args:
        examples: Dataset examples containing comment_text and toxicity labels

    Returns:
        dict: Processed examples with text and binary labels
    """
    # Define toxicity types
    toxicity_types = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']

    # Create binary labels (1 for toxic, 0 for safe)
    labels = []
    for i in range(len(examples['comment_text'])):
        is_toxic = any(examples[col][i] == 1 for col in toxicity_types)
        labels.append(1 if is_toxic else 0)

    return {
        'text': examples['comment_text'],
        'label': labels
    }

# Apply preprocessing
dataset = dataset.map(preprocess_data, batched=True)

# Split dataset
train_test = dataset['train'].train_test_split(test_size=0.2)
train_val = train_test['train'].train_test_split(test_size=0.1)

train_dataset = train_val['train']
val_dataset = train_val['test']
test_dataset = train_test['test']

print(f"Training samples: {len(train_dataset)}")
print(f"Validation samples: {len(val_dataset)}")
print(f"Test samples: {len(test_dataset)}")

Training samples: 114890
Validation samples: 12766
Test samples: 31915


## 3. Model Setup and Training

Initialize the SST-2 model and configure training parameters

In [None]:
# Initialize tokenizer and model
model_name = "distilbert/distilbert-base-uncased-finetuned-sst-2-english"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

def tokenize_function(examples):
    """Tokenize the input texts.

    Args:
        examples: Dataset examples containing text field

    Returns:
        dict: Tokenized inputs
    """
    return tokenizer(
        examples['text'],
        padding='max_length',
        truncation=True,
        max_length=128
    )

# Tokenize datasets
train_dataset = train_dataset.map(tokenize_function, batched=True)
val_dataset = val_dataset.map(tokenize_function, batched=True)
test_dataset = test_dataset.map(tokenize_function, batched=True)

def compute_metrics(pred):
    """Calculate evaluation metrics.

    Args:
        pred: Prediction object containing predictions and label_ids

    Returns:
        dict: Computed metrics
    """
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='binary')
    acc = accuracy_score(labels, preds)
    return {
        'accuracy': acc,
        'f1': f1,
        'precision': precision,
        'recall': recall
    }

# Set training arguments with CO2 tracking
training_args = TrainingArguments(
    output_dir='/content/drive/MyDrive/toxic_classifier_sst/models',
    do_eval=True,
    per_device_train_batch_size=16,  # Smaller batch size for Colab
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    learning_rate=2e-5,
    weight_decay=0.01,
    logging_dir='/content/drive/MyDrive/toxic_classifier_sst/logs',
    logging_steps=100,
    save_total_limit=3,
    save_steps=500,
    eval_steps=500,
    eval_strategy='steps',
    load_best_model_at_end=True,
    metric_for_best_model='f1',
    fp16=True,  # Mixed precision training
    gradient_accumulation_steps=2,  # Gradient accumulation for larger effective batch size
    report_to=["codecarbon","wandb"],  # Track CO2 emissions
)

Map:   0%|          | 0/114890 [00:00<?, ? examples/s]

Map:   0%|          | 0/12766 [00:00<?, ? examples/s]

Map:   0%|          | 0/31915 [00:00<?, ? examples/s]

In [None]:
# Initialize trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    compute_metrics=compute_metrics
)

# Start CO2 emissions tracker
tracker = EmissionsTracker(project_name="toxic_classifier_sst")
tracker.start()

# Train the model
trainer.train()

# Stop tracking emissions
emissions = tracker.stop()
print(f"Total CO2 emissions: {emissions} kg")

[codecarbon INFO @ 15:59:44] [setup] RAM Tracking...
[codecarbon INFO @ 15:59:44] [setup] CPU Tracking...
 Linux OS detected: Please ensure RAPL files exist at /sys/class/powercap/intel-rapl/subsystem to measure CPU

[codecarbon INFO @ 15:59:45] CPU Model on constant consumption mode: Intel(R) Xeon(R) CPU @ 2.00GHz
[codecarbon INFO @ 15:59:45] [setup] GPU Tracking...
[codecarbon INFO @ 15:59:45] Tracking Nvidia GPU via pynvml
[codecarbon INFO @ 15:59:45] The below tracking methods have been set up:
                RAM Tracking Method: RAM power estimation model
                CPU Tracking Method: global constant
                GPU Tracking Method: pynvml
            
[codecarbon INFO @ 15:59:45] >>> Tracker's metadata:
[codecarbon INFO @ 15:59:45]   Platform system: Linux-6.1.123+-x86_64-with-glibc2.35
[codecarbon INFO @ 15:59:45]   Python version: 3.11.13
[codecarbon INFO @ 15:59:45]   CodeCarbon version: 3.0.2
[codecarbon INFO @ 15:59:45]   Available RAM : 12.674 GB
[codecarbon INF

<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:[codecarbon INFO @ 16:00:02] Energy consumed for RAM : 0.000042 kWh. RAM Power : 10.0 W
[codecarbon INFO @ 16:00:02] Delta energy consumed for CPU with constant : 0.000177 kWh, power : 42.5 W
[codecarbon INFO @ 16:00:02] Energy consumed for All CPU : 0.000177 kWh
[codecarbon INFO @ 16:00:02] Energy consumed for all GPUs : 0.000126 kWh. Total GPU Power : 30.167394973285788 W
[codecarbon INFO @ 16:00:02] 0.000345 kWh of electricity used since the beginning.
[codecarbon INFO @ 16:00:02] Energy consumed for RAM : 0.000042 kWh. RAM Power : 10.0 W
[codecarbon INFO @ 16:00:02] Delta energy consumed for CPU with constant : 0.000177 kWh, power : 42.5 W
[codecarbon INFO @ 16:00:02] Energy consumed for All CPU : 0.000177 k

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mnarendiran100[0m ([33mnarendiran100-self-learning-skills[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss,Accuracy,F1,Precision,Recall
500,0.126,0.092884,0.966238,0.827944,0.857025,0.800772
1000,0.1009,0.086161,0.969137,0.834175,0.916744,0.765251
1500,0.1074,0.093642,0.965768,0.810577,0.923913,0.722008
2000,0.0854,0.085545,0.967335,0.844693,0.815827,0.875676
2500,0.0952,0.081892,0.968353,0.844015,0.844015,0.844015
3000,0.0942,0.090487,0.96381,0.832609,0.7843,0.887259
3500,0.0926,0.078649,0.968275,0.846533,0.831101,0.862548
4000,0.079,0.087887,0.971643,0.848282,0.927589,0.781467
4500,0.0669,0.0885,0.970077,0.848533,0.872046,0.826255
5000,0.0599,0.089769,0.971252,0.855796,0.8712,0.840927


[codecarbon INFO @ 16:00:17] Energy consumed for RAM : 0.000083 kWh. RAM Power : 10.0 W
[codecarbon INFO @ 16:00:17] Delta energy consumed for CPU with constant : 0.000177 kWh, power : 42.5 W
[codecarbon INFO @ 16:00:17] Energy consumed for All CPU : 0.000354 kWh
[codecarbon INFO @ 16:00:17] Energy consumed for all GPUs : 0.000274 kWh. Total GPU Power : 35.49956555307421 W
[codecarbon INFO @ 16:00:17] 0.000711 kWh of electricity used since the beginning.
[codecarbon INFO @ 16:00:17] Energy consumed for RAM : 0.000083 kWh. RAM Power : 10.0 W
[codecarbon INFO @ 16:00:17] Delta energy consumed for CPU with constant : 0.000177 kWh, power : 42.5 W
[codecarbon INFO @ 16:00:17] Energy consumed for All CPU : 0.000354 kWh
[codecarbon INFO @ 16:00:17] Energy consumed for all GPUs : 0.000279 kWh. Total GPU Power : 37.00024895180937 W
[codecarbon INFO @ 16:00:17] 0.000717 kWh of electricity used since the beginning.
[codecarbon INFO @ 16:00:32] Energy consumed for RAM : 0.000125 kWh. RAM Power : 1

Total CO2 emissions: 0.021124745670366058 kg


## 4. Model Evaluation and Analysis

In [None]:
# Evaluate on test set
test_results = trainer.evaluate(test_dataset)
print("\nTest Results:")
for key, value in test_results.items():
    print(f"{key}: {value:.4f}")

# Save model and tokenizer
save_path = '/content/drive/MyDrive/toxic_classifier_sst/models/final_model'
trainer.save_model(save_path)
tokenizer.save_pretrained(save_path)


Test Results:
eval_loss: 0.1010
eval_accuracy: 0.9668
eval_f1: 0.8315
eval_precision: 0.8476
eval_recall: 0.8161
eval_runtime: 33.6918
eval_samples_per_second: 947.2630
eval_steps_per_second: 59.2130
epoch: 3.0000


('/content/drive/MyDrive/toxic_classifier_sst/models/final_model/tokenizer_config.json',
 '/content/drive/MyDrive/toxic_classifier_sst/models/final_model/special_tokens_map.json',
 '/content/drive/MyDrive/toxic_classifier_sst/models/final_model/vocab.txt',
 '/content/drive/MyDrive/toxic_classifier_sst/models/final_model/added_tokens.json',
 '/content/drive/MyDrive/toxic_classifier_sst/models/final_model/tokenizer.json')

## 5. Inference Examples and Integration Demo

In [None]:
class TextGuardrail:
    """Text classification model for detecting toxic content."""

    def __init__(self, model_path):
        """Initialize the model.

        Args:
            model_path: Path to the saved model
        """
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
        self.model.eval()

    def check_text(self, text):
        """Check if input text is toxic.

        Args:
            text: Input text to classify

        Returns:
            dict: Classification results
        """
        inputs = self.tokenizer(
            text,
            truncation=True,
            max_length=128,
            return_tensors="pt"
        )

        with torch.no_grad():
            outputs = self.model(**inputs)
            predictions = torch.nn.functional.softmax(outputs.logits, dim=-1)

        label = "unsafe" if predictions[0][1] > 0.5 else "safe"
        confidence = float(predictions[0][1] if label == "unsafe" else predictions[0][0])

        return {
            "text": text,
            "label": label,
            "confidence": f"{confidence:.2%}"
        }

# Initialize guardrail with saved model
guardrail = TextGuardrail("/content/drive/MyDrive/toxic_classifier_sst/models/final_model_sst")

# Test examples
test_texts = [
    "Hello, how are you today?",
    "I really enjoyed the movie!",
    "You're so stupid and worthless",
    "Let's work together to solve this problem",
    "I hate you"
]

# Test the model
for text in test_texts:
    result = guardrail.check_text(text)
    print(f"\nInput: {result['text']}")
    print(f"Classification: {result['label']}")
    print(f"Confidence: {result['confidence']}")


Input: Hello, how are you today?
Classification: safe
Confidence: 99.94%

Input: I really enjoyed the movie!
Classification: safe
Confidence: 99.96%

Input: You're so stupid and worthless
Classification: unsafe
Confidence: 99.96%

Input: Let's work together to solve this problem
Classification: safe
Confidence: 99.98%

Input: I hate you
Classification: unsafe
Confidence: 99.94%


## Real-time Integration Example

Here's how the model could be integrated into a real-time system:

In [None]:
def process_prompt(prompt, guardrail, threshold=0.5):
    """Process a prompt through the toxicity filter.

    Args:
        prompt: User input prompt
        guardrail: TextGuardrail instance
        threshold: Confidence threshold for rejection

    Returns:
        tuple: (bool, str) - (is_safe, message)
    """
    result = guardrail.check_text(prompt)
    confidence = float(result['confidence'].strip('%')) / 100

    if result['label'] == 'unsafe' and confidence > threshold:
        return False, "This prompt contains potentially harmful content and cannot be processed."
    return True, "Prompt is safe for processing."

# Example usage in a prompt processing pipeline
def prompt_pipeline(user_input):
    """Example prompt processing pipeline.

    Args:
        user_input: Raw user input

    Returns:
        str: Response message
    """
    # Check prompt safety
    is_safe, message = process_prompt(user_input, guardrail)

    if not is_safe:
        return message

    # If safe, continue with normal processing
    return "Processing your safe prompt: " + user_input

# Test the pipeline
test_prompts = [
    "Write a poem about spring flowers",
    "You're all worthless and should die",
    "Help me solve this math problem",
    "I will find you and hurt you",
]

for prompt in test_prompts:
    print(f"\nInput: {prompt}")
    print(f"Result: {prompt_pipeline(prompt)}")


Input: Write a poem about spring flowers
Result: Processing your safe prompt: Write a poem about spring flowers

Input: You're all worthless and should die
Result: This prompt contains potentially harmful content and cannot be processed.

Input: Help me solve this math problem
Result: Processing your safe prompt: Help me solve this math problem

Input: I will find you and hurt you
Result: This prompt contains potentially harmful content and cannot be processed.


## Potential Extensions

1. Real-time System Integration:
   - Deploy as a REST API using FastAPI/Flask
   - Implement caching for frequent prompts
   - Add rate limiting and request queuing

2. Performance Improvements:
   - Model quantization for faster inference
   - Batch processing for multiple prompts
   - GPU acceleration in production

3. Enhanced Features:
   - Confidence threshold adjustment
   - Multi-language support
   - Toxicity category detection
   - Feedback loop for continuous improvement

4. Monitoring and Maintenance:
   - Track false positives/negatives
   - Monitor model drift
   - Regular retraining with new data
   - A/B testing for improvements