In [None]:
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch
from fastapi.middleware.cors import CORSMiddleware


MODEL_PATH = r"C:\Users\masoka\Documents\ai-jailbreak-detector\models\bert_jailbreak_detector"

tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
model.eval()

class PromptIn(BaseModel):
    text: str

app = FastAPI()

# --- CORS (allow your dashboard page to call the API) ---
origins = [
    "http://127.0.0.1:5500",   # VS Code Live Server (common)
    "http://localhost:5500",
    "http://127.0.0.1:8000",
    "http://localhost:8000",
    "null",                    # file:// origin
    "*"                        # dev: allow everything
]

app.add_middleware(
    CORSMiddleware,
    allow_origins=origins,        # use the list you defined
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)


@app.get("/")
def root():
    return {"status": "ok", "message": "Jailbreak detector API running"}

@app.post("/classify")
def classify(prompt: PromptIn):
    inputs = tokenizer(prompt.text, return_tensors="pt", padding=True, truncation=True, max_length=512)

    with torch.no_grad():
        outputs = model(**inputs)
        probs = torch.softmax(outputs.logits, dim=1)[0]

    benign = float(probs[0])
    jailbreak = float(probs[1])

    label = "jailbreak" if jailbreak > benign else "benign"  # more accurate than >0.5
    score = jailbreak if label == "jailbreak" else benign     # <-- THIS IS THE KEY

    return {
        "label": label,
        "score": score,                   # <-- ADD THIS LINE
        "benign_score": benign,
        "jailbreak_score": jailbreak
    }