Two–stage ONNX inference:
  1) category  → 2) product-within-category
Assumes directory layout:

models/
├── tokenizer/                     (save_pretrained)
├── category_model/
│   ├── tokenizer/
│   ├── model_quantized.onnx
│   └── label_encoder.pkl
├── cards/
│   ├── model_quantized.onnx
│   └── label_encoder.pkl
├── deposits/
│   └── ...
└── …

Если под-модель для категории отсутствует, возвращается "<category>_common".

In [None]:
!pip install onnxruntime transformers scikit-learn numpy onnx torch

Collecting onnxruntime
  Downloading onnxruntime-1.22.0-cp312-cp312-win_amd64.whl.metadata (5.0 kB)
Collecting transformers
  Using cached transformers-4.52.4-py3-none-any.whl.metadata (38 kB)
Collecting coloredlogs (from onnxruntime)
  Using cached coloredlogs-15.0.1-py2.py3-none-any.whl.metadata (12 kB)
Collecting flatbuffers (from onnxruntime)
  Using cached flatbuffers-25.2.10-py2.py3-none-any.whl.metadata (875 bytes)
Collecting huggingface-hub<1.0,>=0.30.0 (from transformers)
  Downloading huggingface_hub-0.32.6-py3-none-any.whl.metadata (14 kB)
Collecting tokenizers<0.22,>=0.21 (from transformers)
  Using cached tokenizers-0.21.1-cp39-abi3-win_amd64.whl.metadata (6.9 kB)
