# Query Processing (Expansion + Intent Classification with FinBERT)

This notebook builds the query processing module with expansion/paraphrasing and FinBERT-based intent classification.


## 0. Environment & Dependencies
Toggles to control whether we use Flan-T5 and Sentence-BERT.


In [None]:

# Toggles and model names
USE_FLAN_T5 = True
USE_SBERT = True
USE_FINBERT_INTENT = True

FINBERT_MODEL_NAME = "ProsusAI/finbert"
FLAN_T5_MODEL_NAME = "google/flan-t5-base"
SBERT_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"


## 1. Utilities & Normalization

In [None]:

#!pip install -r requirements.txt
import re, unicodedata
from typing import List, Dict, Optional, Any

def normalize(text: str) -> str:
    text = unicodedata.normalize("NFKC", text).strip()
    text = re.sub(r"\s+", " ", text)
    return text

def simple_tokenize(text: str) -> List[str]:
    return re.findall(r"[A-Za-z0-9&$%\.\-]+", text)

def extract_entities(text: str) -> Dict[str, List[str]]:
    out: Dict[str, List[str]] = {}
    low = text.lower()

    q = re.findall(r"\b(q[1-4])\s*([12][0-9]{3})\b", low)
    if q:
        out["quarter"] = [f"{p.upper()} {y}" for p, y in q]

    years = re.findall(r"\b(20[0-4][0-9]|19[0-9]{2})\b", text)
    if years:
        out.setdefault("year", list(sorted(set(years))))

    tickers = re.findall(r"\b[A-Z]{1,5}\b", text.upper())
    tickers = [t for t in tickers if len(t) <= 5 and not t.startswith("Q")]
    if tickers:
        out["ticker"] = list(sorted(set(tickers)))

    return out

DOMAIN_SYNONYMS = {
    "risk": ["risk factor","risk factors","uncertainty","exposure","threat"],
    "cyber": ["cybersecurity","information security","infosec","data breach","security incident"],
    "performance": ["revenue","growth","margin","profit","loss","guidance","results"],
    "strategy": ["roadmap","plan","initiative","expansion","capex","restructuring","acquisition"],
    "md&a": ["management discussion","md&a","results of operations"],
}


## 2. Sentence-BERT for Semantic Expansion & Embeddings

In [None]:

try:
    if USE_SBERT:
        from sentence_transformers import SentenceTransformer, util
        sbert = SentenceTransformer(SBERT_MODEL_NAME)
    else:
        sbert = None
except Exception as e:
    print("Sentence-BERT not available:", e)
    sbert = None

from typing import List, Dict, Optional, Any

def sbert_embed(text: str) -> Optional[List[float]]:
    if sbert is None:
        return None
    vec = sbert.encode([text], normalize_embeddings=True)[0]
    return vec.tolist()

def semantic_expand(base_terms: List[str], k: int = 5) -> List[str]:
    if sbert is None:
        return []
    vocab = sorted({t for syns in DOMAIN_SYNONYMS.values() for t in syns})
    if not vocab:
        return []
    base_vecs = sbert.encode(base_terms, normalize_embeddings=True) if base_terms else []
    vocab_vecs = sbert.encode(vocab, normalize_embeddings=True)
    expansions = set()
    for bv in base_vecs:
        sims = util.cos_sim(bv, vocab_vecs).squeeze(0).tolist()
        top_idx = sorted(range(len(vocab)), key=lambda i: sims[i], reverse=True)[:k]
        for i in top_idx:
            expansions.add(vocab[i])
    return list(sorted(expansions))


## 3. Flan-T5 for Paraphrase-based Expansion

