<a href="https://colab.research.google.com/github/Rajfekar/PythonML/blob/main/reranking_raj_bi-encoder_cross_encoder(bert).ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [3]:
# Colab: install required packages
# If you want GPU use faiss-gpu (only if runtime has CUDA). Otherwise use faiss-cpu.
# For most Colab users, faiss-cpu is fine.
!pip install -q sentence-transformers faiss-cpu pandas
# If you plan to use faiss-gpu on a GPU runtime, replace faiss-cpu with faiss-gpu:
# !pip install -q sentence-transformers faiss-gpu pandas


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.4/31.4 MB[0m [31m24.5 MB/s[0m eta [36m0:00:00[0m
[?25h

In [49]:
from sentence_transformers import SentenceTransformer, CrossEncoder, util
import faiss
import numpy as np
import pandas as pd
import os
import math
import torch

print("torch version:", torch.__version__)
device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)


torch version: 2.8.0+cu126
Using device: cpu


In [50]:
import pandas as pd
from IPython.display import display

csv_path = '/content/merged_cpt_data.csv'
df = pd.read_csv(csv_path)

print("Initial dataframe shape:", df.shape)
print("Columns:", df.columns.tolist())
print("\n--- Head (original) ---")
display(df.head(5))


Initial dataframe shape: (1321, 5)
Columns: ['CPT_Code', 'Desc', 'Category', 'Remark', 'Summary']

--- Head (original) ---


Unnamed: 0,CPT_Code,Desc,Category,Remark,Summary
0,86152,Cell enumeration &id,CLINICAL LABORATORY SERVICES,INCLUDE CPT codes for all clinical laboratory ...,Summary\nThe lab analyst performs a test to ev...
1,86153,Cell enumeration phys interp,CLINICAL LABORATORY SERVICES,INCLUDE CPT codes for all clinical laboratory ...,"Summary\nThe physician, specifically a patholo..."
2,86890,Autologous blood process,CLINICAL LABORATORY SERVICES,INCLUDE CPT codes for all clinical laboratory ...,"Summary\nThe provider collects, processes, and..."
3,86891,Autologous blood op salvage,CLINICAL LABORATORY SERVICES,INCLUDE CPT codes for all clinical laboratory ...,Summary\nThe clinician performs a procedure to...
4,86927,Plasma fresh frozen,CLINICAL LABORATORY SERVICES,INCLUDE CPT codes for all clinical laboratory ...,Summary\nThe lab analyst thaws a unit of froze...


In [64]:
# Install profiling library (only once per runtime)
!pip install -q ydata-profiling

from ydata_profiling import ProfileReport

# Create a profile report
profile = ProfileReport(df, title="CPT Data Profiling Report", explorative=True)

# Display in notebook
profile.to_notebook_iframe()

# Save to file (optional, to download)
profile.to_file("/content/cpt_data_profile.html")
print("Profile report saved to /content/cpt_data_profile.html")


Summarize dataset:   0%|          | 0/5 [00:00<?, ?it/s]


  0%|          | 0/7 [00:00<?, ?it/s][A
 43%|████▎     | 3/7 [00:00<00:00, 25.62it/s][A
100%|██████████| 7/7 [00:00<00:00, 11.67it/s]


Generate report structure:   0%|          | 0/1 [00:00<?, ?it/s]

Render HTML:   0%|          | 0/1 [00:00<?, ?it/s]

Export report to file:   0%|          | 0/1 [00:00<?, ?it/s]

Profile report saved to /content/cpt_data_profile.html


In [51]:
pattern_leading = r'(?i)^Summary\s*(\\n|\n)'   # leading 'Summary' + newline or literal \n
pattern_any = r'(\\n|\n)'                     # any literal \n or actual newline

mask_leading = df['Summary'].astype(str).str.contains(pattern_leading, regex=True, na=False)
mask_any = df['Summary'].astype(str).str.contains(pattern_any, regex=True, na=False)

print(f"Rows with leading 'Summary' prefix + newline: {mask_leading.sum()}")
print(f"Rows with any '\\\\n' literal or actual newline: {mask_any.sum()}")

if mask_any.sum() > 0:
    print("\n--- Example problematic rows (CPT_Code + Summary) ---")
    display(df.loc[mask_any, ['CPT_Code', 'Summary']].head(10))
else:
    print("\nNo rows found with literal '\\n' or newline in Summary.")


Rows with leading 'Summary' prefix + newline: 1065
Rows with any '\\n' literal or actual newline: 1071

--- Example problematic rows (CPT_Code + Summary) ---


  mask_leading = df['Summary'].astype(str).str.contains(pattern_leading, regex=True, na=False)
  mask_any = df['Summary'].astype(str).str.contains(pattern_any, regex=True, na=False)


Unnamed: 0,CPT_Code,Summary
0,86152,Summary\nThe lab analyst performs a test to ev...
1,86153,"Summary\nThe physician, specifically a patholo..."
2,86890,"Summary\nThe provider collects, processes, and..."
3,86891,Summary\nThe clinician performs a procedure to...
4,86927,Summary\nThe lab analyst thaws a unit of froze...
5,86930,Summary\nThe lab analyst thaws a unit of froze...
6,86931,Summary\nThe lab analyst thaws a unit of froze...
7,86932,"Summary\nThe lab analyst prepares, freezes, an..."
8,86945,Summary\nThe lab analyst performs the steps to...
9,86950,Summary\nThe lab personnel separate leukocytes...


In [52]:
df['Summary_clean'] = (
    df['Summary'].astype(str)
      .str.replace(r'(?i)^Summary\s*', '', regex=True)   # remove leading 'Summary' (case-insensitive)
      .str.replace(r'\\n', ' ', regex=True)             # replace literal backslash-n sequence
      .str.replace('\n', ' ', regex=False)              # replace actual newline characters
      .str.replace(r'\s+', ' ', regex=True)             # collapse multiple spaces
      .str.strip()
)

print("Cleaning done. Example cleaned values:")
display(df[['CPT_Code','Summary','Summary_clean']].head(10))


Cleaning done. Example cleaned values:


Unnamed: 0,CPT_Code,Summary,Summary_clean
0,86152,Summary\nThe lab analyst performs a test to ev...,The lab analyst performs a test to evaluate th...
1,86153,"Summary\nThe physician, specifically a patholo...","The physician, specifically a pathologist, int..."
2,86890,"Summary\nThe provider collects, processes, and...","The provider collects, processes, and stores a..."
3,86891,Summary\nThe clinician performs a procedure to...,"The clinician performs a procedure to collect,..."
4,86927,Summary\nThe lab analyst thaws a unit of froze...,The lab analyst thaws a unit of frozen plasma.
5,86930,Summary\nThe lab analyst thaws a unit of froze...,The lab analyst thaws a unit of frozen plasma.
6,86931,Summary\nThe lab analyst thaws a unit of froze...,The lab analyst thaws a unit of frozen blood.
7,86932,"Summary\nThe lab analyst prepares, freezes, an...","The lab analyst prepares, freezes, and thaws a..."
8,86945,Summary\nThe lab analyst performs the steps to...,The lab analyst performs the steps to treat bl...
9,86950,Summary\nThe lab personnel separate leukocytes...,The lab personnel separate leukocytes from don...


In [53]:
problem_mask = df['Summary'].astype(str).str.contains(r'(\\n|\n)', regex=True, na=False)
if problem_mask.sum() > 0:
    print(f"Showing {problem_mask.sum()} problematic rows (before -> after):")
    display(df.loc[problem_mask, ['CPT_Code','Summary','Summary_clean']].head(20))
else:
    print("No problematic rows to display (none contained \\n/newline).")


Showing 1071 problematic rows (before -> after):


  problem_mask = df['Summary'].astype(str).str.contains(r'(\\n|\n)', regex=True, na=False)


Unnamed: 0,CPT_Code,Summary,Summary_clean
0,86152,Summary\nThe lab analyst performs a test to ev...,The lab analyst performs a test to evaluate th...
1,86153,"Summary\nThe physician, specifically a patholo...","The physician, specifically a pathologist, int..."
2,86890,"Summary\nThe provider collects, processes, and...","The provider collects, processes, and stores a..."
3,86891,Summary\nThe clinician performs a procedure to...,"The clinician performs a procedure to collect,..."
4,86927,Summary\nThe lab analyst thaws a unit of froze...,The lab analyst thaws a unit of frozen plasma.
5,86930,Summary\nThe lab analyst thaws a unit of froze...,The lab analyst thaws a unit of frozen plasma.
6,86931,Summary\nThe lab analyst thaws a unit of froze...,The lab analyst thaws a unit of frozen blood.
7,86932,"Summary\nThe lab analyst prepares, freezes, an...","The lab analyst prepares, freezes, and thaws a..."
8,86945,Summary\nThe lab analyst performs the steps to...,The lab analyst performs the steps to treat bl...
9,86950,Summary\nThe lab personnel separate leukocytes...,The lab personnel separate leukocytes from don...


In [54]:
df["merged_text"] = (
    df["CPT_Code"].astype(str) + " " +
    df["Desc"].astype(str) + " " +
    df["Category"].astype(str) + " " +
    df["Remark"].astype(str) + " " +
    df["Summary_clean"].astype(str)
)

print("--- Preview merged_text ---")
display(df[["CPT_Code","Desc","Category","Remark","Summary_clean","merged_text"]].head(10))

# Prepare summaries for embeddings
summaries = df["merged_text"].astype(str).tolist()
print("Total rows loaded for embeddings:", len(summaries))



--- Preview merged_text ---


Unnamed: 0,CPT_Code,Desc,Category,Remark,Summary_clean,merged_text
0,86152,Cell enumeration &id,CLINICAL LABORATORY SERVICES,INCLUDE CPT codes for all clinical laboratory ...,The lab analyst performs a test to evaluate th...,86152 Cell enumeration &id CLINICAL LABORATORY...
1,86153,Cell enumeration phys interp,CLINICAL LABORATORY SERVICES,INCLUDE CPT codes for all clinical laboratory ...,"The physician, specifically a pathologist, int...",86153 Cell enumeration phys interp CLINICAL LA...
2,86890,Autologous blood process,CLINICAL LABORATORY SERVICES,INCLUDE CPT codes for all clinical laboratory ...,"The provider collects, processes, and stores a...",86890 Autologous blood process CLINICAL LABORA...
3,86891,Autologous blood op salvage,CLINICAL LABORATORY SERVICES,INCLUDE CPT codes for all clinical laboratory ...,"The clinician performs a procedure to collect,...",86891 Autologous blood op salvage CLINICAL LAB...
4,86927,Plasma fresh frozen,CLINICAL LABORATORY SERVICES,INCLUDE CPT codes for all clinical laboratory ...,The lab analyst thaws a unit of frozen plasma.,86927 Plasma fresh frozen CLINICAL LABORATORY ...
5,86930,Frozen blood prep,CLINICAL LABORATORY SERVICES,INCLUDE CPT codes for all clinical laboratory ...,The lab analyst thaws a unit of frozen plasma.,86930 Frozen blood prep CLINICAL LABORATORY SE...
6,86931,Frozen blood thaw,CLINICAL LABORATORY SERVICES,INCLUDE CPT codes for all clinical laboratory ...,The lab analyst thaws a unit of frozen blood.,86931 Frozen blood thaw CLINICAL LABORATORY SE...
7,86932,Frozen blood freeze/thaw,CLINICAL LABORATORY SERVICES,INCLUDE CPT codes for all clinical laboratory ...,"The lab analyst prepares, freezes, and thaws a...",86932 Frozen blood freeze/thaw CLINICAL LABORA...
8,86945,Blood product/irradiation,CLINICAL LABORATORY SERVICES,INCLUDE CPT codes for all clinical laboratory ...,The lab analyst performs the steps to treat bl...,86945 Blood product/irradiation CLINICAL LABOR...
9,86950,Leukacyte transfusion,CLINICAL LABORATORY SERVICES,INCLUDE CPT codes for all clinical laboratory ...,The lab personnel separate leukocytes from don...,86950 Leukacyte transfusion CLINICAL LABORATOR...


Total rows loaded for embeddings: 1321


In [55]:
# model can be change here as needed
embed_model_name = 'all-MiniLM-L6-v2'
reranker_model_name = 'cross-encoder/ms-marco-MiniLM-L-6-v2'

embed_model = SentenceTransformer(embed_model_name, device=device)
reranker = CrossEncoder(reranker_model_name, device=device)

print("Loaded:", embed_model_name, "and", reranker_model_name)


Loaded: all-MiniLM-L6-v2 and cross-encoder/ms-marco-MiniLM-L-6-v2


In [56]:
# Encode corpus (run once)
print("Encoding corpus...")
embeddings = embed_model.encode(summaries, batch_size=64, show_progress_bar=True, convert_to_numpy=True)

# Normalize for cosine (we'll use inner product on normalized vectors)
faiss.normalize_L2(embeddings)

dim = embeddings.shape[1]
index = faiss.IndexFlatIP(dim)
index.add(embeddings)
print("FAISS index built. ntotal =", index.ntotal)

# Save for later reuse (optional)
faiss.write_index(index, "mergedtext_faiss.index")
# np.save("merged_ids.npy", np.array(ids))
np.save("merged_summaries.npy", np.array(summaries))


Encoding corpus...


Batches:   0%|          | 0/21 [00:00<?, ?it/s]

FAISS index built. ntotal = 1321


In [57]:
import torch
from typing import List, Tuple

def search_and_rerank(
    query: str,
    top_k: int = 200,
    rerank_k: int = 50,
    reranker_batch: int = 32
) -> List[Tuple[str, str, float]]:
    """
    Returns list of tuples: (cpt_id, merged_text, score) sorted descending by score.
    Scores are mapped to the range [0, 1] using sigmoid on cross-encoder outputs.
    """
    # 1) embed query
    q_emb = embed_model.encode([query], convert_to_numpy=True)
    faiss.normalize_L2(q_emb)

    # 2) retrieve top_k from FAISS
    D, I = index.search(q_emb, top_k)   # shapes (1, top_k)
    idxs = I[0].tolist()
    candidates = [summaries[i] for i in idxs]      # merged_text only
    candidate_ids = [ids[i] for i in idxs]        # CPT_Code ids (if you removed ids, ignore this)

    # 3) prepare pairs for CrossEncoder
    pairs = [[query, cand] for cand in candidates]

    # 4) batch predict with reranker and apply sigmoid to normalize to [0,1]
    scores = []
    for start in range(0, len(pairs), reranker_batch):
        batch_pairs = pairs[start:start+reranker_batch]
        batch_scores = reranker.predict(batch_pairs)  # logits or raw scores
        # convert to tensor, apply sigmoid, convert back to python floats
        batch_scores = torch.tensor(batch_scores, dtype=torch.float32)
        batch_probs = torch.sigmoid(batch_scores).tolist()
        scores.extend(batch_probs)

    # 5) combine & sort by probability descending
    combined = list(zip(candidate_ids, candidates, [float(s) for s in scores]))
    combined.sort(key=lambda x: x[2], reverse=True)

    return combined[:rerank_k]


In [62]:
query = "Blood volume CLINICAL LABORATORY"
results = search_and_rerank(query, top_k=20, rerank_k=5, reranker_batch=32)

print("Top results for:", query)
for cid, text, score in results:
    print(f"CPT: {cid}  |  score: {score:.4f}")
    print("snippet:", text[:250].replace("\n"," "))
    print("-" * 80)


Top results for: Blood volume CLINICAL LABORATORY
CPT: 78122  |  score: 0.9999
snippet: 78122 Blood volume CLINICAL LABORATORY SERVICES INCLUDE the following CPT and HCPCS Level 2 codes for other clinical laboratory services: The provider performs a study to measure the patient’s whole blood volume, including red blood cell and plasma v
--------------------------------------------------------------------------------
CPT: 78110  |  score: 0.9990
snippet: 78110 Plasma volume single CLINICAL LABORATORY SERVICES INCLUDE the following CPT and HCPCS Level 2 codes for other clinical laboratory services: In this procedure, the provider determines the plasma volume, a blood component, by introducing a known 
--------------------------------------------------------------------------------
CPT: 78111  |  score: 0.9989
snippet: 78111 Plasma volume multiple CLINICAL LABORATORY SERVICES INCLUDE the following CPT and HCPCS Level 2 codes for other clinical laboratory services: In this procedure, the p

In [None]:
# - If you have a GPU runtime, the reranker will be much faster. Set Runtime -> Change runtime type -> GPU.
# - For 2,000 items, IndexFlatIP is fine. For much larger corpora, consider IVF+PQ or an external vector DB.
# - top_k (retrieve) = 100–300 is a good starting point. rerank_k = 10–50 for final results.
# - If memory errors occur for reranker.predict, reduce reranker_batch.
# - If you want sub-second latency across many queries, pre-load models and keep the FAISS index in memory (as above).
# - To fine-tune reranker for CPT-style text, collect some query→relevant-summary pairs and fine-tune CrossEncoder (optional).


In [None]:

# 1) Load models
embed_model = SentenceTransformer('all-MiniLM-L6-v2')     # fast embedding
reranker = CrossEncoder('cross-encoder/ms-marco-MiniLM-L-6-v2')  # cross-encoder reranker

# 2) Prepare corpus (your CPT summaries)
summaries = [...]   # list of 2000+ summary strings
ids = [...]         # parallel list of CPT codes or row ids

# 3) Encode summaries (once)
embeddings = embed_model.encode(summaries, convert_to_numpy=True, show_progress_bar=True)

# 4) Build FAISS index (cosine)
d = embeddings.shape[1]
index = faiss.IndexFlatIP(d)                      # inner product index
faiss.normalize_L2(embeddings)
index.add(embeddings)                             # add all vectors

# Function to search and rerank
def search_and_rerank(query, top_k=100, rerank_k=50):
    q_emb = embed_model.encode(query, convert_to_numpy=True)
    faiss.normalize_L2(q_emb.reshape(1,-1))
    D, I = index.search(q_emb.reshape(1,-1), top_k)   # returns top_k indices
    candidate_idxs = I[0].tolist()
    candidates = [summaries[i] for i in candidate_idxs]
    candidate_ids = [ids[i] for i in candidate_idxs]

    # Cross-encoder scoring (batching recommended)
    pairs = [[query, c] for c in candidates]
    scores = reranker.predict(pairs, batch_size=16)   # returns score per pair

    # Sort by score desc
    ranked = sorted(zip(candidate_ids, candidates, scores), key=lambda x: x[2], reverse=True)
    return ranked[:rerank_k]   # final results

# Example
results = search_and_rerank("query about CPT for knee arthroscopy", top_k=200, rerank_k=20)
for cid, txt, score in results:
    print(cid, score, txt[:120])
