In [34]:
import pandas as pd
import numpy as np
import time
from typing import Callable, List, Dict, Tuple
import json
import os
from portkey_ai import Portkey

client = Portkey(api_key="...")

def convert_to_dict(obj):
    if isinstance(obj, (int, float, str, bool)):
        return obj
    elif isinstance(obj, dict):
        return {k: convert_to_dict(v) for k, v in obj.items()}
    elif isinstance(obj, list):
        return [convert_to_dict(v) for v in obj]
    elif hasattr(obj, '__dict__'):
        return {k: convert_to_dict(v) for k, v in obj.__dict__.items()}
    else:
        return str(obj)

def openai_moderation_api_batch(texts: List[str], model: str, dataset_name: str) -> List[Dict]:
    result = client.with_options(metadata={"_user": dataset_name}).moderations.create(input=texts, model=model)
    return [
        {
            'text': text,
            'flagged': convert_to_dict(r.flagged),
            'categories': convert_to_dict(r.categories),
            'category_scores': convert_to_dict(r.category_scores)
        } for text, r in zip(texts, result.results)
    ]

def load_dataset(filepath: str) -> pd.DataFrame:
    df = pd.read_csv(filepath)
    return df

def collect_data(api_func: Callable, model: str, texts: List[str], true_labels: List[bool], dataset_name: str) -> Tuple[List[Dict], List[float]]:
    batch_size = 20
    all_results = []
    latencies = []
    
    for i in range(0, min(len(texts),1000), batch_size):
        batch_texts = texts[i:i+batch_size]
        start_time = time.time()
        batch_results = api_func(batch_texts, model, dataset_name)
        end_time = time.time()
        
        all_results.extend(batch_results)
        latencies.append(end_time - start_time)
    
    return all_results, latencies

def run_data_collection(apis: List[Tuple[str, Callable, str]], datasets: List[Tuple[str, str]]) -> Dict:
    all_data = {}
    
    for api_name, api_func, api_model in apis:
        api_data = {}
        for dataset_name, dataset_path in datasets:
            print(f"\nCollecting data for {api_name} on {dataset_name}")
            df = load_dataset(dataset_path)
            texts = df['text'].tolist()
            true_labels = df['is_toxic'].tolist()
            
            results, latencies = collect_data(api_func, api_model, texts, true_labels, dataset_name)
            api_data[dataset_name] = {
                'results': results,
                'latencies': latencies,
                'true_labels': true_labels
            }
        
        all_data[api_name] = api_data
    
    return all_data

def main():
    apis = [
        ('Omni', openai_moderation_api_batch, "omni-moderation-latest"),
        ('Legacy', openai_moderation_api_batch, "text-moderation-latest")
    ]
    
    datasets = [
        ('Jigsaw', 'datasets/toxic.csv'),
        ('HateXplain', 'datasets/hatexplain.csv'),
        ('Multilingual', 'datasets/multi.csv'),
    ]
    
    collected_data = run_data_collection(apis, datasets)
    
    # Save collected data to a file
    with open('moderation_data.json', 'w') as f:
        json.dump(collected_data, f)
    
    print("Data collection complete. Results saved to moderation_data.json")

if __name__ == "__main__":
    main()


Collecting data for Omni on Jigsaw

Collecting data for Omni on HateXplain

Collecting data for Omni on Multilingual

Collecting data for Legacy on Jigsaw

Collecting data for Legacy on HateXplain

Collecting data for Legacy on Multilingual
Data collection complete. Results saved to moderation_data.json


In [43]:
import json
import numpy as np
from typing import List, Dict
import random

def calculate_metrics(true_labels: List[bool], predictions: List[bool]) -> Dict:
    tp = sum(1 for t, p in zip(true_labels, predictions) if t and p)
    fp = sum(1 for t, p in zip(true_labels, predictions) if not t and p)
    tn = sum(1 for t, p in zip(true_labels, predictions) if not t and not p)
    fn = sum(1 for t, p in zip(true_labels, predictions) if t and not p)
    
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1 = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    false_positive_rate = fp / (fp + tn) if (fp + tn) > 0 else 0
    
    return {
        'f1_score': f1,
        'precision': precision,
        'recall': recall,
        'false_positive_rate': false_positive_rate
    }

def analyze_data(collected_data: Dict) -> Dict:
    results = {}
    for api_name, api_data in collected_data.items():
        api_results = {}
        for dataset_name, dataset_data in api_data.items():
            predictions = [r['flagged'] for r in dataset_data['results']]
            true_labels = dataset_data['true_labels']
            
            metrics = calculate_metrics(true_labels[:len(predictions)], predictions)
            metrics.update({
                'avg_latency': np.mean(dataset_data['latencies']) / 20,  # Divide by batch size
                'p95_latency': np.percentile(dataset_data['latencies'], 95) / 20,
            })
            
            api_results[dataset_name] = metrics
        
        results[api_name] = api_results
    
    return results