In [None]:
import torch
from typing import List
try:
    if USE_FLAN_T5:
        from transformers import AutoTokenizer, AutoModelForSeq2SeqLM

        FLAN_T5_MODEL_NAME = FLAN_T5_MODEL_NAME if "FLAN_T5_MODEL_NAME" in globals() else "google/flan-t5-small"
        flan_device = "cuda" if torch.cuda.is_available() else "cpu"
        flan_dtype = torch.float16 if torch.cuda.is_available() else torch.float32

        flan_tok = AutoTokenizer.from_pretrained(FLAN_T5_MODEL_NAME)
        flan_mdl = AutoModelForSeq2SeqLM.from_pretrained(
            FLAN_T5_MODEL_NAME,
            torch_dtype=flan_dtype
        ).to(flan_device)
        flan_mdl.eval()
    else:
        flan_tok = None
        flan_mdl = None
        flan_device = "cpu"
except Exception as e:
    print("Flan-T5 not available:", e)
    flan_tok = None
    flan_mdl = None
    flan_device = "cpu"

def t5_paraphrases_safe(
    query: str,
    num_return: int = 5,
    max_new_tokens: int = 48,
    max_input_tokens: int = 128
) -> list[str]:
    if not (USE_FLAN_T5 and 'flan_tok' in globals() and flan_tok is not None and 'flan_mdl' in globals() and flan_mdl is not None):
        return []

    prompt = f"Paraphrase the user query in multiple diverse ways. Keep the meaning and keep it concise.\nQuery: {query}"
    x = flan_tok(prompt, return_tensors="pt", padding=True, truncation=True, max_length=max_input_tokens)
    x = {k: v.to(flan_device) for k, v in x.items()}

    # sampling (no beams) tends to be more diverse than beam search
    with torch.no_grad():
        out = flan_mdl.generate(
            **x,
            do_sample=True,
            top_k=50,
            top_p=0.92,
            temperature=0.9,
            num_return_sequences=num_return,
            max_new_tokens=max_new_tokens,
            repetition_penalty=1.1,
            no_repeat_ngram_size=3,
            early_stopping=True
        )

    paras = flan_tok.batch_decode(out, skip_special_tokens=True)
    # drop clones of the input (case/punct insensitive)
    base = re.sub(r"\W+", " ", query).strip().lower()
    seen, uniq = set(), []
    for p in paras:
        p2 = normalize(p)
        p2_cmp = re.sub(r"\W+", " ", p2).strip().lower()
        if p2_cmp == base:
            continue
        if p2 and p2 not in seen:
            seen.add(p2)
            uniq.append(p2)
    return uniq


    # Try full settings, then back off on OOM or other generation errors
    try:
        return _generate(num_return, max_new_tokens, num_beams)
    except RuntimeError as e:
        if "out of memory" in str(e).lower():
            torch.cuda.empty_cache() if torch.cuda.is_available() else None
            try:
                # lighter retry
                return _generate(min(3, num_return), min(32, max_new_tokens), max(4, min(num_beams, 4)))
            except Exception:
                return []
        else:
            # any other transient error: one lighter retry
            try:
                return _generate(min(3, num_return), min(32, max_new_tokens), max(4, num_beams))
            except Exception:
                return []
    except Exception:
        return []

In [None]:
TICKER_STOPWORDS = {"THE","AND","FOR","WITH","WHAT","NEW","RISKS","DID","IN","OF","ON","BY","LAST","VS","S","DATA"}

def extract_entities(raw_text: str) -> Dict[str, List[str]]:
    out: Dict[str, List[str]] = {}

    # Quarter/Year (from lower)
    low = raw_text.lower()
    q = re.findall(r"\b(q[1-4])\s*([12][0-9]{3})\b", low)
    if q:
        out["quarter"] = [f"{p.upper()} {y}" for p, y in q]

    years = re.findall(r"\b(20[0-4][0-9]|19[0-9]{2})\b", raw_text)
    if years:
        out.setdefault("year", sorted(set(years)))

    # Tickers ONLY from raw patterns: $TSLA, (TSLA), NASDAQ: TSLA, NYSE: TSLA, or clean ALL-CAPS (stopwords filtered)
    cands = set()
    cands.update(re.findall(r"\$([A-Z]{1,5})\b", raw_text))
    cands.update(re.findall(r"\(([A-Z]{1,5})\)", raw_text))
    cands.update(re.findall(r"\b(?:NASDAQ|NYSE)\s*:\s*([A-Z]{1,5})\b", raw_text))
    for m in re.findall(r"\b[A-Z]{2,5}\b", raw_text):
        if m not in TICKER_STOPWORDS:
            cands.add(m)
    if cands:
        out["ticker"] = sorted(cands)

    return out