Collecting safetensors>=0.4.3 (from transformers)
  Using cached safetensors-0.5.3-cp38-abi3-win_amd64.whl.metadata (3.9 kB)
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Using cached humanfriendly-10.0-py2.py3-none-any.whl.metadata (9.2 kB)
Collecting pyreadline3 (from humanfriendly>=9.1->color

In [5]:
from pathlib import Path
import numpy as np
import joblib
import onnxruntime as ort
from transformers import AutoTokenizer


In [6]:
# Укажи путь до models/
MODELS_DIR     = Path("models")  # или полный путь

TOK_DIR        = MODELS_DIR / "tokenizer"
CAT_DIR        = MODELS_DIR / "category_model"
CAT_TOK_DIR    = CAT_DIR / "tokenizer"
CAT_MODEL_PATH = CAT_DIR / "model_quantized.onnx"
CAT_LABELS     = CAT_DIR / "label_encoder.pkl"

In [20]:
class TwoStageClassifier:
    def __init__(self):
        self.cat_tok = AutoTokenizer.from_pretrained(CAT_TOK_DIR)
        self.prod_tok = AutoTokenizer.from_pretrained(TOK_DIR)
        self.cat_sess = ort.InferenceSession(str(CAT_MODEL_PATH))
        self.cat_enc = joblib.load(CAT_LABELS)
        self.prod_cache: dict[str, tuple[ort.InferenceSession, joblib]] = {}

    def _run(self, sess, text, tokenizer, max_len=128):
        enc = tokenizer(
            text,
            return_tensors="np",
            padding="max_length",
            truncation=True,
            max_length=max_len
        )
        return sess.run(["logits"], {
            "input_ids": enc["input_ids"].astype("int64"),
            "attention_mask": enc["attention_mask"].astype("int64")
        })[0]

    def predict(self, text: str) -> dict:
        cat_logits = self._run(self.cat_sess, text, tokenizer=self.cat_tok)
        cat_id = int(np.argmax(cat_logits, axis=1)[0])
        category = self.cat_enc.inverse_transform([cat_id])[0]

        folder = category.lower().replace(" ", "_")
        folder_path = MODELS_DIR / folder

        if not folder_path.exists():
            return {"category": category, "product": f"{folder}_common"}

        if folder not in self.prod_cache:
            model_path = folder_path / "model_quantized.onnx"
            labels_path = folder_path / "label_encoder.pkl"
            if not model_path.exists() or not labels_path.exists():
                return {"category": category, "product": f"{folder}_common"}
            sess = ort.InferenceSession(str(model_path))
            enc = joblib.load(labels_path)
            self.prod_cache[folder] = (sess, enc)

        sess, enc = self.prod_cache[folder]
        prod_logits = self._run(sess, text, tokenizer=self.prod_tok)
        prod_id = int(np.argmax(prod_logits, axis=1)[0])
        product = enc.inverse_transform([prod_id])[0]

        return {"category": category, "product": product}


In [21]:
clf = TwoStageClassifier()

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


In [25]:
examples = [
    "как заказать карту?",
    "можно ли открыть вклад в приложении?",
    "хочу оплатить штраф",
    "Работает ли QR оплата для коммунальных услуг?",
    "можно ли взять автокредит?",
    "как взять кредит под залог недвижимости"
]

for text in examples:
    result = clf.predict(text)
    print(f"🟢 {text}\n→ {result}\n")


🟢 как заказать карту?
→ {'category': 'Cards', 'product': 'cards_common'}

🟢 можно ли открыть вклад в приложении?
→ {'category': 'Deposits', 'product': 'deposits_common'}

🟢 хочу оплатить штраф
→ {'category': 'Other', 'product': 'operator'}

🟢 Работает ли QR оплата для коммунальных услуг?
→ {'category': 'Payments', 'product': 'payments_common'}

🟢 можно ли взять автокредит?
→ {'category': 'Avtokredit', 'product': 'avtokredit_common'}

🟢 как взять кредит под залог недвижимости
→ {'category': 'Zalogovoe', 'product': 'zalog_nedvizhimosti'}



In [None]:
# inference.py

from pathlib import Path
import argparse, joblib, numpy as np, onnxruntime as ort
from transformers import AutoTokenizer


# ----------------- ПУТИ К ФАЙЛАМ -----------------
MODELS_DIR     = Path(__file__).resolve().parent / "models"
TOK_DIR        = MODELS_DIR / "tokenizer"                      # общий для продуктов
CAT_DIR        = MODELS_DIR / "category_model"
CAT_TOK_DIR    = CAT_DIR / "tokenizer"                         # токенайзер для категорий
CAT_MODEL_PATH = CAT_DIR / "model_quantized.onnx"
CAT_LABELS     = CAT_DIR / "label_encoder.pkl"


class TwoStageClassifier:
    def __init__(self):
        # 🔹 Токенизаторы
        self.cat_tok = AutoTokenizer.from_pretrained(CAT_TOK_DIR)  # категоризация
        self.prod_tok = AutoTokenizer.from_pretrained(TOK_DIR)     # продукты

        # 🔹 ONNX сессия + LabelEncoder (категории)
        self.cat_sess = ort.InferenceSession(str(CAT_MODEL_PATH))
        self.cat_enc = joblib.load(CAT_LABELS)

        # 🔹 кеш {folder → (session, encoder)}
        self.prod_cache: dict[str, tuple[ort.InferenceSession, joblib]] = {}

    def _run(self, sess, text, tokenizer, max_len=128):
        enc = tokenizer(text, return_tensors="np", padding="max_length",
                        truncation=True, max_length=max_len)
        return sess.run(["logits"], {
            "input_ids": enc["input_ids"],
            "attention_mask": enc["attention_mask"]
        })[0]

    def predict(self, text: str) -> dict:
        # === ① Категория ===
        cat_logits = self._run(self.cat_sess, text, tokenizer=self.cat_tok)
        cat_id = int(np.argmax(cat_logits, axis=1)[0])
        category = self.cat_enc.inverse_transform([cat_id])[0]

        folder = category.lower().replace(" ", "_")
        folder_path = MODELS_DIR / folder

        # === ② Продукт ===
        if not folder_path.exists():
            return {"category": category, "product": f"{folder}_common"}

        if folder not in self.prod_cache:
            model_path = folder_path / "model_quantized.onnx"
            labels_path = folder_path / "label_encoder.pkl"
            if not model_path.exists() or not labels_path.exists():
                return {"category": category, "product": f"{folder}_common"}

            sess = ort.InferenceSession(str(model_path))
            enc = joblib.load(labels_path)
            self.prod_cache[folder] = (sess, enc)

        sess, enc = self.prod_cache[folder]
        prod_logits = self._run(sess, text, tokenizer=self.prod_tok)
        prod_id = int(np.argmax(prod_logits, axis=1)[0])
        product = enc.inverse_transform([prod_id])[0]

        return {"category": category, "product": product}


# ----------------- CLI запуск -----------------
def main():
    parser = argparse.ArgumentParser(description="Two-stage ONNX classifier")
    parser.add_argument("text", nargs="*", help="Input query")
    args = parser.parse_args()

    if args.text:
        query = " ".join(args.text)
    else:
        query = input("Введите запрос: ")

    print(f"[INFO] Ввод: {query}")

    clf = TwoStageClassifier()
    result = clf.predict(query)

    if not result:
        print("[WARNING] Модель не вернула результат")
    else:
        print("[RESULT]", result)


if __name__ == "__main__":
    main()