In [1]:
import os
import json
import pickle
import logging
import pandas as pd
import time
import torch
from dotenv import load_dotenv, find_dotenv
from transformers import AutoTokenizer, AutoModel
from langchain_community.retrievers import BM25Retriever
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.docstore.document import Document
import chromadb
from chromadb.config import Settings
from typing import List, Dict, Any, Tuple

In [3]:
from utils import count_llama_tokens

In [4]:
logger = logging.getLogger(__name__)

# ------------------
# Environment Setup
# ------------------
def setup_environment() -> None:
    env_path = find_dotenv()
    if not env_path:
        env_path = "/home/yl3427/.env"
    if not load_dotenv(env_path):
        raise Exception("Failed to load .env file")
    os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
    os.environ["CUDA_VISIBLE_DEVICES"] = "0, 1"
    logger.info("Environment setup complete for retrieval.")


# --------------------
# Data Loading
# --------------------
def load_cases(json_path: str) -> Dict[str, Any]:
    with open(json_path, "r") as f:
        cases = json.load(f)
    logger.info(f"Loaded {len(cases)} cases from {json_path}")
    return cases


# --------------------
# Text Splitting
# --------------------
def create_documents(
    cases: Dict[str, Any],
    tokenizer,
    max_length: int = 512
) -> List[Document]:
    """
    Creates BM25-compatible Documents. We do not necessarily need to chunk again
    if we only want each entire 'before_diagnosis' text as a single Document for BM25.
    But often you'll chunk for consistency with embedding. 
    """
    text_splitter = RecursiveCharacterTextSplitter.from_huggingface_tokenizer(
        separators=[
            "\n\n", "\n", r'(?<=[.?"\s])\s+', " ", ".", ","
        ],
        tokenizer=tokenizer,
        chunk_size=max_length,
        chunk_overlap=20,
        add_start_index=True,
        strip_whitespace=True,
        is_separator_regex=True
    )

    docs = []
    unique_texts = set()

    for hadm_id, data in cases.items():
        full_text = data["before_diagnosis"]
        split_docs = text_splitter.create_documents(
            texts=[full_text],
            metadatas=[{
                "hadm_id": hadm_id,
                "full_text": full_text,
                "diagnosis": data["after_diagnosis"]
            }]
        )
        for d in split_docs:
            if d.page_content not in unique_texts:
                unique_texts.add(d.page_content)
                docs.append(d)

    logger.info(f"Created {len(docs)} doc chunks for BM25 retrieval.")
    return docs

In [5]:
os.makedirs("log", exist_ok=True)

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    handlers=[
        logging.FileHandler('log/rag2.log', mode='w'),  # 파일로 저장
        logging.StreamHandler()  # 콘솔에 출력
    ]
)

setup_environment()

# 1) Paths
json_path = "/secure/shared_data/SOAP/MIMIC/full_cases_base.json"
chroma_db_path = "/secure/shared_data/rag_embedding_model/chroma_db"
model_cache_dir = "/secure/shared_data/rag_embedding_model"
model_name = "nvidia/NV-Embed-v2"

# 2) Load Cases
cases = load_cases(json_path)

# 3) Create Documents for BM25
tokenizer = AutoTokenizer.from_pretrained(
    model_name,
    trust_remote_code=True,
    cache_dir=model_cache_dir
)
# docs_processed = create_documents(cases, tokenizer, max_length=512)
chunked_doc_file = "chunked_documents.pkl"
if os.path.exists(chunked_doc_file):
    with open(chunked_doc_file, "rb") as f:
        docs_processed = pickle.load(f)
    logger.info(f"Loaded pre-chunked documents from {chunked_doc_file}")
else:
    docs_processed = create_documents(cases, tokenizer, max_length=512)
    with open(chunked_doc_file, "wb") as f:
        pickle.dump(docs_processed, f)
    logger.info(f"Created and saved {len(docs_processed)} chunked documents to {chunked_doc_file}")

# 4) Connect to existing Chroma DB
client = chromadb.PersistentClient(
    path=chroma_db_path,
    settings=Settings(allow_reset=True)
)
mimic_collection = client.get_or_create_collection(
    name="mimic_notes_full",
    metadata={"hnsw:space": "cosine"}
)

# Load embedding model for semantic retrieval
embedding_model = AutoModel.from_pretrained(
    model_name,
    trust_remote_code=True,
    cache_dir=model_cache_dir,
    device_map="auto"
)