# quick probe
print(extract_entities("What new cyber risks did Tesla disclose in Q2 2024?"))
# expect: {'quarter': ['Q2 2024'], 'year': ['2024']} (no ticker)


{'quarter': ['Q2 2024'], 'year': ['2024']}


## 4. Intent Classification with FinBERT (adapter or fine-tune)

In [None]:
from sklearn.linear_model import LogisticRegression
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import Pipeline
from typing import Tuple

INTENT_LABELS = ["risk","performance","strategy"]

# tiny seed so it actually trains; add more later
X_train = [
    "What new risk factors were disclosed?",
    "Cybersecurity breach details for Tesla",
    "Explain Apple revenue growth and margins",
    "Compare Microsoft profit guidance last quarter",
    "Outline Nvidia expansion strategy in data centers",
    "What restructuring plan is management proposing?"
]
y_train = ["risk","risk","performance","performance","strategy","strategy"]

intent_clf = Pipeline([
    ("tfidf", TfidfVectorizer(ngram_range=(1,2), min_df=1)),
    ("lr", LogisticRegression(max_iter=300, class_weight="balanced", multi_class="ovr"))
]).fit(X_train, y_train)

RISK_KW = {"risk","risk factor","risk factors","uncertainty","cyber","cybersecurity","breach","litigation","security"}
PERF_KW = {"revenue","growth","margin","profit","loss","guidance","results","compare","last quarter","quarterly"}
STRAT_KW= {"strategy","plan","roadmap","expansion","acquisition","restructuring","capex","data center","data centers"}

def _kw_score(t: str, kws: set[str]) -> int:
    return sum(1 for k in kws if k in t)

def classify_intent(text: str) -> Tuple[str, float]:
    tx = normalize(text)
    proba = intent_clf.predict_proba([tx])[0].tolist()  # [risk, perf, strat] due to fit order
    k_r = _kw_score(tx, RISK_KW); k_p = _kw_score(tx, PERF_KW); k_s = _kw_score(tx, STRAT_KW)
    k_sum = max(1, (k_r + k_p + k_s))
    priors = [k_r/k_sum, k_p/k_sum, k_s/k_sum]
    alpha, beta = 0.6, 0.4
    blended = [alpha*proba[i] + beta*priors[i] for i in range(3)]
    s = sum(blended) or 1.0
    blended = [b/s for b in blended]
    idx = max(range(3), key=lambda i: blended[i])
    return INTENT_LABELS[idx], float(blended[idx])




### 4.A Optional: Fine-tune FinBERT with HF Trainer

In [None]:

# Uncomment to fine-tune FinBERT. After training, replace classify_intent with a wrapper
# that calls the trained model to get label and confidence.

# from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
# from datasets import Dataset
# import numpy as np, torch
# tokenizer = AutoTokenizer.from_pretrained(FINBERT_MODEL_NAME)
# model = AutoModelForSequenceClassification.from_pretrained(FINBERT_MODEL_NAME, num_labels=len(INTENT_LABELS))
# # Build Dataset from your dataframe df with columns 'text' and 'label'.
# # Train and evaluate, then define:
# # def finbert_predict(text: str) -> Tuple[str, float]: ...
# # classify_intent = finbert_predict


## 5. Expansion functions (lexical, semantic, paraphrase)

