<a href="https://colab.research.google.com/github/Rajfekar/PythonML/blob/main/reranking_raj_bi_encoder_cross_encoder(bert).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install -q sentence-transformers faiss-cpu pandas rapidfuzz


[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/3.2 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.2/3.2 MB[0m [31m4.4 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/3.2 MB[0m [31m17.8 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m3.2/3.2 MB[0m [31m36.2 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.2/3.2 MB[0m [31m28.8 MB/s[0m eta [36m0:00:00[0m
[?25h

In [None]:
import os
import re
import numpy as np
import pandas as pd
import faiss
import torch
from sentence_transformers import SentenceTransformer, CrossEncoder
from rapidfuzz import process, fuzz

device = "cuda" if torch.cuda.is_available() else "cpu"

csv_path = '/content/merged_cpt_data - merged_cpt_data.csv.csv'
df = pd.read_csv(csv_path)

df['Summary_clean'] = (
    df['Summary'].astype(str)
      .str.replace(r'(?i)^Summary\s*', '', regex=True)
      .str.replace(r'\\n', ' ', regex=True)
      .str.replace('\n', ' ', regex=False)
      .str.replace(r'\s+', ' ', regex=True)
      .str.strip()
)



df["merged_text"] = (
    (df["CPT_Code"].astype(str) + " ") +
    df["Desc"].astype(str) + " " +
    df["Category"].astype(str) + " " +
    df["Summary_clean"].astype(str)
)

summaries = df["merged_text"].astype(str).tolist()
ids = list(df['CPT_Code'].astype(str).tolist())

embed_model_name = 'all-MiniLM-L6-v2'
reranker_model_name = 'cross-encoder/ms-marco-MiniLM-L-6-v2'

embed_model = SentenceTransformer(embed_model_name, device=device)
reranker = CrossEncoder(reranker_model_name, device=device)

embeddings = embed_model.encode(summaries, batch_size=64, show_progress_bar=True, convert_to_numpy=True)
faiss.normalize_L2(embeddings)
dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(embeddings)
faiss.write_index(index, "mergedtext_faiss.index")
np.save("merged_summaries.npy", np.array(summaries))
np.save("merged_ids.npy", np.array(ids))

def normalize_code(s):
    if s is None:
        return None
    return re.sub(r'[\s\-]', '', str(s)).upper()

TOKEN_RE = re.compile(r'\b[A-Za-z0-9]+\b')
CONTAINS_DIGIT_RE = re.compile(r'\d')

normalized_ids = [normalize_code(x) for x in ids]
code_to_index = {nid: i for i, nid in enumerate(normalized_ids)}
code_list = list(code_to_index.keys())

FUZZY_THRESHOLD = 90
SEMANTIC_CONF_THRESHOLD = 0.80
CODE_DESC_MISMATCH_SCORE = 0.90
DEMOTION_FACTOR = 0.35

def search_and_rerank_combined(
    query: str,
    top_k: int = 200,
    rerank_k: int = 50,
    reranker_batch: int = 32,
    use_fuzzy_for_code: bool = True,
    fuzzy_threshold: int = FUZZY_THRESHOLD,
    semantic_conf_threshold: float = SEMANTIC_CONF_THRESHOLD,
    code_mismatch_score: float = CODE_DESC_MISMATCH_SCORE,
    demotion_factor: float = DEMOTION_FACTOR
):
    if not query:
        return []

    tokens = TOKEN_RE.findall(query)
    token_norms = [normalize_code(t) for t in tokens if t]

    candidate_code_token = None
    candidate_token_original = None
    for i, orig_t in enumerate(tokens):
        t = token_norms[i] if i < len(token_norms) else normalize_code(orig_t)
        if not t:
            continue
        if CONTAINS_DIGIT_RE.search(t) and 1 <= len(t) <= 12:
            candidate_code_token = t
            candidate_token_original = orig_t
            break

    if len(tokens) == 1 and candidate_code_token:
        code_norm = candidate_code_token
        if code_norm in code_to_index:
            idx = code_to_index[code_norm]
            return [(ids[idx], summaries[idx], 1.0)]
        else:
            return []

    matched_code_idx = None
    matched_code_norm = None
    matched_code_conf = 0.0
    if use_fuzzy_for_code and candidate_code_token:
        match, score, _ = process.extractOne(candidate_code_token, code_list, scorer=fuzz.ratio)
        if score >= fuzzy_threshold:
            matched_code_norm = match
            matched_code_conf = score / 100.0
            matched_code_idx = code_to_index[match]

    desc_query = query.strip()
    if candidate_code_token and len(tokens) > 1:
        if candidate_token_original:
            tmp = re.sub(r'\b' + re.escape(candidate_token_original) + r'\b', '', desc_query, flags=re.IGNORECASE)
            desc_query = re.sub(r'\s+', ' ', tmp).strip()
        else:
            tmp = re.sub(r'\b' + re.escape(candidate_code_token) + r'\b', '', desc_query, flags=re.IGNORECASE)
            desc_query = re.sub(r'\s+', ' ', tmp).strip()

    if not desc_query:
        if matched_code_idx is not None:
            if matched_code_norm == candidate_code_token:
                return [(ids[matched_code_idx], summaries[matched_code_idx], 1.0)]
            else:
                return [(ids[matched_code_idx], summaries[matched_code_idx], max(0.95, matched_code_conf))]
        return []

    q_emb = embed_model.encode([desc_query], convert_to_numpy=True)
    faiss.normalize_L2(q_emb)
    D, I = index.search(q_emb, top_k)
    idxs = I[0].tolist()
    candidates = [summaries[i] for i in idxs]
    candidate_ids = [ids[i] for i in idxs]

    pairs = [[desc_query, cand] for cand in candidates]
    scores = []
    for start in range(0, len(pairs), reranker_batch):
        batch_pairs = pairs[start:start+reranker_batch]
        batch_scores = reranker.predict(batch_pairs)
        batch_scores = torch.tensor(batch_scores, dtype=torch.float32)
        batch_probs = torch.sigmoid(batch_scores).tolist()
        scores.extend(batch_probs)

    combined = list(zip(candidate_ids, candidates, [float(s) for s in scores]))
    scores_by_id = {cid: sc for cid, _, sc in combined}

    code_id = None
    code_text = None
    if matched_code_idx is not None:
        code_id = ids[matched_code_idx]
        code_text = summaries[matched_code_idx]
    elif candidate_code_token and candidate_code_token in code_to_index:
        matched_code_idx = code_to_index[candidate_code_token]
        code_id = ids[matched_code_idx]
        code_text = summaries[matched_code_idx]

    final_list = []
    if code_id:
        semantic_score_for_code = float(scores_by_id.get(code_id, 0.0))
        if semantic_score_for_code >= semantic_conf_threshold:
            code_combined_score = 1.0
        else:
            code_combined_score = code_mismatch_score
        final_list.append((code_id, code_text, float(code_combined_score)))

    for cid, txt, sc in combined:
        if cid == code_id:
            continue
        adj_score = float(sc)
        if code_id:
            adj_score = adj_score * demotion_factor
        final_list.append((cid, txt, adj_score))

    final_list.sort(key=lambda x: x[2], reverse=True)
    seen = set()
    out = []
    for cid, txt, sc in final_list:
        if cid in seen:
            continue
        seen.add(cid)
        out.append((cid, txt, float(sc)))
        if len(out) >= rerank_k:
            break

    return out

query = "perfusion"
results = search_and_rerank_combined(query, top_k=20, rerank_k=5, reranker_batch=32)

for cid, text, score in results:
    print(f"CPT: {cid}  |  score: {score:.4f}")
    print("snippet:", text[:250].replace("\n", " "))
    print("-" * 80)


Batches:   0%|          | 0/21 [00:00<?, ?it/s]

CPT: 78597  |  score: 0.9600
snippet: 78597 Lung perfusion differential RADIOLOGY AND CERTAIN OTHER IMAGING SERVICES In this diagnostic procedure, the provider performs pulmonary perfusion, a nuclear scan test that evaluates the flow of blood within the patient’s lungs. The aim is to per
--------------------------------------------------------------------------------
CPT: 0042T  |  score: 0.9103
snippet: 0042T Ct perfusion w/contrast cbf RADIOLOGY AND CERTAIN OTHER IMAGING SERVICES The provider obtains a measurement of regional cerebral blood flow through analysis of computed tomography (CT) scans by taking sequential images of sections of the brain 
--------------------------------------------------------------------------------
CPT: 78580  |  score: 0.9004
snippet: 78580 Lung perfusion imaging RADIOLOGY AND CERTAIN OTHER IMAGING SERVICES In this diagnostic procedure, the provider performs a nuclear perfusion imaging test to evaluate the circulation of blood within the patient’s lungs

In [2]:
# build_and_save.py
import os, re
import numpy as np
import pandas as pd
import faiss
import torch
from sentence_transformers import SentenceTransformer

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
EMBED_MODEL_NAME = "all-MiniLM-L6-v2"

CSV_PATH = "/content/merged_cpt_data - merged_cpt_data.csv.csv"   # change if needed
OUT_DIR = "./data"
os.makedirs(OUT_DIR, exist_ok=True)

df = pd.read_csv(CSV_PATH)

# clean summary
df["Summary_clean"] = (
    df["Summary"].astype(str)
      .str.replace(r'(?i)^Summary\s*', '', regex=True)
      .str.replace(r'\\n', ' ', regex=True)
      .str.replace('\n', ' ', regex=False)
      .str.replace(r'\s+', ' ', regex=True)
      .str.strip()
)

# build merged text (you used merged_text earlier)
df["merged_text"] = (
    (df["CPT_Code"].astype(str) + " ") +
    df["Desc"].astype(str) + " " +
    df["Category"].astype(str) + " " +
    df["Summary_clean"].astype(str)
)

# prepare lists
ids = df["CPT_Code"].astype(str).tolist()
summaries = df["merged_text"].astype(str).tolist()

# load embed model and encode
embed_model = SentenceTransformer(EMBED_MODEL_NAME, device=DEVICE)
embeddings = embed_model.encode(summaries, batch_size=64, show_progress_bar=True, convert_to_numpy=True)

# normalize and build FAISS index (inner product on normalized vectors => cosine)
faiss.normalize_L2(embeddings)
dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(embeddings)

# save artifacts
np.save(os.path.join(OUT_DIR, "embeddings.npy"), embeddings)
np.save(os.path.join(OUT_DIR, "summaries.npy"), np.array(summaries))
pd.DataFrame({"CPT_Code": ids, "merged_text": summaries}).to_csv(os.path.join(OUT_DIR, "ids.csv"), index=False)
faiss.write_index(index, os.path.join(OUT_DIR, "mergedtext_faiss.index"))

print("Saved artifacts to", OUT_DIR)
print("ntotal in index:", index.ntotal)


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


Batches:   0%|          | 0/21 [00:00<?, ?it/s]

Saved artifacts to ./data
ntotal in index: 1321


In [3]:
# load_and_search.py
import os, re
import numpy as np
import pandas as pd
import faiss
import torch
from sentence_transformers import SentenceTransformer, CrossEncoder
from rapidfuzz import process, fuzz
from typing import List, Tuple, Optional

# config & paths (adjust)
OUT_DIR = "./data"
EMBED_MODEL_NAME = "all-MiniLM-L6-v2"
RERANKER_MODEL_NAME = "cross-encoder/ms-marco-MiniLM-L-6-v2"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

EMBEDDINGS_PATH = os.path.join(OUT_DIR, "embeddings.npy")
SUMMARIES_PATH = os.path.join(OUT_DIR, "summaries.npy")
IDS_CSV_PATH = os.path.join(OUT_DIR, "ids.csv")
FAISS_INDEX_PATH = os.path.join(OUT_DIR, "mergedtext_faiss.index")

# load artifacts
embeddings = np.load(EMBEDDINGS_PATH)
summaries = np.load(SUMMARIES_PATH, allow_pickle=True).tolist()
ids_df = pd.read_csv(IDS_CSV_PATH)
ids = ids_df["CPT_Code"].astype(str).tolist()

# load or recreate FAISS index
if os.path.exists(FAISS_INDEX_PATH):
    index = faiss.read_index(FAISS_INDEX_PATH)
else:
    faiss.normalize_L2(embeddings)
    dim = embeddings.shape[1]
    index = faiss.IndexFlatIP(dim)
    index.add(embeddings)

# models
embed_model = SentenceTransformer(EMBED_MODEL_NAME, device=DEVICE)
reranker = CrossEncoder(RERANKER_MODEL_NAME, device=DEVICE)

# helpers
def normalize_code(s: Optional[str]) -> Optional[str]:
    if s is None:
        return None
    return re.sub(r'[\s\-]', '', str(s)).upper()

TOKEN_RE = re.compile(r'\b[A-Za-z0-9]+\b')
CONTAINS_DIGIT_RE = re.compile(r'\d')

normalized_ids = [normalize_code(x) for x in ids]
code_to_index = {nid: i for i, nid in enumerate(normalized_ids)}
code_list = list(code_to_index.keys())

# thresholds
FUZZY_THRESHOLD = 90
SEMANTIC_CONF_THRESHOLD = 0.80
CODE_DESC_MISMATCH_SCORE = 0.90
DEMOTION_FACTOR = 0.35

def search_loaded(
    query: str,
    top_k: int = 200,
    rerank_k: int = 50,
    reranker_batch: int = 32,
    use_fuzzy_for_code: bool = True
) -> List[Tuple[str, str, float]]:
    if not query or not query.strip():
        return []

    tokens = TOKEN_RE.findall(query)
    token_norms = [normalize_code(t) for t in tokens if t]

    # detect first token containing a digit -> candidate code
    candidate_code_token = None
    candidate_token_original = None
    for i, orig_t in enumerate(tokens):
        t = token_norms[i] if i < len(token_norms) else normalize_code(orig_t)
        if not t:
            continue
        if CONTAINS_DIGIT_RE.search(t) and 1 <= len(t) <= 12:
            candidate_code_token = t
            candidate_token_original = orig_t
            break

    # fast path: code-only
    if len(tokens) == 1 and candidate_code_token:
        if candidate_code_token in code_to_index:
            idx = code_to_index[candidate_code_token]
            return [(ids[idx], summaries[idx], 1.0)]
        else:
            return []

    # fuzzy code lookup
    matched_code_idx = None
    matched_code_norm = None
    matched_code_conf = 0.0
    if use_fuzzy_for_code and candidate_code_token:
        match, score, _ = process.extractOne(candidate_code_token, code_list, scorer=fuzz.ratio)
        if score >= FUZZY_THRESHOLD:
            matched_code_norm = match
            matched_code_conf = score / 100.0
            matched_code_idx = code_to_index[match]

    # remove code token from query to form description-only query
    desc_query = query.strip()
    if candidate_code_token and len(tokens) > 1:
        if candidate_token_original:
            tmp = re.sub(r'\b' + re.escape(candidate_token_original) + r'\b', '', desc_query, flags=re.IGNORECASE)
            desc_query = re.sub(r'\s+', ' ', tmp).strip()
        else:
            tmp = re.sub(r'\b' + re.escape(candidate_code_token) + r'\b', '', desc_query, flags=re.IGNORECASE)
            desc_query = re.sub(r'\s+', ' ', tmp).strip()

    # if no desc left -> return matched code
    if not desc_query:
        if matched_code_idx is not None:
            if matched_code_norm == candidate_code_token:
                return [(ids[matched_code_idx], summaries[matched_code_idx], 1.0)]
            else:
                return [(ids[matched_code_idx], summaries[matched_code_idx], max(0.95, matched_code_conf))]
        return []

    # semantic retrieval on desc_query
    q_emb = embed_model.encode([desc_query], convert_to_numpy=True)
    faiss.normalize_L2(q_emb)
    D, I = index.search(q_emb, top_k)
    idxs = I[0].tolist()
    candidates = [summaries[i] for i in idxs]
    candidate_ids = [ids[i] for i in idxs]

    # reranker scoring
    pairs = [[desc_query, cand] for cand in candidates]
    scores = []
    for start in range(0, len(pairs), reranker_batch):
        batch_pairs = pairs[start:start+reranker_batch]
        batch_scores = reranker.predict(batch_pairs)
        batch_scores = torch.tensor(batch_scores, dtype=torch.float32)
        batch_probs = torch.sigmoid(batch_scores).tolist()
        scores.extend(batch_probs)

    combined = list(zip(candidate_ids, candidates, [float(s) for s in scores]))
    scores_by_id = {cid: sc for cid, _, sc in combined}

    # resolve code doc (from fuzzy or exact map)
    code_id = None
    code_text = None
    if matched_code_idx is not None:
        code_id = ids[matched_code_idx]
        code_text = summaries[matched_code_idx]
    elif candidate_code_token and candidate_code_token in code_to_index:
        matched_code_idx = code_to_index[candidate_code_token]
        code_id = ids[matched_code_idx]
        code_text = summaries[matched_code_idx]

    final_list = []
    if code_id:
        semantic_score_for_code = float(scores_by_id.get(code_id, 0.0))
        if semantic_score_for_code >= SEMANTIC_CONF_THRESHOLD:
            code_combined_score = 1.0
        else:
            code_combined_score = CODE_DESC_MISMATCH_SCORE
        final_list.append((code_id, code_text, float(code_combined_score)))

    for cid, txt, sc in combined:
        if cid == code_id:
            continue
        adj_score = float(sc)
        if code_id:
            adj_score = adj_score * DEMOTION_FACTOR
        final_list.append((cid, txt, adj_score))

    final_list.sort(key=lambda x: x[2], reverse=True)
    seen = set()
    out = []
    for cid, txt, sc in final_list:
        if cid in seen:
            continue
        seen.add(cid)
        out.append((cid, txt, float(sc)))
        if len(out) >= rerank_k:
            break

    return out



In [4]:
# example usage:
if __name__ == "__main__":
    q_list = [
        "42324205",
        "A9557",
        "93980 Echo transthoracic",
        "perfusion"
    ]
    for q in q_list:
        print("\nQuery:", q)
        res = search_loaded(q, top_k=20, rerank_k=5, reranker_batch=32)
        for cid, text, score in res:
            print(f"  {cid}  score={score:.4f}  snippet={text[:120].replace(chr(10),' ')}")



Query: 42324205

Query: A9557
  A9557  score=1.0000  snippet=A9557 Tc99m bicisate RADIOLOGY AND CERTAIN OTHER IMAGING SERVICES nan

Query: 93980 Echo transthoracic
  93980  score=0.9000  snippet=93980 Penile vascular study RADIOLOGY AND CERTAIN OTHER IMAGING SERVICES The provider performs a complete study of the p
  93303  score=0.3499  snippet=93303 Echo transthoracic RADIOLOGY AND CERTAIN OTHER IMAGING SERVICES The provider performs a complete transthoracic ech
  93304  score=0.3499  snippet=93304 Echo transthoracic RADIOLOGY AND CERTAIN OTHER IMAGING SERVICES The provider performs a limited or follow–up trans
  93308  score=0.3487  snippet=93308 Tte f-up or lmtd RADIOLOGY AND CERTAIN OTHER IMAGING SERVICES The provider performs a limited or follow–up transth
  76506  score=0.0589  snippet=76506 Echo exam of head RADIOLOGY AND CERTAIN OTHER IMAGING SERVICES The provider performs a noninvasive diagnostic imag

Query: perfusion
  78597  score=0.9600  snippet=78597 Lung perfusion diffe