def compare_models(collected_data: Dict) -> Dict:
    comparisons = {}
    for dataset_name in collected_data['Omni'].keys():
        omni_results = collected_data['Omni'][dataset_name]['results']
        legacy_results = collected_data['Legacy'][dataset_name]['results']
        
        disagreements = []
        omni_more_sensitive = 0
        legacy_more_sensitive = 0
        
        for omni_result, legacy_result in zip(omni_results, legacy_results):
            if omni_result['flagged'] != legacy_result['flagged']:
                disagreement = {
                    'text': omni_result['text'],
                    'omni_flagged': omni_result['flagged'],
                    'legacy_flagged': legacy_result['flagged'],
                    'omni_scores': omni_result['category_scores'],
                    'legacy_scores': legacy_result['category_scores']
                }
                disagreements.append(disagreement)
                
                if omni_result['flagged']:
                    omni_more_sensitive += 1
                else:
                    legacy_more_sensitive += 1
        
        total = len(omni_results)
        comparisons[dataset_name] = {
            'total_samples': total,
            'disagreements': len(disagreements),
            'disagreement_rate': len(disagreements) / total if total > 0 else 0,
            'omni_more_sensitive': omni_more_sensitive,
            'legacy_more_sensitive': legacy_more_sensitive,
            'detailed_disagreements': disagreements[:10]  # Limit to first 10 for brevity
        }
    
    return comparisons

def omni_vs_legacy_analysis(collected_data: Dict) -> Dict:
    analysis = {}
    for dataset_name in collected_data['Omni'].keys():
        omni_results = collected_data['Omni'][dataset_name]['results']
        legacy_results = collected_data['Legacy'][dataset_name]['results']
        true_labels = collected_data['Omni'][dataset_name]['true_labels'][:len(omni_results)]
        
        false_positives = []  # Legacy flagged, Omni didn't
        false_negatives = []  # Omni flagged, Legacy didn't
        omni_correct = []  # Omni matches true label, Legacy doesn't
        legacy_correct = []  # Legacy matches true label, Omni doesn't
        
        for omni_result, legacy_result, true_label in zip(omni_results, legacy_results, true_labels):
            if omni_result['flagged'] and not legacy_result['flagged']:
                false_negatives.append({
                    'text': omni_result['text'],
                    'omni_scores': omni_result['category_scores'],
                    'legacy_scores': legacy_result['category_scores'],
                    'true_label': true_label
                })
            elif not omni_result['flagged'] and legacy_result['flagged']:
                false_positives.append({
                    'text': omni_result['text'],
                    'omni_scores': omni_result['category_scores'],
                    'legacy_scores': legacy_result['category_scores'],
                    'true_label': true_label
                })
            
            if omni_result['flagged'] == true_label and legacy_result['flagged'] != true_label:
                omni_correct.append({
                    'text': omni_result['text'],
                    'omni_scores': omni_result['category_scores'],
                    'legacy_scores': legacy_result['category_scores'],
                    'true_label': true_label
                })
            elif legacy_result['flagged'] == true_label and omni_result['flagged'] != true_label:
                legacy_correct.append({
                    'text': omni_result['text'],
                    'omni_scores': omni_result['category_scores'],
                    'legacy_scores': legacy_result['category_scores'],
                    'true_label': true_label
                })
        
        analysis[dataset_name] = {
            'total_samples': len(omni_results),
            'false_positives': len(false_positives),
            'false_negatives': len(false_negatives),
            'omni_correct': len(omni_correct),
            'legacy_correct': len(legacy_correct),
            'false_positive_rate': len(false_positives) / len(omni_results),
            'false_negative_rate': len(false_negatives) / len(omni_results),
            'omni_correct_rate': len(omni_correct) / len(omni_results),
            'legacy_correct_rate': len(legacy_correct) / len(omni_results),
            'false_positive_examples': random.sample(false_positives, min(3, len(false_positives))),
            'false_negative_examples': random.sample(false_negatives, min(3, len(false_negatives))),
            'omni_correct_examples': random.sample(omni_correct, min(3, len(omni_correct))),
            'legacy_correct_examples': random.sample(legacy_correct, min(3, len(legacy_correct)))
        }
    
    return analysis