In [None]:
def keyword_expand(tokens: list[str]) -> list[str]:
    ex = []
    for t in tokens:
        t0 = t.strip(".-").lower()
        ex.extend(DOMAIN_SYNONYMS.get(t0, []))
    # dedup preserve order
    seen, out = set(), []
    for w in ex:
        if w not in seen:
            seen.add(w); out.append(w)
    return out

def build_keywords(tokens: list[str], expansions: list[str]) -> list[str]:
    kept = []
    for t in tokens + expansions:
        if re.search(r"[a-z0-9]", t) and t not in kept:
            kept.append(t.lower())
    return kept

def expand_query(query: str) -> dict:
    norm = normalize(query)
    toks = simple_tokenize(norm)

    # lexical expansions
    lex_ex = keyword_expand(toks)

    # T5 paraphrases (diverse)
    paras = t5_paraphrases_safe(norm, num_return=5, max_new_tokens=48) if USE_FLAN_T5 else []

    # mine paraphrase tokens → expand via domain synonyms again
    para_tokens = []
    for p in paras:
        para_tokens.extend(simple_tokenize(p))
    para_tokens = list(dict.fromkeys(para_tokens))
    para_ex = keyword_expand(para_tokens) if para_tokens else []

    # merge
    expansions = []
    for lst in (lex_ex, para_ex):
        for w in lst:
            if w not in expansions:
                expansions.append(w)

    keywords = build_keywords(toks, expansions)
    return {
        "normalized": norm,
        "tokens": toks,
        "expansions": expansions,
        "paraphrases": paras,
        "keywords": keywords
    }


## 6. QueryProcessor class

In [None]:
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional

@dataclass
class QueryProcessorConfig:
    labels: List[str] = field(default_factory=lambda: ["risk","performance","strategy"])
    use_flan_t5: bool = False
    use_sbert: bool = False

class QueryProcessor:
    def __init__(self, config: QueryProcessorConfig = QueryProcessorConfig()):
        self.config = config

    def process(self, query: str) -> Dict[str, Any]:
        ex = expand_query(query)
        ents = extract_entities(ex["normalized"])
        label, conf = classify_intent(ex["normalized"])
        emb = sbert_embed(ex["normalized"]) if USE_SBERT else None
        return {
            "normalized": ex["normalized"],
            "label": label,
            "confidence": conf,
            "expansions": ex["expansions"],
            "paraphrases": ex["paraphrases"],
            "keywords": ex["keywords"],
            "entities": ents,
            "filters": ents.copy(),
            "embedding": emb
        }


## 7. Smoke tests

In [None]:
qp = QueryProcessor()

tests = [
    "What new cyber risks did Tesla disclose in Q2 2024?",
    "Compare Apple's revenue growth vs Microsoft last quarter",
    "Outline Nvidia's expansion strategy in data centers"
]

for t in tests:
    out = qp.process(t)
    print("\nQ:", t)
    print("Label:", out["label"], "Conf:", round(out["confidence"], 3))
    print("Filters:", out["filters"])
    print("Top keywords:", out["keywords"][:12])
    print("Paraphrase sample:", (out["paraphrases"][:1] if out["paraphrases"] else []))
    print("Embedding:", "present" if out["embedding"] is not None else "None")


The following generation flags are not valid and may be ignored: ['early_stopping']. Set `TRANSFORMERS_VERBOSITY=info` for more details.



Q: What new cyber risks did Tesla disclose in Q2 2024?
Label: risk Conf: 0.577
Filters: {'quarter': ['Q2 2024'], 'year': ['2024']}
Top keywords: ['what', 'new', 'cyber', 'risks', 'did', 'tesla', 'disclose', 'in', 'q2', '2024', 'cybersecurity', 'information security']
Paraphrase sample: ['What new cyber risks did Tesla reveal in Q2 2024?']
Embedding: present

