In [1]:
#Installations to be done when runnning from colab

# When running from Colab uncomment the next two lines.

!pip -q install "gradio>=4.44" "transformers>=4.43" "torch>=2.3" "shap>=0.45" \
               "pytesseract>=0.3.10" "Pillow>=9.5" "sqlalchemy>=2.0" "psycopg2-binary>=2.9" \
               "langdetect>=1.0.9" "pydantic>=2.8" "fastapi>=0.111" "uvicorn[standard]>=0.30" \
               "easyocr>=1.7.1" "opencv-python-headless>=4.10.0.84" -q
!apt-get -y install tesseract-ocr tesseract-ocr-eng tesseract-ocr-deu >/dev/null
!apt-get -y install fonts-noto-core fonts-noto-color-emoji fonts-indic >/dev/null


# The imports required to run ...

import os, io, re, json, time, threading, unicodedata, socket
from datetime import datetime
from typing import Optional, List, Dict, Any

import numpy as np
import pandas as pd
from PIL import Image

import pytesseract
import shap
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

import torch
from transformers import pipeline, AutoTokenizer

from langdetect import detect, DetectorFactory, LangDetectException
DetectorFactory.seed = 42

import sqlalchemy as sa
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, Session

import gradio as gr

# Imports for the REST API
from fastapi import FastAPI, UploadFile, File, Form, Header, Depends, HTTPException, Query
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn

[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/981.5 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m256.0/981.5 kB[0m [31m7.3 MB/s[0m eta [36m0:00:01[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m972.8/981.5 kB[0m [31m15.0 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m981.5/981.5 kB[0m [31m12.4 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.0/3.0 MB[0m [31m39.7 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.9/2.9 MB[0m [31m43.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m510.8/510.8 kB[0m [31m23.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m4.7/4.7 MB[0m [31

In [2]:
# -- coding: utf-8 --
"""
spam_app.py — Multilingual text and Image spam detector
This module provides:
- OCR (Tesseract primary, EasyOCR fallback) with image preprocessing
- Zero-shot XLM-R teacher model and optional smaller student model
- Rule-based scoring with patterns tuned to the SMS spam dataset
- SHAP artifacts (bar, force, token highlight)
- FastAPI endpoints (/predict, /log, /flag)
- Gradio UI for interactive use (Predict / History / Benchmark)
"""

import os
import io
import re
import json
import time
import threading
import unicodedata
import socket
import logging
from datetime import datetime, timezone
from typing import Optional, List, Dict, Any
from functools import lru_cache

import numpy as np
import pandas as pd
from PIL import Image, ImageOps, ImageFilter, ImageEnhance

import pytesseract
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt

import torch
# Commented next line as it might not be needed - check back- refer note3
torch.set_num_threads(int(os.environ.get("TORCH_NUM_THREADS", "1")))
from transformers import pipeline, AutoTokenizer, AutoModelForSequenceClassification

from langdetect import detect, DetectorFactory, LangDetectException
DetectorFactory.seed = 42

import sqlalchemy as sa
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, Session


from fastapi import FastAPI, UploadFile, File, Form, Header, Depends, HTTPException, Query
from fastapi.responses import JSONResponse
from fastapi.middleware.cors import CORSMiddleware
import uvicorn

import gradio as gr
import shap



# Adding Basic logging and configrtn

logging.basicConfig(level=os.environ.get("LOG_LEVEL", "INFO"))
log = logging.getLogger("spam_app")

# Configuration - will go as a file if running as standalone app

DATABASE_URL = os.environ.get("DATABASE_URL", "sqlite:///./spam_records.db")
API_KEY = os.environ.get("API_KEY", "devkey")
RATE_LIMIT_PER_MIN = int(os.environ.get("RATE_LIMIT_PER_MIN", "60"))
OCR_LANGS_DEFAULT = os.environ.get("OCR_LANGS", "eng+deu")
DEVICE = 0 if torch.cuda.is_available() else -1
os.environ["TOKENIZERS_PARALLELISM"] = "false"

BASE_THRESHOLD = float(os.environ.get("SPAM_THRESHOLD", "0.35"))
MODEL_WEIGHT   = float(os.environ.get("MODEL_WEIGHT",   "0.60"))
RULE_WEIGHT    = float(os.environ.get("RULE_WEIGHT",    "0.40"))
FORCE_RULE_CUTOFF = float(os.environ.get("FORCE_RULE_CUTOFF", "0.58"))
MAX_SEQ_LEN = int(os.environ.get("MAX_SEQ_LEN", "256"))

# dirs
ARTIFACT_DIR = os.path.abspath("./shap_artifacts"); os.makedirs(ARTIFACT_DIR, exist_ok=True)
##os.makedirs(ARTIFACT_DIR, exist_ok=True)
RULEPLOT_DIR = os.path.abspath("./rule_plots");     os.makedirs(RULEPLOT_DIR, exist_ok=True)
##os.makedirs(RULEPLOT_DIR, exist_ok=True)-note5


#Text normalisation part - adding helpers

ZERO_WIDTH_RE = re.compile(r"[\u200B-\u200D\uFEFF]")
URL_RE        = re.compile(r"https?://\S+|www\.\S+")
PHONE_RE      = re.compile(r"(?<!\d)(\+?\d[\d\-\s]{6,}\d)")
_MULTI_SPACE  = re.compile(r"\s+")

_OCR_URL_LIKE = re.compile(r"(?i)(?:[oh0]+://|h\s*[xt]{1,2}\s*[xt]?\s*p\s*s?\s*[:：]\s*/\s*/)")
_SPACED_WWW   = re.compile(r"(?i)\bw\s*w\s*w\s*\.")
_SPACED_WAP   = re.compile(r"(?i)\bw\s*a\s*p\s*\.")
_ODD_DOTS     = re.compile(r"\s*[\.\u2024\u2027\u2219·]\s*")
_PCT_NUM      = re.compile(r"%\s*([0-9]+)")
REPEAT_RE     = re.compile(r'(.)\1{2,}')

 #Check for common OCR-specific noise like spaced 'www', unusal dots, percent signs.

def _normalize_ocr_noise(t: str) -> str:
    if not t: return t
    t = _OCR_URL_LIKE.sub(" http://", t)
    t = _SPACED_WWW.sub("www.", t)
    t = _SPACED_WAP.sub("wap.", t)
    t = _ODD_DOTS.sub(".", t)
    t = _PCT_NUM.sub(r"\1", t)
    return t


# Preprocess text
def preprocess_text(text: str) -> str:
    if not text: return ""
    t = unicodedata.normalize("NFKC", text)
    t = _normalize_ocr_noise(t)
    t = ZERO_WIDTH_RE.sub("", t)
    t = URL_RE.sub(" <URL> ", t)
    t = PHONE_RE.sub(" <PHONE> ", t)
    t = REPEAT_RE.sub(r'\1\1', t)
    t = _MULTI_SPACE.sub(" ", t)
    return t.strip()

def detect_language_safe(text: str) -> str:
    try:
        return detect(text)
    except LangDetectException:
        return "unknown"

# Optional easy OCR installation
try:
    import easyocr
    HAVE_EASYOCR = True
except Exception:
    HAVE_EASYOCR = False
    #easyocr = None # Explicitly set to None if import fails


def _easyocr_langs_from_tesseract(lang_str: str) -> List[str]:
    parts = [p.strip() for p in (lang_str or "eng").split("+")]
    m = {"eng":"en","hin":"hi","tam":"ta","tel":"te","kan":"kn","mal":"ml","mar":"mr","guj":"gu",
         "ben":"bn","pan":"pa","urd":"ur","ori":"or","asm":"as","nep":"ne","san":"sa",
         "fra":"fr","deu":"de","spa":"es"}
    return [m.get(p, p) for p in parts]


# Revisit this - working for now - but might need to improve-check note 8
def ocr_image_bytes(img_bytes: bytes, ocr_langs: str = OCR_LANGS_DEFAULT) -> str:
    try:
        image = Image.open(io.BytesIO(img_bytes)).convert("RGB")
        text = pytesseract.image_to_string(image, lang=ocr_langs or OCR_LANGS_DEFAULT).strip()
        if text: return text
    except Exception:
        pass
    if HAVE_EASYOCR and easyocr is not None: # Added check for easyocr being None
        try:
            import numpy as np
            langs = _easyocr_langs_from_tesseract(ocr_langs or OCR_LANGS_DEFAULT)
            reader = easyocr.Reader(langs, gpu=torch.cuda.is_available())
            img = np.array(Image.open(io.BytesIO(img_bytes)).convert("RGB"))
            res = reader.readtext(img, detail=0, paragraph=True)
            return "\n".join([r.strip() for r in res if r.strip()])
        except Exception:
            return ""
    return ""

# Definition of model with a zero shot teacher and an optional student

ZS_MODEL = os.environ.get("ZS_MODEL", "joeddav/xlm-roberta-large-xnli")
pipe_kwargs = {"device": DEVICE}
if DEVICE >= 0:
    pipe_kwargs["model_kwargs"] = {"torch_dtype": torch.float16}
tokenizer = AutoTokenizer.from_pretrained(ZS_MODEL, use_fast=True)
zspipe = pipeline("zero-shot-classification", model=ZS_MODEL, tokenizer=tokenizer, **pipe_kwargs)
CANDIDATE_LABELS = ["spam", "not spam"]
HYPOTHESIS_TEMPLATE = "This text is {}."

USE_STUDENT   = os.environ.get("USE_STUDENT", "true").lower() == "true"
STUDENT_MODEL = os.environ.get("STUDENT_MODEL", "huawei-noah/TinyBERT_General_4L_312D")
STUDENT_STATE = os.environ.get("STUDENT_STATE", "models/student/tinybert.pt")
student_tok = None
student_model = None
if USE_STUDENT:
    try:
        student_tok = AutoTokenizer.from_pretrained(STUDENT_MODEL, use_fast=True)
        student_model = AutoModelForSequenceClassification.from_pretrained(STUDENT_MODEL, num_labels=2)
        if STUDENT_STATE and os.path.exists(STUDENT_STATE):
            sd = torch.load(STUDENT_STATE, map_location="cpu")
            student_model.load_state_dict(sd, strict=False)
        student_model.eval()
        if DEVICE >= 0: student_model.to(f"cuda:{DEVICE}")
    except Exception as e:
        log.warning("student model unavailable: %s", e)
        #print(" student unavailable:", e) # modified look at note 10
        student_tok = None; student_model = None


# Incase there is a json, then we can use it for autothresholding
# not useful if running from colab

AUTO_THRESH = {"teacher": None, "student": None}
try:
    if os.path.exists("reports/threshold.json"):
        with open("reports/threshold.json", "r", encoding="utf-8") as f:
            data = json.load(f)
        AUTO_THRESH["student"] = float(data.get("student")) if "student" in data else None
        AUTO_THRESH["teacher"] = float(data.get("teacher")) if "teacher" in data else None
except Exception as e:
    print(" threshold.json:", e)
# check note 11
def _dynamic_threshold(rule_score: float, which: str = "teacher") -> float:
    base = AUTO_THRESH.get(which) or BASE_THRESHOLD
    return max(0.28, base - 0.07 * max(0.0, rule_score - 0.5))

# Lang Patterns
# Added rule patterns based on the study of common sms words appearing usually.
# Keep revisiting and refine this lot.

_LANG_PATTERNS = {
    "en": [
        r"\b(congratulations?|congrats|you\s+won|you['’']ve\s+won|reward|prize|jackpot|lottery)\b",
        r"\b(urgent|limited\s*offer|act\s*now|winner|selected)\b",
        r"\b(claim|redeem|verify|activate)\b",
        r"\bfree\s+(?:entry|membership|msg|gift|bonus|trial)\b",
        r"\b(?:txt|text)\s+(?:the\s+word\s+)?[A-Z]{3,}(?:\s+to)?\b",
        r"\b(?:wap|www)\b",
        r"\bpo\s*box\b|\bltd\b|\bco\b",
        r"\bstop\s+(?:to\s+)?unsubscribe\b|\bstop\b\s*\d{4,6}\b",
        r"\bNo[: ]\s*\d{4,6}\b|\b\d{5}\b",
    ],
    # added German - note 12
    "de": [
        r"glückwunsch|gewonnen|preis|jackpot|lotterie|gewinner",
        r"jetzt\s+handeln|dringend|angebot|klicken\s+sie",
        r"beanspruchen|einlösen|verifizieren|aktivieren|klicken",
        r"kostenlos|gratis|gewinnspiel|teilnehmen",
    ],
    "hi": [r"बधाई|इनाम|जीता|दावा|रिवार्ड|क्लेम|क्लिक|पुरस्कार|लॉटरी"],
    "ta": [r"வாழ்த்துகள்|பரிசு|வென்றுள்ளீர்கள்|க்ளெய்ம்|கிளிக்|லாட்டரி"],
    "es": [r"felicitaciones|has\s+ganado|premio|reclamar|haz\s+clic|loter[íi]a"],
    "mr": [r"अभिनंदन|बक्षीस|जिंकले|दावा|क्लिक|लॉटरी"],
}
# second installment of additional language patterns
_LANG_PATTERNS["en"] += [
        r"\b(congratulations?|congrats|you\s+won|you['']ve\s+won|reward|prize|jackpot|lottery)\b",
        r"\b(urgent|limited\s*offer|act\s*now|winner|selected)\b",
        r"\b(claim|redeem|verify|activate)\b",
        r"\bfree\s+(?:entry|membership|msg|gift|bonus|trial)\b",
        r"\b(?:txt|text)\s+(?:the\s+word\s+)?[A-Z]{3,}(?:\s+to)?\b",
        r"\b(?:wap|www)\b",
        r"\bpo\s*box\b|\bltd\b|\bco\b",
        r"\bstop\s+(?:to\s+)?unsubscribe\b|\bstop\b\s*\d{4,6}\b",
        r"\bNo[: ]\s*\d{4,6}\b|\b\d{5}\b",
        # Enhanced patterns for better spam detection
        r"\b(?:dating|sex|horny|laid|shag|sexy|admirer)\b",
        r"\b(?:quiz|competition|contest|play\s+now|answer\s+\d+\s+easy)\b",
        r"\b(?:upgrade|latest|newest|colour\s+camera|bluetooth)\b",
        r"\b(?:ringtone|polyphonic|tone|content|download)\b",
        r"\b(?:subscription|charged|monthly|weekly|per\s+week)\b",
        r"\b(?:valentine|christmas|xmas|special\s+day)\b",
        r"\b(?:secret\s+admirer|someone\s+has\s+contacted)\b",
        r"\b(?:age\s+verify|18\+|16\+|over\s+\d+)\b",
        r"\b(?:customer\s+service|helpline|freephone)\b",
        r"\b(?:mobile\s+update|phone\s+upgrade|camera\s+phone)\b",
]

_GENERIC_PATTERNS = [
      r"<URL>",
    r"(bit\.ly|tinyurl\.com|t\.co|is\.gd|ow\.ly|goo\.gl|rb\.gy|dbuk\.net|wap\.|club4mobiles\.com|regalportfolio\.co\.uk|100percent-real\.com|areyouunique\.co\.uk|sextextuk\.com)",
    r"₹|₹\s*\d|\bRs\.?\s*\d|£\s*\d|€\s*\d|\$\s*\d",
    r"\b(win|winner|gift|bonus|promo|cash|credit)\b",
    r"\b(subscription|membership|renew)\b",
    r"\b(free\s*(?:gift|bonus|msg|entry|trial))\b",
    r"\b(comp|competition)\s+(?:to\s+)?win\b",
    r"\bterms?\s*&?\s*conditions\b|\bT&Cs?\b|\bT\s*&\s*C\b",
    # Enhanced generic patterns
    r"\b(?:dating|sex|horny|laid|shag|sexy|admirer|fancy)\b",
    r"\b(?:quiz|competition|contest|play\s+now|answer\s+\d+\s+easy)\b",
    r"\b(?:upgrade|latest|newest|colour\s+camera|bluetooth|mobile)\b",
    r"\b(?:ringtone|polyphonic|tone|content|download|service)\b",
    r"\b(?:subscription|charged|monthly|weekly|per\s+week|gbp|pounds?)\b",
    r"\b(?:valentine|christmas|xmas|special\s+day|holiday)\b",
    r"\b(?:secret\s+admirer|someone\s+has\s+contacted|contacted\s+our)\b",
    r"\b(?:age\s+verify|18\+|16\+|over\s+\d+)\b",
    r"\b(?:customer\s+service|helpline|freephone|call\s+now)\b",
    r"\b(?:mobile\s+update|phone\s+upgrade|camera\s+phone|handset)\b",
    r"\b(?:voucher|vouchers|cd\s+voucher|gift\s+voucher)\b",
    r"\b(?:box|pobox|po\s*box|suite|lands\s+row)\b",
    r"\b(?:double\s+mins|unlimited\s+text|network\s+mins)\b",
    r"\b(?:half\s+price|special\s+offer|limited\s+time)\b",
    r"\b(?:opt\s*out|optout|2optout|call2optout)\b",
]
_HAM_SAFETY = [
  r"\bmeeting|agenda|minutes\b",
    r"\bcall\b\s*(?:at|@|\d)",
    r"\bteam|project|deadline|doc|attachment\b",
    # Next iteration of ham safety messages
    r"\b(?:thanks|thank\s+you|appreciate|grateful)\b",
    r"\b(?:sorry|apologize|apology|excuse)\b",
    r"\b(?:please|pls|kindly|would\s+you)\b",
    r"\b(?:help|assist|support|guidance)\b",
    r"\b(?:question|ask|wondering|curious)\b",
    r"\b(?:work|job|office|colleague|boss)\b",
    r"\b(?:family|mom|dad|brother|sister|parent)\b",
    r"\b(?:friend|buddy|pal|mate)\b",
    r"\b(?:love|care|miss|thinking\s+of\s+you)\b",
    r"\b(?:birthday|anniversary|celebration|party)\b",
    r"\b(?:dinner|lunch|breakfast|food|eat)\b",
    r"\b(?:movie|film|cinema|theater|show)\b",
    r"\b(?:book|read|reading|study|learn)\b",
    r"\b(?:travel|trip|vacation|holiday|visit)\b",
    r"\b(?:health|doctor|hospital|medical|sick)\b",
    r"\b(?:weather|rain|sunny|cold|hot)\b",
    r"\b(?:time|clock|hour|minute|schedule)\b",
    r"\b(?:home|house|apartment|room|bedroom)\b",
    r"\b(?:car|drive|driving|road|traffic)\b",
    r"\b(?:school|university|college|class|student)\b",
]
# hard overrides (premium SMS + your failing cases)
_HARD_SPAM_PATTERNS = [

    r"\bfreemsg\b",
    r"\bstd\s*(?:network\s*)?chg?s?\b",
    r"£\s*\d+(?:\.\d{2})?\s*to\s*(?:rcv|receive)\b",
    r"\b(ringtone|polyphonic|txt\s+win|claim\s+now)\b",
    r"\bwin(?:ner)?[:!]\b",
    r"\bvisit\s+www\.win-\d{4,}\.co\.uk\b",
    r"\byou\s+are\s+subscribed\s+to\s+the\s+best\s+mobile\s+content\s+service\b",
    r"\b£\s*\d+(?:\.\d{2})?\s+per\s+(?:ten\s+days|week|day|msg)\b",
    r"\bhelpline\b\s*\d{5,}",
    r"\bdivorce\s+barbie\b",
    r"\bcomes?\s+with\s+all\s+of\s+ken['']s\s+stuff\b",
    # additional Enhanced hard spam patterns
    r"\b(?:want\s+2\s+get\s+laid|horny|shag|sextextuk)\b",
    r"\b(?:dating\s+service|secret\s+admirer|someone\s+has\s+contacted)\b",
    r"\b(?:quiz|competition|contest|play\s+now|answer\s+\d+\s+easy)\b",
    r"\b(?:upgrade|latest|newest|colour\s+camera|bluetooth|mobile)\b",
    r"\b(?:ringtone\s+service|polyphonic|tone|content|download)\b",
    r"\b(?:subscription|charged|monthly|weekly|per\s+week)\b",
    r"\b(?:valentine|christmas|xmas|special\s+day|holiday)\b",
    r"\b(?:age\s+verify|18\+|16\+|over\s+\d+)\b",
    r"\b(?:customer\s+service|helpline|freephone|call\s+now)\b",
    r"\b(?:mobile\s+update|phone\s+upgrade|camera\s+phone|handset)\b",
    r"\b(?:voucher|vouchers|cd\s+voucher|gift\s+voucher)\b",
    r"\b(?:double\s+mins|unlimited\s+text|network\s+mins)\b",
    r"\b(?:half\s+price|special\s+offer|limited\s+time)\b",
    r"\b(?:opt\s*out|optout|2optout|call2optout)\b",
    r"\b(?:box|pobox|po\s*box|suite|lands\s+row)\b",
    r"\b(?:gbp|pounds?|pence|ppm|p\/msg)\b",
    r"\b(?:free\s+entry|free\s+membership|free\s+gift|free\s+bonus)\b",
    r"\b(?:guaranteed|prize|jackpot|lottery|winner)\b",
    r"\b(?:urgent|limited\s+offer|act\s+now|selected)\b",
    r"\b(?:claim|redeem|verify|activate|call\s+now)\b",
]

def _hard_override(pre_text: str) -> bool:
    """
    Return True when the text matches hard spam pattern.
    """
    t = (pre_text or "").lower()
    return any(re.search(p, t, flags=re.I) for p in _HARD_SPAM_PATTERNS)

def _rule_based_spam_score(pre_text: str, lang_hint: str) -> float:
    if not pre_text: return 0.0
    text = pre_text.lower()
    ham_hits = sum(1 for pat in _HAM_SAFETY if re.search(pat, text, flags=re.I))
    ham_discount = 0.15 * ham_hits  # Increased discount for ham safety
    hits = 0.0; max_hits = 1e-9

    # Language-specific patterns
    for pat in _LANG_PATTERNS.get(lang_hint, []):
        max_hits += 1; hits += 1 if re.search(pat, text, flags=re.I) else 0

    # Generic patterns
    for pat in _GENERIC_PATTERNS:
        max_hits += 1; hits += 1 if re.search(pat, text, flags=re.I) else 0

    # URL patterns
    url_count = text.count("<url>")
    hits += min(2, url_count) * 0.6; max_hits += 1.0

    # Enhanced scoring for specific spam indicators
    if len(text.split()) <= 32 and re.search(r"(claim|click|क्लिक|கிளிக்|haz\s+clic|redeem|verify)", text, re.I):
        hits += 1.2; max_hits += 1.2

    # Money patterns
    if re.search(r"(jackpot|lottery|£\s*\d{2,}|€\s*\d{2,}|\$\s*\d{2,})", text, re.I):
        hits += 1.0; max_hits += 1.0

    # Premium SMS patterns
    if re.search(r"\b(freemsg|std\s*chgs?|£\s*\d+(?:\.\d{2})?\s+to\s*(?:rcv|receive))\b", text, re.I):
        hits += 2.0; max_hits += 2.0

    # Dating/sexual content patterns
    if re.search(r"\b(?:dating|sex|horny|laid|shag|sexy|admirer|fancy)\b", text, re.I):
        hits += 1.5; max_hits += 1.5

    # Quiz/competition patterns
    if re.search(r"\b(?:quiz|competition|contest|play\s+now|answer\s+\d+\s+easy)\b", text, re.I):
        hits += 1.3; max_hits += 1.3

    # Mobile upgrade patterns
    if re.search(r"\b(?:upgrade|latest|newest|colour\s+camera|bluetooth|mobile)\b", text, re.I):
        hits += 1.1; max_hits += 1.1

    # Ringtone/content service patterns
    if re.search(r"\b(?:ringtone|polyphonic|tone|content|download|service)\b", text, re.I):
        hits += 1.2; max_hits += 1.2

    # Subscription patterns
    if re.search(r"\b(?:subscription|charged|monthly|weekly|per\s+week|gbp|pounds?)\b", text, re.I):
        hits += 1.4; max_hits += 1.4

    # Special day patterns
    if re.search(r"\b(?:valentine|christmas|xmas|special\s+day|holiday)\b", text, re.I):
        hits += 1.0; max_hits += 1.0

    # Secret admirer patterns
    if re.search(r"\b(?:secret\s+admirer|someone\s+has\s+contacted|contacted\s+our)\b", text, re.I):
        hits += 1.6; max_hits += 1.6

    # Age verification patterns
    if re.search(r"\b(?:age\s+verify|18\+|16\+|over\s+\d+)\b", text, re.I):
        hits += 1.3; max_hits += 1.3

    # Customer service patterns
    if re.search(r"\b(?:customer\s+service|helpline|freephone|call\s+now)\b", text, re.I):
        hits += 1.1; max_hits += 1.1

    # Voucher patterns
    if re.search(r"\b(?:voucher|vouchers|cd\s+voucher|gift\s+voucher)\b", text, re.I):
        hits += 1.2; max_hits += 1.2

    # Phone number patterns (UK mobile numbers)
    if re.search(r"\b(?:07\d{9}|08\d{8,9}|\+44\d{10,11})\b", text):
        hits += 0.8; max_hits += 0.8

    # Premium rate number patterns
    if re.search(r"\b(?:09\d{8,9}|087\d{7,8})\b", text):
        hits += 1.5; max_hits += 1.5

    score = max(0.0, min(1.0, hits / max_hits))
    return float(max(0.0, score - ham_discount))


# Here incorporate rules for Inference

@lru_cache(maxsize=2048)
def _cached_zs(pre: str) -> tuple:
    outs = zspipe([pre], candidate_labels=CANDIDATE_LABELS, multi_label=False,
                  hypothesis_template=HYPOTHESIS_TEMPLATE, truncation=True)
    o = outs[0] if isinstance(outs, list) else outs
    labels = o["labels"]; scores = o["scores"]
    return (float(scores[labels.index("spam")]), float(scores[labels.index("not spam")]))

def predict_proba_zs(texts: List[str]) -> np.ndarray:
    if isinstance(texts, str): texts = [texts]
    probs = []
    for t in texts:
        tokens = t.split()
        if len(tokens) > MAX_SEQ_LEN: t = " ".join(tokens[:MAX_SEQ_LEN])
        p_spam, p_not = _cached_zs(t)
        probs.append([p_spam, p_not])
    return np.array(probs, dtype=np.float32)

@torch.inference_mode()
def predict_proba_student(texts: List[str], max_len: int = 128) -> Optional[np.ndarray]:
    if not (student_tok and student_model): return None
    if isinstance(texts, str): texts = [texts]
    enc = student_tok(texts, padding=True, truncation=True, max_length=max_len, return_tensors="pt")
    if DEVICE >= 0: enc = {k: v.to(f"cuda:{DEVICE}") for k,v in enc.items()}
    out = student_model(**enc)
    logits = out.logits.detach().cpu().numpy()
    e = np.exp(logits - logits.max(axis=1, keepdims=True))
    return e / e.sum(axis=1, keepdims=True)

def classify_text(raw_text: str) -> Dict[str, Any]:
    pre = preprocess_text(raw_text)
    lang = detect_language_safe(raw_text)

    if _hard_override(pre):
        rule_spam = _rule_based_spam_score(pre, lang_hint=lang)
        model_spam = 0.80
        spam_prob = float(min(1.0, MODEL_WEIGHT * model_spam + RULE_WEIGHT * rule_spam))
        return {
            "label": "spam",
            "spam_prob": spam_prob,
            "not_spam_prob": float(1.0 - spam_prob),
            "preprocessed": pre,
            "lang_hint": lang,
            "rule_spam": rule_spam,
            "model_spam": model_spam,
            "threshold": _dynamic_threshold(rule_spam, which="teacher"),
            "which_model_for_th": "teacher",
        }

    p_zs = predict_proba_zs([pre])[0]; model_spam = float(p_zs[0]); which = "teacher"
    if lang in {"en"} and student_model is not None:
        which = "student"
        p_student = predict_proba_student([pre])[0]
        model_spam = float((p_zs[0] + p_student[0]) / 2.0)

    rule_spam = _rule_based_spam_score(pre, lang_hint=lang)
    spam_prob = MODEL_WEIGHT * model_spam + RULE_WEIGHT * rule_spam
    spam_prob = float(min(1.0, max(0.0, spam_prob)))
    not_spam_prob = float(1.0 - spam_prob)
    th = _dynamic_threshold(rule_spam, which=which)
    label = "spam" if (rule_spam >= FORCE_RULE_CUTOFF or spam_prob >= th) else "not spam"

    return {
        "label": label,
        "spam_prob": spam_prob,
        "not_spam_prob": not_spam_prob,
        "preprocessed": pre,
        "lang_hint": lang,
        "rule_spam": rule_spam,
        "model_spam": model_spam,
        "threshold": th,
        "which_model_for_th": which,
    }


# Rule evidence plotting - output shows the counts of rule categorise for a visual evidence.

def _rule_hits_breakdown(pre_text: str, lang_hint: str) -> Dict[str, int]:
    text = (pre_text or "").lower()
    def any_hit(pats): return sum(1 for pat in pats if re.search(pat, text, flags=re.I))
    return {
        "lang_rules": any_hit(_LANG_PATTERNS.get(lang_hint, [])),
        "generic_rules": any_hit(_GENERIC_PATTERNS),
        "shortcode_STOP": 1 if re.search(r"\bstop\b\s*(?:to\s*(?:end|stop|opt\s*out))?\b", text, re.I) else 0,
        "helpline": 1 if re.search(r"\b(08\d{7,}|\+?44\d{9,})\b", text) else 0,
        "money/£": 1 if re.search(r"£\s*\d+(?:\.\d{2})?", text) else 0,
        "urls": min(2, text.count("<url>")),
    }

def make_rule_evidence_plot(pre_text: str, lang_hint: str, rec_id: int) -> str:
    hits = _rule_hits_breakdown(pre_text, lang_hint)
    keys = list(hits.keys()); vals = [hits[k] for k in keys]
    fig, ax = plt.subplots(figsize=(7.8, 3.2))
    ax.bar(range(len(keys)), vals)
    ax.set_xticks(range(len(keys))); ax.set_xticklabels(keys, rotation=20)
    ax.set_ylabel("hits"); ax.set_title("Rule Evidence")
    fig.tight_layout()
    out = os.path.join(RULEPLOT_DIR, f"rules_{rec_id}.png")
    fig.savefig(out, dpi=160, bbox_inches="tight"); plt.close(fig)
    return out


# Generate the SHAP artifacts

def make_shap_artifacts(sample_text: str, rec_id: int, class_index: int = 0):
    base = f"rec_{rec_id}"
    bar_png   = os.path.join(ARTIFACT_DIR, f"{base}_bar.png")
    text_html = os.path.join(ARTIFACT_DIR, f"{base}_text.html")
    force_png = os.path.join(ARTIFACT_DIR, f"{base}_force.png")

    def _as_text_list(x):
        if isinstance(x, str): return [x]
        if isinstance(x, (list, tuple)): return [str(t) for t in x]
        try: return [str(t) for t in np.array(x, dtype=object).ravel().tolist()]
        except Exception: return [str(x)]

    def f(pred_texts):
        rows = []
        for t in _as_text_list(pred_texts):
            rows.append(predict_proba_zs([t])[0])
        return np.asarray(rows, dtype=np.float32)

    tokens, shap_vals, base_value = None, None, None
    try:
        masker = shap.maskers.Text(tokenizer=r"\s+")
        explainer = shap.Explainer(f, masker, algorithm="partition", output_names=["spam","not_spam"])
        exp = explainer([sample_text], batch_size=1)
        ex = exp[0, :, class_index]
        tokens = list(ex.data)
        shap_vals = np.array(ex.values, dtype=float)
        base_value = float(np.array(ex.base_values).reshape(-1)[0])
    except Exception:
        tokens = sample_text.split()
        shap_vals = np.zeros(len(tokens), dtype=float)
        base_value = float(predict_proba_zs([sample_text])[0, class_index])

    k = min(20, len(tokens))
    order = np.argsort(np.abs(shap_vals))[::-1][:k]
    plt.figure(figsize=(6.8, max(2.5, 0.28*k)))
    vals = [shap_vals[i] for i in order][::-1]
    labs = [tokens[i]     for i in order][::-1]
    plt.barh(range(len(vals)), vals, tick_label=labs)
    plt.xlabel("SHAP value (impact)")
    plt.tight_layout(); plt.savefig(bar_png, dpi=200, bbox_inches="tight"); plt.close()

    max_abs = float(np.max(np.abs(shap_vals))) if len(shap_vals) else 1.0
    def span(tok, v):
        a = 0.15 + 0.85 * (abs(v) / (max_abs + 1e-12))
        bg = f"rgba(214,39,40,{a:.3f})" if v >= 0 else f"rgba(31,119,180,{a:.3f})"
        safe = (tok or "").replace("&","&amp;").replace("<","&lt;").replace(">","&gt;")
        return f"<span style='background:{bg}; padding:2px 4px; margin:1px; border-radius:4px; display:inline-block'>{safe}&nbsp;</span>"
    with open(text_html, "w", encoding="utf-8") as fhtml:
        fhtml.write("<div style='font-family:system-ui,Arial,sans-serif;line-height:1.9;font-size:16px;'>"
                    + "".join([span(t, v) for t, v in zip(tokens, shap_vals)]) + "</div>")

    try:
        shap.force_plot(base_value, shap_vals, tokens, matplotlib=True, show=False)
        plt.gcf().savefig(force_png, dpi=200, bbox_inches="tight"); plt.close(plt.gcf())
    except Exception:
        fig, ax = plt.subplots(figsize=(7.2, 1.8))
        csum = np.cumsum(sorted(shap_vals))
        ax.plot(csum); ax.set_title("SHAP force (fallback)")
        fig.tight_layout(); fig.savefig(force_png, dpi=200, bbox_inches="tight"); plt.close(fig)

    return {
        "bar_png": bar_png,
        "text_html": text_html,
        "force_png": force_png,
        "top_tokens_json": json.dumps(
            [{"token": tokens[i], "shap": float(shap_vals[i])} for i in order], ensure_ascii=False
        ),
    }

# Database model definition along with helpers
# revisit

class Base(DeclarativeBase): pass

class SMSRecord(Base):
    __tablename__ = "sms_records"
    id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
    created_at: Mapped[str] = mapped_column(sa.String, default=lambda: datetime.now(timezone.utc).isoformat())
    api_key: Mapped[Optional[str]] = mapped_column(sa.String, nullable=True)
    is_image: Mapped[bool] = mapped_column(sa.Boolean, default=False)
    raw_text: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
    ocr_text: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
    preprocessed: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
    lang: Mapped[Optional[str]] = mapped_column(sa.String, nullable=True)
    label: Mapped[str] = mapped_column(sa.String)
    spam_prob: Mapped[float] = mapped_column(sa.Float)
    not_spam_prob: Mapped[float] = mapped_column(sa.Float)
    shap_top_tokens: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
    shap_bar_path: Mapped[Optional[str]] = mapped_column(sa.String, nullable=True)
    shap_text_path: Mapped[Optional[str]] = mapped_column(sa.String, nullable=True)
    shap_force_path: Mapped[Optional[str]] = mapped_column(sa.String, nullable=True)
    rule_png_path: Mapped[Optional[str]] = mapped_column(sa.String, nullable=True)

class FlagRecord(Base):
    __tablename__ = "flags"
    id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
    created_at: Mapped[str] = mapped_column(sa.String, default=lambda: datetime.now(timezone.utc).isoformat())
    record_id: Mapped[int] = mapped_column(sa.Integer)
    flag_type: Mapped[str] = mapped_column(sa.String)
    note: Mapped[Optional[str]] = mapped_column(sa.Text, nullable=True)
    api_key: Mapped[Optional[str]] = mapped_column(sa.String, nullable=True)

is_sqlite = DATABASE_URL.startswith("sqlite")
engine = sa.create_engine(DATABASE_URL, future=True, pool_pre_ping=True,
                          connect_args={"check_same_thread": False} if is_sqlite else {})
Base.metadata.create_all(engine)

def _ensure_schema():
    if not is_sqlite: return
    with engine.begin() as conn:
        cols = [row[1] for row in conn.exec_driver_sql("PRAGMA table_info('sms_records')").fetchall()]
        if "rule_png_path" not in cols:
            conn.exec_driver_sql("ALTER TABLE sms_records ADD COLUMN rule_png_path VARCHAR")
_ensure_schema()

def _persist_record(data: Dict[str, Any]) -> int:
    with Session(engine) as sess:
        rec = SMSRecord(**data); sess.add(rec); sess.commit(); sess.refresh(rec)
        return rec.id

def _serialize_record(rec: SMSRecord) -> Dict[str, Any]:
    return {
        "id": rec.id, "created_at": rec.created_at, "is_image": rec.is_image, "lang": rec.lang,
        "raw_text": rec.raw_text, "ocr_text": rec.ocr_text, "preprocessed": rec.preprocessed,
        "label": rec.label, "spam_prob": rec.spam_prob, "not_spam_prob": rec.not_spam_prob,
        "shap_top_tokens": json.loads(rec.shap_top_tokens or "[]"),
        "shap_bar_path": rec.shap_bar_path, "shap_text_path": rec.shap_text_path,
        "shap_force_path": rec.shap_force_path, "rule_png_path": rec.rule_png_path,
    }


# Implement the rate limiter - if rate limit is exceeded return
# status code of 429
_rate_state: Dict[str, Dict[str, Any]] = {}
def rate_limit_guard(api_key: str):
    now = time.time()
    st = _rate_state.get(api_key)
    if st is None or now - st["window_start"] > 60.0:
        _rate_state[api_key] = {"window_start": now, "count": 1}; return
    if st["count"] >= RATE_LIMIT_PER_MIN:
        raise HTTPException(status_code=429, detail="Rate limit exceeded")
    st["count"] += 1

def require_api_key(x_api_key: Optional[str] = Header(default=None)):
    if x_api_key != API_KEY:
        raise HTTPException(status_code=401, detail="Invalid or missing API key.")
    rate_limit_guard(x_api_key); return x_api_key


# Check for freeport

def find_free_port(preferred: int = 8000, span: int = 100) -> int:
    for p in [preferred] + list(range(preferred + 1, preferred + 1 + span)):
        with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
            s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
            try:
                s.bind(("0.0.0.0", p)); return p
            except OSError: continue
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
        s.bind(("0.0.0.0", 0)); return s.getsockname()[1]

API_PORT = int(os.environ.get("API_PORT", "8000"))
api = FastAPI(title="Spam Detector API", version="1.5.0")
api.add_middleware(CORSMiddleware,
    allow_origins=[""], allow_credentials=True, allow_methods=[""], allow_headers=["*"])

@api.post("/predict")
async def api_predict(
    x_api_key: str = Depends(require_api_key),
    text: Optional[str] = Form(default=None),
    ocr_langs: Optional[str] = Form(default=OCR_LANGS_DEFAULT),
    image: Optional[UploadFile] = File(default=None),
):
    if not text and not image:
        raise HTTPException(status_code=400, detail="Provide 'text' or 'image'.")
    is_image = image is not None
    raw_text, ocr_text = text, None

    if image:
        img_bytes = await image.read()
        ocr_text = ocr_image_bytes(img_bytes, ocr_langs=ocr_langs or OCR_LANGS_DEFAULT)
        raw_text = (raw_text or "") + ("\n" if raw_text else "") + (ocr_text or "")

    if not (raw_text or "").strip():
        raise HTTPException(status_code=400, detail="No text to classify (empty after OCR).")

    cls = classify_text(raw_text)
    rec_id = _persist_record({
        "api_key": x_api_key, "is_image": bool(is_image),
        "raw_text": text, "ocr_text": ocr_text, "preprocessed": cls["preprocessed"],
        "lang": cls["lang_hint"], "label": cls["label"],
        "spam_prob": cls["spam_prob"], "not_spam_prob": cls["not_spam_prob"],
    })

    # plots — synchronous to guarantee paths are valid
    rule_png = make_rule_evidence_plot(cls["preprocessed"], cls["lang_hint"], rec_id)
    shap_paths = make_shap_artifacts(cls["preprocessed"], rec_id, class_index=0)

    with Session(engine) as sess:
        rec = sess.get(SMSRecord, rec_id)
        rec.rule_png_path = rule_png
        rec.shap_top_tokens = shap_paths["top_tokens_json"]
        rec.shap_bar_path = shap_paths["bar_png"]
        rec.shap_text_path = shap_paths["text_html"]
        rec.shap_force_path = shap_paths["force_png"]
        sess.commit()

    return JSONResponse({
        "record_id": rec_id, "lang": cls["lang_hint"], "label": cls["label"],
        "spam_prob": cls["spam_prob"], "not_spam_prob": cls["not_spam_prob"],
        "ocr_text": ocr_text, "rule_png": rule_png,
        "shap": shap_paths,
        "threshold": cls["threshold"],
        "model_spam": cls["model_spam"],
        "rule_spam": cls["rule_spam"],
        "which_model_for_threshold": cls["which_model_for_th"],
    })

@api.get("/log")
def api_log(x_api_key: str = Depends(require_api_key),
            limit: int = Query(default=100, ge=1, le=1000),
            search: Optional[str] = Query(default=None)):
    with Session(engine) as sess:
        q = sa.select(SMSRecord).where((SMSRecord.api_key == x_api_key) | (SMSRecord.api_key.is_(None)))\
                                .order_by(SMSRecord.id.desc()).limit(limit)
        rows = sess.execute(q).scalars().all()
    items = []
    for r in rows:
        row = _serialize_record(r)
        short = (row["raw_text"] or "") + ("\n" + (row["ocr_text"] or "")) if row["ocr_text"] else (row["raw_text"] or "")
        short = (short or "").strip().replace("\n", " ")
        if len(short) > 120: short = short[:117] + "..."
        items.append({
            "id": row["id"], "time": row["created_at"], "img?": r.is_image,
            "lang": row["lang"], "pred": row["label"], "spam_prob": f"{row['spam_prob']:.3f}", "text": short,
        })
    df = pd.DataFrame(items, columns=["id","time","img?","lang","pred","spam_prob","text"])
    if search_str:
        m = str(search_str).lower()
        df = df[df.apply(lambda s: any(m in str(x).lower() for x in s.values), axis=1)]
    return {"items": items}

@api.post("/flag")
def api_flag(x_api_key: str = Depends(require_api_key),
             record_id: int = Form(...),
             flag_type: str = Form(...),
             note: Optional[str] = Form(default=None)):
    if flag_type not in {"false_positive","false_negative","other"}:
        raise HTTPException(status_code=400, detail="flag_type must be false_positive/false_negative/other")
    with Session(engine) as sess:
        f = FlagRecord(record_id=record_id, flag_type=flag_type, note=note, api_key=x_api_key)
        sess.add(f); sess.commit(); sess.refresh(f)
    return {"status":"ok","flag_id":f.id}


# Launch API in a thread

def _run_api(port: int):
    uvicorn.run(api, host="0.0.0.0", port=port, log_level="info")

api_thread = threading.Thread(target=_run_api, args=(API_PORT,), daemon=True)
api_thread.start()

# Gradio UI

def ui_predict(sms_text, img, ocr_langs):
    is_image = img is not None; ocr_text = None
    if is_image:
        pil = Image.fromarray(img)
        buf = io.BytesIO(); pil.convert("RGB").save(buf, format="PNG")
        ocr_text = ocr_image_bytes(buf.getvalue(), ocr_langs=ocr_langs or OCR_LANGS_DEFAULT)
        sms_text = (sms_text or "") + ("\n" if sms_text else "") + (ocr_text or "")
    if not (sms_text or "").strip():
        return None, None, "Please provide SMS text or an image.", None, "", None, None

    cls = classify_text(sms_text)
    rec_id = _persist_record({
        "api_key": None, "is_image": bool(is_image),
        "raw_text": None if ocr_text else sms_text, "ocr_text": ocr_text,
        "preprocessed": cls["preprocessed"], "lang": cls["lang_hint"], "label": cls["label"],
        "spam_prob": cls["spam_prob"], "not_spam_prob": cls["not_spam_prob"],
    })

    rule_png = make_rule_evidence_plot(cls["preprocessed"], cls["lang_hint"], rec_id)
    paths = make_shap_artifacts(cls["preprocessed"], rec_id, class_index=0)

    with Session(engine) as sess:
        rec = sess.get(SMSRecord, rec_id)
        rec.rule_png_path = rule_png
        rec.shap_top_tokens = paths["top_tokens_json"]
        rec.shap_bar_path = paths["bar_png"]
        rec.shap_text_path = paths["text_html"]
        rec.shap_force_path = paths["force_png"]
        sess.commit()

    text_html_str = open(paths["text_html"], "r", encoding="utf-8").read()
    pred_str = (
        f"*Prediction:* {cls['label']}  \n"
        f"*spam:* {cls['spam_prob']:.3f} | *not_spam:* {cls['not_spam_prob']:.3f} "
        f"| *lang:* {cls['lang_hint']} | *threshold:* {cls['threshold']:.2f} ({cls['which_model_for_th']})  \n"
        f"model_spam={cls['model_spam']:.3f}, rule_spam={cls['rule_spam']:.3f}"
    )
    return (img if is_image else None), (ocr_text or ""), pred_str, paths["bar_png"], text_html_str, paths["force_png"], rule_png

def ui_load_history(search_str):
    with Session(engine) as sess:
        q = sa.select(SMSRecord).order_by(SMSRecord.id.desc()).limit(300)
        rows = sess.execute(q).scalars().all()
    items = []
    for r in rows:
        row = _serialize_record(r)
        short = (row["raw_text"] or "") + ("\n" + (row["ocr_text"] or "")) if row["ocr_text"] else (row["raw_text"] or "")
        short = (short or "").strip().replace("\n", " ")
        if len(short) > 120: short = short[:117] + "..."
        items.append({
            "id": row["id"], "time": row["created_at"], "img?": r.is_image,
            "lang": row["lang"], "pred": row["label"], "spam_prob": f"{row['spam_prob']:.3f}", "text": short,
        })
    df = pd.DataFrame(items, columns=["id","time","img?","lang","pred","spam_prob","text"])
    if search_str:
        m = str(search_str).lower()
        df = df[df.apply(lambda s: any(m in str(x).lower() for x in s.values), axis=1)]
    return df

def ui_benchmark():
    samples = [
        ("en", "Congratulations! You won a prize. Click <URL> to claim now."),
        ("en", "Team catch-up at 5 PM. Agenda attached."),
        ("en", 'Did you hear about the new "Divorce Barbie"? It comes with all of Ken\'s stuff!'),
        ("en", "FreeMsg Hey! std chgs to send, £1.50 to rcv. Text WIN to 80085."),
        ("en", "You are subscribed to the best mobile content service in the UK for £3 per ten days until you send STOP to 83435. Helpline 08706091795."),
        ("hi", "बधाई हो! आपने इनाम जीता है। दावा करने के लिए <URL> पर क्लिक करें।"),
        ("es", "¡Felicitaciones! Has ganado un premio. Haz clic en <URL> para reclamarlo."),
        ("ta", "வாழ்த்துகள்! நீங்கள் பரிசு வென்றுள்ளீர்கள். <URL> கிளிக் செய்யவும்."),
        ("mr", "अभिनंदन! तुम्ही बक्षीस जिंकले आहे. दावा करण्यासाठी <URL> वर क्लिक करा."),
    ]
    rows = []
    for lang_hint, txt in samples:
        res = classify_text(txt)
        rows.append({
            "lang": lang_hint, "text": txt, "label": res["label"],
            "spam_prob": float(res["spam_prob"]), "not_spam_prob": float(res["not_spam_prob"]),
        })
    return pd.DataFrame(rows, columns=["lang","text","label","spam_prob","not_spam_prob"])

with gr.Blocks(title="📬 Multilingual SMS/Image Spam Detector") as demo:
    gr.Markdown("### Multilingual SMS/Image Spam Detector\n- XLM-R zero-shot + optional student\n- OCR (Tesseract → EasyOCR)\n- SHAP (bar + force) and Rule Evidence (always shown)\n- Calibrated ensemble & hard overrides (incl. Divorce Barbie)\n- Searchable history")
    with gr.Tabs():
        with gr.Tab("Predict"):
            with gr.Row():
                sms_in = gr.Textbox(label="SMS Text", lines=4, placeholder="Paste SMS here...")
                img_in = gr.Image(label="Drag/Drop Image (optional)", type="numpy", sources=["upload","clipboard"])
            ocr_langs_in = gr.Textbox(label="OCR Languages (Tesseract codes)", value=OCR_LANGS_DEFAULT)
            run_btn = gr.Button("Predict", variant="primary")
            with gr.Row():
                img_out = gr.Image(label="Image Preview", type="numpy")
                ocr_out = gr.Textbox(label="OCR Text", lines=5)
            pred_out = gr.Markdown(label="Prediction")
            bar_out = gr.Image(label="SHAP Summary (Bar)", type="filepath")
            text_html_out = gr.HTML(label="SHAP Text Highlight")
            force_png_out = gr.Image(label="SHAP Force Plot", type="filepath")
            rule_png_out = gr.Image(label="Rule Evidence", type="filepath")
            run_btn.click(
                ui_predict,
                inputs=[sms_in, img_in, ocr_langs_in],
                outputs=[img_out, ocr_out, pred_out, bar_out, text_html_out, force_png_out, rule_png_out]
            )
        with gr.Tab("History"):
            search_in = gr.Textbox(label="Search", placeholder="text / lang / label ...")
            table_out = gr.Dataframe(label="Recent records", interactive=False)
            gr.Button("Load").click(ui_load_history, inputs=[search_in], outputs=[table_out])
        with gr.Tab("Benchmark"):
            bench_btn = gr.Button("Run multilingual benchmark")
            bench_table = gr.Dataframe(label="Benchmark", interactive=False)
            bench_btn.click(ui_benchmark, inputs=None, outputs=[bench_table])

gr.close_all()
SHARE = os.environ.get("GRADIO_SHARE", "true").lower() == "true"
IN_COLAB = "COLAB_RELEASE_TAG" in os.environ
ret = demo.queue().launch(
    share=(SHARE or IN_COLAB),
    server_name="0.0.0.0" if (SHARE or IN_COLAB) else None,
    server_port=None,
    prevent_thread_lock=True,
    debug=False
)

try:
    url_to_show = getattr(ret, "share_url", None) or getattr(ret, "local_url", None)
except Exception:
    if isinstance(ret, tuple) and len(ret) >= 2:
        _, url_to_show, *rest = ret
    else:
        url_to_show = None

print("\n*****************")
print("The UI is running")
print(" Gradio URL:", url_to_show or "See Gradio output above")
print("**************")
print(f"API base (local): http://127.0.0.1:{API_PORT}")
print("   Use header: X-API-Key:", API_KEY)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/734 [00:00<?, ?B/s]

sentencepiece.bpe.model:   0%|          | 0.00/5.07M [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/150 [00:00<?, ?B/s]

`torch_dtype` is deprecated! Use `dtype` instead!
`torch_dtype` is deprecated! Use `dtype` instead!


model.safetensors:   0%|          | 0.00/2.24G [00:00<?, ?B/s]

Some weights of the model checkpoint at joeddav/xlm-roberta-large-xnli were not used when initializing XLMRobertaForSequenceClassification: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
- This IS expected if you are initializing XLMRobertaForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing XLMRobertaForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Device set to use cuda:0


config.json:   0%|          | 0.00/409 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

pytorch_model.bin:   0%|          | 0.00/62.7M [00:00<?, ?B/s]

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at huawei-noah/TinyBERT_General_4L_312D and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


model.safetensors:   0%|          | 0.00/62.7M [00:00<?, ?B/s]

Colab notebook detected. To show errors in colab notebook, set debug=True in launch()
* Running on public URL: https://566a3c639a2881cd1a.gradio.live

This share link expires in 1 week. For free permanent hosting and GPU upgrades, run `gradio deploy` from the terminal in the working directory to deploy to Hugging Face Spaces (https://huggingface.co/spaces)



*****************
The UI is running
 Gradio URL: See Gradio output above
**************
API base (local): http://127.0.0.1:8000
   Use header: X-API-Key: devkey