def main():
    # Load collected data
    with open('moderation_data.json', 'r') as f:
        collected_data = json.load(f)
    
    # Analyze data
    results = analyze_data(collected_data)
    
    print("\nBenchmark Results:")
    print(json.dumps(results, indent=2))

    # Compare models
    model_comparisons = compare_models(collected_data)
    
    print("\nModel Comparison:")
    for dataset_name, comparison in model_comparisons.items():
        print(f"\n{dataset_name}:")
        print(f"  Total samples: {comparison['total_samples']}")
        print(f"  Disagreements: {comparison['disagreements']}")
        print(f"  Disagreement rate: {comparison['disagreement_rate']:.2%}")
        print(f"  Omni more sensitive: {comparison['omni_more_sensitive']}")
        print(f"  Legacy more sensitive: {comparison['legacy_more_sensitive']}")

    # Omni vs Legacy Analysis
    omni_legacy_analysis = omni_vs_legacy_analysis(collected_data)
    
    print("\nOmni vs Legacy Analysis:")
    for dataset_name, analysis in omni_legacy_analysis.items():
        print(f"\n{dataset_name}:")
        print(f"  Total samples: {analysis['total_samples']}")
        print(f"  False Positives (Legacy flagged, Omni didn't): {analysis['false_positives']}")
        print(f"  False Negatives (Omni flagged, Legacy didn't): {analysis['false_negatives']}")
        print(f"  Omni correct (Omni matches true label, Legacy doesn't): {analysis['omni_correct']}")
        print(f"  Legacy correct (Legacy matches true label, Omni doesn't): {analysis['legacy_correct']}")
        print(f"  False Positive Rate: {analysis['false_positive_rate']:.2%}")
        print(f"  False Negative Rate: {analysis['false_negative_rate']:.2%}")
        print(f"  Omni Correct Rate: {analysis['omni_correct_rate']:.2%}")
        print(f"  Legacy Correct Rate: {analysis['legacy_correct_rate']:.2%}")
        
        print("\n  False Positive Examples (Legacy flagged, Omni didn't):")
        for example in analysis['false_positive_examples']:
            print(f"    Text: {example['text']}")
            print(f"    True Label: {example['true_label']}")
            print("    Omni scores:")
            for category, score in example['omni_scores'].items():
                print(f"      {category}: {score}")
            print("    Legacy scores:")
            for category, score in example['legacy_scores'].items():
                print(f"      {category}: {score}")
            print()
        
        print("\n  False Negative Examples (Omni flagged, Legacy didn't):")
        for example in analysis['false_negative_examples']:
            print(f"    Text: {example['text']}")
            print(f"    True Label: {example['true_label']}")
            print("    Omni scores:")
            for category, score in example['omni_scores'].items():
                print(f"      {category}: {score}")
            print("    Legacy scores:")
            for category, score in example['legacy_scores'].items():
                print(f"      {category}: {score}")
            print()
        
        print("\n  Omni Correct Examples (Omni matches true label, Legacy doesn't):")
        for example in analysis['omni_correct_examples']:
            print(f"    Text: {example['text']}")
            print(f"    True Label: {example['true_label']}")
            print("    Omni scores:")
            for category, score in example['omni_scores'].items():
                print(f"      {category}: {score}")
            print("    Legacy scores:")
            for category, score in example['legacy_scores'].items():
                print(f"      {category}: {score}")
            print()
        
        print("\n  Legacy Correct Examples (Legacy matches true label, Omni doesn't):")
        for example in analysis['legacy_correct_examples']:
            print(f"    Text: {example['text']}")
            print(f"    True Label: {example['true_label']}")
            print("    Omni scores:")
            for category, score in example['omni_scores'].items():
                print(f"      {category}: {score}")
            print("    Legacy scores:")
            for category, score in example['legacy_scores'].items():
                print(f"      {category}: {score}")
            print()

    print("\nNote on Dataset Labels vs API Categories:")
    print("The datasets used in this analysis have different labeling schemes compared to the categories checked by the moderation APIs:")
    print("- HateXplain tracks: 'toxic', 'non-toxic', 'hatespeech', 'normal', 'offensive'")
    print("- Jigsaw tracks: 'toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate'")
    print("- Multi only tracks: 'hatespeech'")
    print("The moderation APIs may check for additional categories not present in these datasets.")
    print("This discrepancy may lead to cases where the APIs flag content that isn't labeled as problematic in the datasets, or vice versa.")
    print("When interpreting the results, consider that differences between API predictions and dataset labels")
    print("might sometimes be due to this mismatch in categorization rather than errors in the API's judgment.")

if __name__ == "__main__":
    main()


Benchmark Results:
{
  "Omni": {
    "Jigsaw": {
      "f1_score": 0.738095238095238,
      "precision": 0.6458333333333334,
      "recall": 0.8611111111111112,
      "false_positive_rate": 0.05717488789237668,
      "avg_latency": 0.01820281386375427,
      "p95_latency": 0.03162048161029815
    },
    "HateXplain": {
      "f1_score": 0.7769066286528866,
      "precision": 0.6703567035670357,
      "recall": 0.923728813559322,
      "false_positive_rate": 0.6536585365853659,
      "avg_latency": 0.015415645360946657,
      "p95_latency": 0.025118092894554127
    },
    "Multilingual": {
      "f1_score": 0.3223443223443223,
      "precision": 0.44,
      "recall": 0.2543352601156069,
      "false_positive_rate": 0.06771463119709795,
      "avg_latency": 0.014957529783248902,
      "p95_latency": 0.019016566276550288
    }
  },
  "Legacy": {
    "Jigsaw": {
      "f1_score": 0.8193832599118943,
      "precision": 0.7815126050420168,
      "recall": 0.8611111111111112,
      "false_po