Q: Compare Apple's revenue growth vs Microsoft last quarter
Label: performance Conf: 0.57
Filters: {}
Top keywords: ['compare', 'apple', 's', 'revenue', 'growth', 'vs', 'microsoft', 'last', 'quarter']
Paraphrase sample: ["Apple's revenue growth was 35% higher last quarter compared to Microsoft's."]
Embedding: present

Q: Outline Nvidia's expansion strategy in data centers
Label: strategy Conf: 0.662
Filters: {}
Top keywords: ['outline', 'nvidia', 's', 'expansion', 'strategy', 'in', 'data', 'centers', 'roadmap', 'plan', 'initiative', 'capex']
Paraphrase sample: ["The user can find Nvidia's data center expansion strategy in the foll

## 8. Integration into RAG

In [None]:
# === write query_processor.py (importable) ===
from textwrap import dedent

code = dedent(r'''
# Lightweight export of your QueryProcessor so other notebooks can import it
import re, unicodedata
from dataclasses import dataclass, field
from typing import List, Dict, Any, Optional, Tuple

# ---- toggles (you can flip them later from base_RAG if needed)
USE_SBERT   = True
USE_FLAN_T5 = True

def normalize(text: str) -> str:
    text = unicodedata.normalize("NFKC", text).strip()
    text = re.sub(r"\s+", " ", text)
    return text

TOKEN_RE = re.compile(r"[A-Za-z0-9]+(?:'[A-Za-z0-9]+)?|[&$%.\-]+")
def simple_tokenize(text: str) -> List[str]:
    toks = TOKEN_RE.findall(text)
    cleaned = []
    for t in toks:
        if t.lower() == "'s": continue
        if t.endswith("'s"): t = t[:-2]
        cleaned.append(t)
    return cleaned

DOMAIN_SYNONYMS = {
    "risk": ["risk factor","risk factors","uncertainty","exposure","threat"],
    "cyber": ["cybersecurity","information security","infosec","data breach","security incident"],
    "performance": ["revenue","growth","margin","profit","loss","guidance","results"],
    "strategy": ["roadmap","plan","initiative","expansion","capex","restructuring","acquisition"],
    "md&a": ["management discussion","md&a","results of operations"],
}

def keyword_expand(tokens: List[str]) -> List[str]:
    ex = []
    for t in tokens:
        t0 = t.strip(".-").lower()
        ex.extend(DOMAIN_SYNONYMS.get(t0, []))
    seen, out = set(), []
    for w in ex:
        if w not in seen:
            seen.add(w); out.append(w)
    return out

def build_keywords(tokens: List[str], expansions: List[str]) -> List[str]:
    kept = []
    for t in tokens + expansions:
        t = t.lower()
        if not re.search(r"[a-z0-9]", t):
            continue
        if t not in kept:
            kept.append(t)
    return kept

COMPANY_TICKERS = {"tesla":"TSLA","apple":"AAPL","microsoft":"MSFT","nvidia":"NVDA"}

# --- entities (quarter/year/company/ticker) ---
try:
    import spacy
    _nlp = spacy.load("en_core_web_sm")
except Exception:
    _nlp = None

def extract_entities(raw_text: str) -> dict:
    out = {}
    low = raw_text.lower()
    q = re.findall(r"\b(q[1-4])\s*([12][0-9]{3})\b", low)
    if q: out["quarter"] = [f"{p.upper()} {y}" for p, y in q]
    years = re.findall(r"\b(20[0-4][0-9]|19[0-9]{2})\b", raw_text)
    if years: out["year"] = sorted(set(years))
    companies = set()
    if _nlp is not None:
        doc = _nlp(raw_text)
        for ent in doc.ents:
            if ent.label_ == "ORG":
                companies.add(ent.text.strip())
    low_raw = raw_text.lower()
    for name in COMPANY_TICKERS:
        if name in low_raw: companies.add(name.title())
    if companies: out["company"] = sorted(companies)
    tickers = set(COMPANY_TICKERS.get(c.lower(),"") for c in companies if COMPANY_TICKERS.get(c.lower()))
    tickers.update(re.findall(r"\$([A-Z]{1,5})\b", raw_text))
    tickers.update(re.findall(r"\(([A-Z]{1,5})\)", raw_text))
    tickers.update(re.findall(r"\b(?:NASDAQ|NYSE)\s*:\s*([A-Z]{1,5})\b", raw_text))
    tickers = {t for t in tickers if t}
    if tickers: out["ticker"] = sorted(tickers)
    return out

# --- SBERT embedding (matches base_RAG Config) ---
try:
    from sentence_transformers import SentenceTransformer
    _sbert = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2") if USE_SBERT else None
except Exception:
    _sbert = None

def sbert_embed(text: str) -> Optional[List[float]]:
    if _sbert is None: return None
    v = _sbert.encode([text], normalize_embeddings=True)[0]
    return v.tolist()

# --- Flan-T5 paraphrasing (optional) ---
try:
    import torch
    from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
    _flan_device = "cuda" if USE_FLAN_T5 and torch.cuda.is_available() else "cpu"
    _flan_tok = AutoTokenizer.from_pretrained("google/flan-t5-small") if USE_FLAN_T5 else None
    _flan_mdl = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small").to(_flan_device).eval() if USE_FLAN_T5 else None
except Exception:
    _flan_tok = _flan_mdl = None
    _flan_device = "cpu"

def t5_paraphrases_safe(q: str, num_return: int = 5, max_new_tokens: int = 48) -> List[str]:
    if not (USE_FLAN_T5 and _flan_tok is not None and _flan_mdl is not None): return []
    import torch, re
    prompt = ("Rewrite the query into multiple short paraphrases without adding facts or numbers. "
              "Keep meaning; avoid speculation or meta text.\nQuery: " + q)
    x = _flan_tok(prompt, return_tensors="pt", padding=True, truncation=True, max_length=128)
    x = {k: v.to(_flan_device) for k, v in x.items()}
    with torch.no_grad():
        out = _flan_mdl.generate(
            **x, do_sample=True, top_k=50, top_p=0.92, temperature=0.9,
            num_return_sequences=num_return, max_new_tokens=max_new_tokens,
            repetition_penalty=1.1, no_repeat_ngram_size=3
        )
    paras = _flan_tok.batch_decode(out, skip_special_tokens=True)
    base = re.sub(r"\W+"," ", q).strip().lower()
    seen, kept = set(), []
    for p in paras:
        p2 = normalize(p)
        p2_cmp = re.sub(r"\W+"," ", p2).strip().lower()
        if p2_cmp == base: continue
        if p2 and p2 not in seen:
            seen.add(p2); kept.append(p2)
    return kept[:num_return]

# --- intent (hybrid)
from sklearn.linear_model import LogisticRegression
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.pipeline import Pipeline
INTENT_LABELS = ["risk","performance","strategy"]
X_train = [
    "What new risk factors were disclosed?",
    "Cybersecurity breach details for Tesla",
    "Explain Apple revenue growth and margins",
    "Compare Microsoft profit guidance last quarter",
    "Outline Nvidia expansion strategy in data centers",
    "What restructuring plan is management proposing?"
]
y_train = ["risk","risk","performance","performance","strategy","strategy"]
_intent_clf = Pipeline([
    ("tfidf", TfidfVectorizer(ngram_range=(1,2), min_df=1)),
    ("lr", LogisticRegression(max_iter=300, class_weight="balanced", multi_class="ovr"))
]).fit(X_train, y_train)

RISK_KW = {"risk","risk factor","risk factors","uncertainty","cyber","cybersecurity","breach","litigation","security"}
PERF_KW = {"revenue","growth","margin","profit","loss","guidance","results","compare","last quarter","quarterly"}
STRAT_KW= {"strategy","plan","roadmap","expansion","acquisition","restructuring","capex","data center","data centers"}

def _kw_score(t: str, kws: set[str]) -> int:
    return sum(1 for k in kws if k in t)

def classify_intent(text: str) -> Tuple[str, float]:
    tx = normalize(text)
    proba = _intent_clf.predict_proba([tx])[0].tolist()
    k_r = _kw_score(tx, RISK_KW); k_p = _kw_score(tx, PERF_KW); k_s = _kw_score(tx, STRAT_KW)
    k_sum = max(1, (k_r + k_p + k_s))
    priors = [k_r/k_sum, k_p/k_sum, k_s/k_sum]
    alpha, beta = 0.6, 0.4
    blended = [alpha*proba[i] + beta*priors[i] for i in range(3)]
    s = sum(blended) or 1.0
    blended = [b/s for b in blended]
    idx = max(range(3), key=lambda i: blended[i])
    return INTENT_LABELS[idx], float(blended[idx])

def expand_query(query: str) -> dict:
    norm = normalize(query)
    toks = simple_tokenize(norm)
    lex_ex = keyword_expand(toks)
    paras = t5_paraphrases_safe(norm, num_return=5, max_new_tokens=48) if USE_FLAN_T5 else []
    para_tokens = []
    for p in paras:
        para_tokens.extend(simple_tokenize(p))
    para_tokens = list(dict.fromkeys(para_tokens))
    para_ex = keyword_expand(para_tokens) if para_tokens else []
    expansions = []
    for lst in (lex_ex, para_ex):
        for w in lst:
            if w not in expansions:
                expansions.append(w)
    return {
        "normalized": norm,
        "tokens": toks,
        "expansions": expansions,
        "paraphrases": paras,
        "keywords": build_keywords(toks, expansions)
    }

@dataclass
class QueryProcessorConfig:
    labels: List[str] = field(default_factory=lambda: ["risk","performance","strategy"])

class QueryProcessor:
    def __init__(self, config: QueryProcessorConfig = QueryProcessorConfig()):
        self.config = config
    def process(self, query: str) -> Dict[str, Any]:
        raw = query
        ex  = expand_query(query)
        ents = extract_entities(raw)
        label, conf = classify_intent(ex["normalized"])
        emb = sbert_embed(ex["normalized"]) if USE_SBERT else None
        return {
            "normalized": ex["normalized"],
            "label": label,
            "confidence": conf,
            "expansions": ex["expansions"],
            "paraphrases": ex["paraphrases"],
            "keywords": ex["keywords"],
            "entities": ents,
            "filters": ents.copy(),
            "embedding": emb
        }
''')

with open("query_processor.py", "w", encoding="utf-8") as f:
    f.write(code)

print("✅ query_processor.py written")


✅ query_processor.py written


In [None]:

def build_metadata_filter(filters: Dict[str, List[str]]):
    try:
        from qdrant_client.http.models import Filter, FieldCondition, MatchAny
        must = []
        if "ticker" in filters:
            must.append(FieldCondition(key="ticker", match=MatchAny(any=filters["ticker"])))
        if "quarter" in filters:
            must.append(FieldCondition(key="quarter", match=MatchAny(any=filters["quarter"])))
        return Filter(must=must) if must else None
    except Exception:
        return None

def rag_retrieve(user_query: str, top_k: int = 20):
    q = qp.process(user_query)
    filt = build_metadata_filter(q["filters"])
    embedding = q["embedding"]
    terms = q["keywords"]

    vec_hits = []
    if embedding is not None:
        # vec_hits = qdrant_manager.search_vector(embedding, top_k=top_k, metadata_filter=filt)
        pass

    bm25_hits = []
    # bm25_hits = es.search(terms, top_k=top_k, filters=q["filters"])

    hits = vec_hits + bm25_hits
    return q, hits
