In [4]:
import joblib
import json
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification
import re

# -------- 1) Load detector --------
detector_pipeline = joblib.load("ood_detector/detector_pipeline.joblib")

# -------- 2) Load MARBERT classifier --------
model_dir = "marbret_intent_classifier"
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(model_dir)

with open(f"{model_dir}/label_mapping.json", "r", encoding="utf-8") as f:
    mapping = json.load(f)
label_to_id = mapping["label_to_id"]
id_to_label = {int(k): v for k, v in mapping["id_to_label"].items()}  # keys may be strings

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
model.eval()




BertForSequenceClassification(
  (bert): BertModel(
    (embeddings): BertEmbeddings(
      (word_embeddings): Embedding(100000, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): BertEncoder(
      (layer): ModuleList(
        (0-11): 12 x BertLayer(
          (attention): BertAttention(
            (self): BertSdpaSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): BertSelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), eps=1

In [5]:
def clean_text(text):
    if not isinstance(text, str):
        return ""
    
    # Lowercase Latin characters
    text = text.lower()
    
    # Remove URLS
    text = re.sub(r'http\S+|www\S+|https\S+', '', text, flags=re.MULTILINE)
    
    # Remove Emojis and special chars (keep basic punctuation)
    text = re.sub(r'[^\w\s\u0600-\u06FF]', ' ', text)
    
    # Arabic Normalization
    text = re.sub("[إأآا]", "ا", text)
    text = re.sub("ى", "ي", text)
    text = re.sub("ؤ", "ء", text)
    text = re.sub("ئ", "ء", text)
    text = re.sub("ة", "ه", text)
    text = re.sub("گ", "ك", text)
    
    # Remove Tashkeel
    tashkeel = re.compile(r'[\u064B-\u0652]')
    text = re.sub(tashkeel, "", text)
    
    # Remove longation
    text = re.sub(r'(.)\1+', r'\1\1', text)
    
    # Remove multiple spaces
    text = re.sub(r'\s+', ' ', text).strip()
    
    return text


def is_out_of_domain(text, threshold=None):
    cleaned = clean_text(text)
    if hasattr(detector_pipeline, "predict_proba") and threshold is not None:
        probs = detector_pipeline.predict_proba([cleaned])[0]
        classes = list(detector_pipeline.classes_)
        idx_ood = classes.index("out_of_domain")
        p_ood = probs[idx_ood]
        return p_ood >= threshold
    else:
        return detector_pipeline.predict([cleaned])[0] == "out_of_domain"


def classify_intent(text, ood_threshold=0.5):
    # 1) Out-of-domain detection
    if is_out_of_domain(text, threshold=ood_threshold):
        return "Out of context"

    # 2) In-domain: MARBERT intent
    encoding = tokenizer(
        text,
        add_special_tokens=True,
        max_length=128,
        truncation=True,
        padding="max_length",
        return_tensors="pt"
    )
    encoding = {k: v.to(device) for k, v in encoding.items()}

    with torch.no_grad():
        outputs = model(**encoding)
        logits = outputs.logits
        pred_id = torch.argmax(logits, dim=-1).item()

    return id_to_label[pred_id]


### usage


In [6]:
texts = [
    "حاب نبرا من لادروك عاونوني",
    "ما هي أعراض انسحاب الكوكايين؟",
    "راني حاب نشرب قاع الدوا لي عندي باش نرقد وما نوضش",
    "كيفاش راهي حالة الطقس في وهران؟"
]

for t in texts:
    print(t, "->", classify_intent(t, ood_threshold=0.5))


حاب نبرا من لادروك عاونوني -> Looking for support
ما هي أعراض انسحاب الكوكايين؟ -> Exact fact
راني حاب نشرب قاع الدوا لي عندي باش نرقد وما نوضش -> Harm
كيفاش راهي حالة الطقس في وهران؟ -> Exact fact