2025-04-29 11:05:13 - INFO - Environment setup complete for retrieval.
2025-04-29 11:05:15 - INFO - Loaded 41174 cases from /secure/shared_data/SOAP/MIMIC/full_cases_base.json
2025-04-29 11:05:19 - INFO - Loaded pre-chunked documents from chunked_documents.pkl
2025-04-29 11:05:19 - INFO - Anonymized telemetry enabled. See                     https://docs.trychroma.com/telemetry for more information.
2025-04-29 11:05:19 - INFO - PyTorch version 2.6.0+cu126 available.


Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

In [None]:
# # --------------------
# # Hybrid Query (Union)
# # --------------------
# def hybrid_query(
#     cases: Dict[str, Any],
#     docs: List[Document],
#     collection,
#     embedding_model,
#     query_text: str,
#     query_prefix: str,
#     max_length: int = 512,
#     semantic_k: int = 5,
#     bm25_k: int = 5,
#     bm25_weight: float = 0.5
# ) -> List[str]:

#     logger.info("Starting hybrid retrieval...")

#     # --- Semantic Retrieval ---
#     query_embedding = embedding_model.encode(
#         [query_text],
#         instruction=query_prefix,
#         max_length=max_length
#     ).cpu().numpy().tolist()

#     semantic_results = collection.query(
#         query_embeddings=query_embedding,
#         n_results=semantic_k
#     )

  
#     results_semantic_ids = [full_id.split("_")[0] for full_id in semantic_results["ids"][0]]
#     results_semantic_distances = semantic_results["distances"][0]
#     semantic_id_set = set(results_semantic_ids)
#     logger.info(f"Semantic top-{semantic_k} hadm_ids: {results_semantic_ids}")

#     # --- BM25 Retrieval ---
#     bm25_retriever = BM25Retriever.from_documents(docs, k=bm25_k)
#     bm25_results = bm25_retriever.get_relevant_documents(query_text)
#     results_bm25_ids = [doc.metadata["hadm_id"] for doc in bm25_results]
#     bm25_id_set = set(results_bm25_ids)
#     logger.info(f"BM25 top-{bm25_k} hadm_ids: {results_bm25_ids}")

#     # --------------------------
#     # Union of both sets
#     # --------------------------
#     combined_ids = semantic_id_set | bm25_id_set
#     logger.info(f"Union of semantic & BM25 -> total {len(combined_ids)} unique hadm_ids")

#     # Weighted scoring
#     semantic_weight = 1.0 - bm25_weight
#     ids_to_score = {}

#     # Build a ranking score
#     for hadm_id in combined_ids:
#         score = 0.0

#         if hadm_id in semantic_id_set:
#             idx_sem = results_semantic_ids.index(hadm_id)
#             score += semantic_weight * (1 / (idx_sem + 1))

#         if hadm_id in bm25_id_set:
#             idx_bm25 = results_bm25_ids.index(hadm_id)
#             score += bm25_weight * (1 / (idx_bm25 + 1))

#         ids_to_score[hadm_id] = score

#     # Sort by combined score
#     sorted_ids = sorted(ids_to_score.keys(), key=lambda x: ids_to_score[x], reverse=True)

#     # Build final result list with dictionaries
#     retrieved_docs = []
#     for doc_id in sorted_ids:
#         # Determine which retriever(s) provided this doc (logging only)
#         source_list = []
#         if doc_id in semantic_id_set:
#             source_list.append("Semantic")
#         if doc_id in bm25_id_set:
#             source_list.append("BM25")

#         logger.info(
#             f"Doc hadm_id={doc_id} => from {', '.join(source_list)}; combined_score={ids_to_score[doc_id]:.4f}"
#         )

#         before = cases[doc_id]["before_diagnosis"]
#         after = cases[doc_id]["after_diagnosis"]
#         final_text = f"{before}\nDischarge Diagnosis: {after}"

#         retrieved_docs.append({
#             "text": final_text,
#             "diagnosis": after,
#             "score": ids_to_score[doc_id]
#         })

#     logger.info(f"Retrieved {len(retrieved_docs)} docs total.")
#     return retrieved_docs

