In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
import pandas as pd
from tqdm import tqdm
import numpy as np

# Model and tokenizer from Hugging Face
MODEL_NAME = "qualifire/prompt-injection-sentinel"  
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"

print(f"Loading control model: {MODEL_NAME} on {DEVICE}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, torch_dtype=torch.float16)
model.to(DEVICE)
model.eval()

# Datasets
DATASETS = [
    {
        "name": "qualifire/prompt-injections-benchmark",
        "split": "test",
        "text_cols": ["prompt", "text"],
        "label_col": "label",
    },
]

MAX_SAMPLES = 500

def pick_text(example, candidates):
    for c in candidates:
        if c in example and example[c]:
            return example[c]
    raise ValueError("No valid text column found")

# Parse label to 0 (safe) or 1 (malicious)
def parse_label(value):
    if isinstance(value, int):
        return value
    if isinstance(value, str):
        v = value.lower()
        if v in ["jailbreak", "malicious", "unsafe", "attack", "injection", "1"]:
            return 1
        if v in ["benign", "safe", "legit", "0"]:
            return 0
    # Default fallback
    try:
        return int(value)
    except:
        return 0

# Run inference
def analyse_control(prompt):
    try:
        inputs = tokenizer(
            prompt,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512
        ).to(DEVICE)
        
        with torch.no_grad():
            logits = model(**inputs).logits
        probs = torch.softmax(logits, dim=-1)
        pred = torch.argmax(probs, dim=-1).item()
        return pred
    except Exception as e:
        print(f"Error inferencing: {e}")
        return 0

all_rows = []

for cfg in DATASETS:
    print(f"\nRunning {cfg['name']} on control model")
    ds = load_dataset(cfg["name"], split=cfg["split"])
    ds = ds.shuffle(seed=42).select(range(min(len(ds), MAX_SAMPLES)))
    
    for ex in tqdm(ds):
        text = pick_text(ex, cfg["text_cols"])
        
        raw_label = ex[cfg["label_col"]]
        true_label = parse_label(raw_label)
        
        pred_label = analyse_control(text)
        
        all_rows.append({
            "dataset": cfg["name"],
            "true_label": true_label,
            "pred_label": pred_label
        })

# Convert to DataFrame
df_control = pd.DataFrame(all_rows)
df_control["correct"] = df_control["true_label"] == df_control["pred_label"]

# Accuracy per dataset
accuracy_df_control = df_control.groupby("dataset")["correct"].mean().reset_index()
accuracy_df_control.rename(columns={"correct": "accuracy"}, inplace=True)

print("\nControl Model Accuracy per dataset:")
print(accuracy_df_control)
print("Overall control accuracy:", df_control["correct"].mean())

Loading control model: qualifire/prompt-injection-sentinel on mps...

Running qualifire/prompt-injections-benchmark on control model


100%|██████████| 500/500 [00:27<00:00, 18.39it/s]


Control Model Accuracy per dataset:
                                 dataset  accuracy
0  qualifire/prompt-injections-benchmark     0.978
Overall control accuracy: 0.978



