In [3]:
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from datasets import load_dataset
import pandas as pd
from tqdm import tqdm
import numpy as np

# Model and tokenizer from Hugging Face
MODEL_NAME = "qualifire/prompt-injection-sentinel"  # replace with actual HF repo
DEVICE = "mps" if torch.backends.mps.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_NAME, torch_dtype=torch.float16)
model.to(DEVICE)
model.eval()

# Datasets
DATASETS = [
    {
        "name": "geekyrakshit/prompt-injection-dataset",
        "split": "test",
        "text_cols": ["prompt", "text"],
        "label_col": "label",
    },
    {
        "name": "JasperLS/prompt-injections",
        "split": "test",
        "text_cols": ["text", "prompt"],
        "label_col": "label",
    },
    {
        "name": "deepset/prompt-injections",
        "split": "test",
        "text_cols": ["text", "prompt"],
        "label_col": "label",
    }
]

MAX_SAMPLES = 500

def pick_text(example, candidates):
    for c in candidates:
        if c in example and example[c]:
            return example[c]
    raise ValueError("No valid text column found")

# Run inference
def analyse_control(prompt):
    inputs = tokenizer(
        prompt,
        return_tensors="pt",
        padding=True,
        truncation=True,
        max_length=512
    ).to(DEVICE)
    
    with torch.no_grad():
        logits = model(**inputs).logits
    probs = torch.softmax(logits, dim=-1)
    pred = torch.argmax(probs, dim=-1).item()
    return pred, probs.squeeze().cpu().numpy()

all_rows = []

for cfg in DATASETS:
    print(f"\nRunning {cfg['name']} on control Prompt Guard")
    ds = load_dataset(cfg["name"], split=cfg["split"])
    ds = ds.shuffle(seed=42).select(range(min(len(ds), MAX_SAMPLES)))
    
    for ex in tqdm(ds):
        text = pick_text(ex, cfg["text_cols"])
        true_label = int(ex[cfg["label_col"]])
        pred_label, _ = analyse_control(text)
        
        all_rows.append({
            "dataset": cfg["name"],
            "true_label": true_label,
            "pred_label": pred_label
        })

# Convert to DataFrame
df_control = pd.DataFrame(all_rows)
df_control["correct"] = df_control["true_label"] == df_control["pred_label"]

# Accuracy per dataset
accuracy_df_control = df_control.groupby("dataset")["correct"].mean().reset_index()
accuracy_df_control.rename(columns={"correct": "accuracy"}, inplace=True)

print("\ncontrol Prompt Guard Accuracy per dataset:")
print(accuracy_df_control)
print("Overall control accuracy:", df_control["correct"].mean())



Running geekyrakshit/prompt-injection-dataset on control Prompt Guard


100%|██████████| 500/500 [00:12<00:00, 38.78it/s]



Running JasperLS/prompt-injections on control Prompt Guard


100%|██████████| 116/116 [00:02<00:00, 56.05it/s]



Running deepset/prompt-injections on control Prompt Guard


100%|██████████| 116/116 [00:02<00:00, 57.45it/s]


control Prompt Guard Accuracy per dataset:
                                 dataset  accuracy
0             JasperLS/prompt-injections   0.87069
1              deepset/prompt-injections   0.87069
2  geekyrakshit/prompt-injection-dataset   0.67600
Overall control accuracy: 0.7377049180327869