In [11]:
def semantic_query( # Renamed from hybrid_query
    cases: Dict[str, Any],
    collection,
    embedding_model,
    query_text: str,
    query_prefix: str,
    max_length: int = 512,
    semantic_k: int = 5,
    # Removed parameters: docs, bm25_k, bm25_weight
) -> List[Dict[str, Any]]: # Return type remains List of Dictionaries

    logger.info(f"Starting semantic retrieval for top {semantic_k} results...")
    start_semantic_time = time.time()

    # --- Semantic Retrieval ---
    query_embedding = embedding_model.encode(
        [query_text],
        instruction=query_prefix,
        max_length=max_length
    ).cpu().numpy().tolist()

    # Query ChromaDB
    semantic_results = collection.query(
        query_embeddings=query_embedding,
        n_results=semantic_k,
        include=['metadatas', 'documents', 'distances'] # Ensure distances are included
    )
    end_semantic_time = time.time()
    logger.info(f"ChromaDB query took {end_semantic_time - start_semantic_time:.2f} seconds.")

    # --- Process Results ---
    retrieved_docs = []
    if semantic_results and semantic_results["ids"] and len(semantic_results["ids"][0]) > 0:
        # Get the IDs, distances, and potentially metadata from the results
        result_ids = semantic_results["ids"][0]
        result_distances = semantic_results["distances"][0]
        # result_metadatas = semantic_results["metadatas"][0] # Optional, if needed

        logger.info(f"Retrieved {len(result_ids)} results from semantic search.")

        # Combine IDs and distances
        results_with_distances: List[Tuple[str, float]] = list(zip(result_ids, result_distances))

        # Sort by distance (ascending - closer results have smaller distances)
        # Chroma usually returns sorted results, but explicit sorting doesn't hurt
        results_with_distances.sort(key=lambda item: item[1])

        processed_hadm_ids = set() # Keep track to avoid duplicates if chunks from the same note are returned

        for full_id, distance in results_with_distances:
            try:
                # Extract hadm_id from the full document ID (e.g., "12345_10")
                hadm_id = full_id.split("_")[0]

                # Avoid adding the same original note multiple times if multiple chunks were retrieved
                if hadm_id in processed_hadm_ids:
                    continue

                # Check if hadm_id exists in the loaded cases data
                if hadm_id in cases:
                    before = cases[hadm_id]["before_diagnosis"]
                    after = cases[hadm_id]["after_diagnosis"]
                    # Combine the text; you might adjust this based on exactly what you want returned
                    final_text = f"{before}\nDischarge Diagnosis: {after}"

                    retrieved_docs.append({
                        "full_note": final_text,
                        "semantic_distance": f"{distance:.3f}", # Format distance to 3 decimal places
                        "hadm_id": hadm_id # Include hadm_id for reference
                    })
                    processed_hadm_ids.add(hadm_id)
                    logger.info(
                        f"Added Doc hadm_id={hadm_id} => distance={distance:.4f}"
                    )
                else:
                    # This handles the previous KeyError possibility more gracefully
                    logger.warning(f"Retrieved hadm_id={hadm_id} from Chroma, but not found in loaded cases data. Skipping.")

            except KeyError as e:
                 logger.error(f"Error processing result for id {full_id}: Missing key {e} in cases dictionary for hadm_id {hadm_id}. Skipping.")
            except Exception as e:
                 logger.error(f"Unexpected error processing result for id {full_id}: {e}. Skipping.")


    else:
        logger.warning("Semantic search returned no results.")


    logger.info(f"Retrieved {len(retrieved_docs)} unique documents total based on semantic search.")
    return retrieved_docs

In [36]:
# 5) Hybrid Retrieval
max_length_chunk = 0
max_length_query = 0
query_prefix = (
    "Given the following clinical note, retrieve the most similar clinical case. "
    "The clinical note is:\n\n"
)

input_df = pd.read_csv("/home/yl3427/cylab/SOAP_MA/Input/SOAP_all_problems.csv", lineterminator="\n")
for index, row in input_df.iterrows():
    query_text = f"{row['Subjective']}\n{row['Objective']}"
    logger.info(f"Number of tokens in query text: {count_llama_tokens(query_text)}")
    if count_llama_tokens(query_text) > max_length_query:
        max_length_query = count_llama_tokens(query_text)

    retrieved = semantic_query(
        cases=cases,
        # docs=docs_processed, # Removed BM25 docs
        collection=mimic_collection,
        embedding_model=embedding_model,
        query_text=query_text,
        query_prefix=query_prefix,
        max_length=512,
        semantic_k=3 # Keep semantic_k (or adjust as needed)
    )



    # 6) Print out final docs
    logger.info("---------- FINAL RETRIEVED DOCS ----------")
    if count_llama_tokens(retrieved) > max_length_chunk:
        max_length_chunk = count_llama_tokens(json.dumps(retrieved, indent=2) + query_text)

    # for idx, doc_text in enumerate(retrieved, start=1):
    #     print(f"\n--- Number of tokens in doc {idx}: {count_llama_tokens(doc_text['full_note'])} ---")
    #     if count_llama_tokens(doc_text['full_note']) > max_length_chunk:
    #         max_length_chunk = count_llama_tokens(doc_text['full_note'])

    # # --- New: Save results ---
    # json_out = "log/final_retrieved_docs_2.json"
    # with open(json_out, "w") as jf:
    #     json.dump(retrieved, jf, indent=2)
    # logger.info(f"Final docs saved as JSON to {json_out}")

    # csv_out = "log/final_retrieved_docs_2.csv"
    # pd.DataFrame(retrieved).to_csv(csv_out, index=False)
    # logger.info(f"Final docs saved as CSV to {csv_out}")

