In [31]:
import torch
import gradio as gr
import time
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import torch.nn.functional as F

# ===== CONFIG =====
MODEL_DIRS = [
    r"C:\Users\Admin\Desktop\web\outputs\checkpoint-211",
    r"C:\Users\Admin\Desktop\web\outputs\checkpoint-1880",
]

BASE_TOKENIZER_DIR = r"C:\Users\Admin\Desktop\web\outputs\checkpoint-211"  # nơi có tokenizer
LABEL_NAMES = ["non-toxic", "toxic"]
THRESHOLD = 0.5
MAX_LEN = 128

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ===== LOAD TOKENIZER (ONCE) =====
tokenizer = AutoTokenizer.from_pretrained(BASE_TOKENIZER_DIR)

# ===== LOAD MODELS =====
models = []
for ckpt in MODEL_DIRS:
    model = AutoModelForSequenceClassification.from_pretrained(ckpt)
    model.to(device)
    model.eval()
    models.append(model)

# ===== PREDICT FUNCTION =====
def predict(text):
    if not text or text.strip() == "":
        return "Please enter some text", {}, "0 ms"

    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        padding=True,
        max_length=MAX_LEN
    )
    inputs = {k: v.to(device) for k, v in inputs.items()}

    # ===== đo thời gian CHỈ 1 model =====
    start_time = time.perf_counter()
    with torch.no_grad():
        logits_0 = models[0](**inputs).logits
    single_probs = F.softmax(logits_0, dim=1)[0]
    elapsed_ms = (time.perf_counter() - start_time) * 1000

    # ===== ensemble (KHÔNG tính time) =====
    all_probs = [single_probs]
    with torch.no_grad():
        for model in models[1:]:
            logits = model(**inputs).logits
            probs = F.softmax(logits, dim=1)[0]
            all_probs.append(probs)

    avg_probs = torch.stack(all_probs).mean(dim=0).cpu().numpy()

    result = {LABEL_NAMES[i]: float(avg_probs[i]) for i in range(len(LABEL_NAMES))}
    pred_label = "toxic" if avg_probs[1] >= THRESHOLD else "non-toxic"
    
    return pred_label, result, f"{elapsed_ms:.2f} ms"

# ===== GRADIO UI =====
demo = gr.Interface(
    fn=predict,
    inputs=gr.Textbox(lines=3, placeholder="Enter text here..."),
    outputs=[
        gr.Label(label="Prediction"),
        gr.JSON(label="Confidence scores"),
        gr.Textbox(label="Inference time")
    ],
    title="Toxicity Detection Demo (Ensemble)",
    description="DistilBERT multilingual ensemble (2 checkpoints) with inference time"
)

demo.launch()


The tokenizer you are loading from 'C:\Users\Admin\Desktop\web\outputs\checkpoint-211' with an incorrect regex pattern: https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/discussions/84#69121093e8b480e709447d5e. This will lead to incorrect tokenization. You should set the `fix_mistral_regex=True` flag when loading this tokenizer to fix this issue.


* Running on local URL:  http://127.0.0.1:7871
* To create a public link, set `share=True` in `launch()`.