2025-04-29 11:45:38 - INFO - Number of tokens in query text: 1418
2025-04-29 11:45:38 - INFO - Starting semantic retrieval for top 3 results...
  'input_ids': torch.tensor(batch_dict.get('input_ids').to(batch_dict.get('input_ids')).long()),
  self.gen = func(*args, **kwds)
2025-04-29 11:45:39 - INFO - ChromaDB query took 0.20 seconds.
2025-04-29 11:45:39 - INFO - Retrieved 3 results from semantic search.
2025-04-29 11:45:39 - INFO - Added Doc hadm_id=186751 => distance=0.3024
2025-04-29 11:45:39 - INFO - Added Doc hadm_id=104752 => distance=0.3055
2025-04-29 11:45:39 - INFO - Added Doc hadm_id=103986 => distance=0.3094
2025-04-29 11:45:39 - INFO - Retrieved 3 unique documents total based on semantic search.
2025-04-29 11:45:39 - INFO - ---------- FINAL RETRIEVED DOCS ----------
2025-04-29 11:45:39 - INFO - Number of tokens in query text: 700
2025-04-29 11:45:39 - INFO - Starting semantic retrieval for top 3 results...
2025-04-29 11:45:39 - INFO - ChromaDB query took 0.20 seconds.
2025-

In [38]:
max_length_chunk, max_length_query

(24170, 2050)

In [16]:
retrieved
count_llama_tokens(retrieved[1])

6912

In [35]:
print(json.dumps(retrieved, indent=2))

[
  {
    "full_note": "admission date:  [**2112-9-9**]              discharge date:   [**2112-9-14**]\n\n\nservice: med\n\nallergies:\npenicillins\n\nattending:[**first name3 (lf) 4052**]\nchief complaint:\ntachycardia\n\nmajor surgical or invasive procedure:\nnone\n\nhistory of present illness:\n[**age over 90 **] yo female w/ htn, dm, dementia, recent admit to st. [**female first name (un) **]\n[**date range (1) 19038**] for pna (completed course of azithromycin [**9-1**] and 10\nday course of levoquin [**9-8**]), no known cad, presents from\ncoolridge nh today after 2 weeks of unexplained tachycardia in\n130-140s (for which her lopressor dosing had been titrated). she\nhad also developed arf over the past several weeks (cr. 1.3-->\n2.9). she denies cp, sob, abd pain, dysuria, diarrhea, fever,\nchills, nausea, vomitting though has had relatively low po\nintake.\nin ed: found to be hypotensive/febrile after 3 liters ns\nvs:  t 100.6, hr 126, bp 88/60 (map 65), rr 18-20 o2 sat 100%\n3

In [None]:
input_df = pd.read_csv("/home/yl3427/cylab/SOAP_MA/Input/SOAP_3_problems_mini.csv", lineterminator="\n")
row = input_df.iloc[5]
note  = f"{row['Subjective']}\n{row['Objective']}"
note

In [None]:
row.Summary

In [None]:
chroma_db_path = "/secure/shared_data/rag_embedding_model/chroma_db"

client = chromadb.PersistentClient(
    path=chroma_db_path,
    settings=Settings(allow_reset=True)
)

In [None]:
client.list_collections()

In [None]:
mimic_collection = client.get_or_create_collection(
    name="mimic_notes_full",
    metadata={"hnsw:space": "cosine"}
)

In [None]:
mimic_collection.count()  # Check the number of documents in the collection

In [None]:
mimic_collection.peek(1)

In [None]:
mimic_collection.peek(2)['metadatas']