In [2]:
from google.colab import auth
auth.authenticate_user()


In [3]:
from google.cloud import bigquery

project_id = "ai-in-medicine-unihamburg"
client = bigquery.Client(project=project_id)

query = """
SELECT
    n.subject_id,
    n.hadm_id,
    ANY_VALUE(n.text) AS note_text,
    -- Aggregating all codes for this specific admission
    STRING_AGG(dx.icd_code, ', ') AS icd_codes,
    -- Aggregating all descriptions for this specific admission
    STRING_AGG(d.long_title, ' | ') AS icd_descriptions
FROM `physionet-data.mimiciv_note.discharge` AS n
JOIN `physionet-data.mimiciv_3_1_hosp.diagnoses_icd` AS dx
    ON n.hadm_id = dx.hadm_id
JOIN `physionet-data.mimiciv_3_1_hosp.d_icd_diagnoses` AS d
    ON dx.icd_code = d.icd_code
WHERE dx.icd_version = 10
GROUP BY n.subject_id, n.hadm_id
-- This magic line ensures only 1 admission (hadm_id) per patient (subject_id)
QUALIFY ROW_NUMBER() OVER (PARTITION BY n.subject_id ORDER BY n.hadm_id ASC) = 1
LIMIT 10000
"""

df = client.query(query).to_dataframe()

print(df.head())
print(df.tail())

   subject_id   hadm_id                                          note_text  \
0    10003502  29011269   \nName:  ___                 Unit No:   ___\n...   
1    10015931  22130791   \nName:  ___                     Unit No:   _...   
2    10045395  29383457   \nName:  ___                  Unit No:   ___\...   
3    10046234  21162300   \nName:  ___                  Unit No:   ___\...   
4    10048986  23742207   \nName:  ___                     Unit No:   _...   

                                           icd_codes  \
0  R001, I5033, F0391, I2582, E871, Z9181, I2510,...   
1  K921, I214, N170, J9601, R571, N184, G9340, D6...   
2        I771, M79A22, M79A21, G43909, J45909, G8918   
3  I97120, I472, E1122, I130, D696, Y838, Y92530,...   
4  K2211, D649, I351, K449, K219, N400, G8929, M4...   

                                    icd_descriptions  
0  Bradycardia, unspecified | Acute on chronic di...  
1  Melena | Non-ST elevation (NSTEMI) myocardial ...  
2  Stricture of artery | Nont

In [4]:
import pandas as pd
from collections import Counter


# parse ICD codes
def parse_icd_list(icd_str):
    """Convert a string like into a list of strings"""
    icd_str = icd_str.strip("[]")  # remove brackets
    codes = [code.strip() for code in icd_str.split(",")]
    return codes

df['icd_code_list'] = df['icd_codes'].apply(parse_icd_list)

# basic statistics
total_samples = len(df)
unique_patients = df['subject_id'].nunique()
note_lengths = df['note_text'].apply(lambda x: len(str(x).split()))
avg_note_len = note_lengths.mean()
min_note_len = note_lengths.min()
max_note_len = note_lengths.max()

# flatten all ICD codes and compute frequency
all_icd_codes = [code for codes in df['icd_code_list'] for code in codes]
icd_counts = Counter(all_icd_codes)
total_unique_icd = len(icd_counts)
top_10_icd = icd_counts.most_common(10)


print(f"Total samples: {total_samples}")
print(f"Total unique patients: {unique_patients}")
print(f"Average note length (words): {avg_note_len:.2f} (min={min_note_len}, max={max_note_len})")
print(f"Total unique ICD codes: {total_unique_icd}")
print(f"Top 10 most frequent ICD codes: {top_10_icd}")

# count how many notes each patient has
notes_per_patient = df.groupby('subject_id').size()



Total samples: 10000
Total unique patients: 10000
Average note length (words): 1647.43 (min=154, max=7120)
Total unique ICD codes: 7335
Top 10 most frequent ICD codes: [('I10', 3772), ('E785', 3481), ('Z87891', 2653), ('K219', 2195), ('I2510', 1609), ('F329', 1596), ('F419', 1441), ('N179', 1285), ('E119', 1143), ('E039', 1139)]


In [None]:
import re
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MultiLabelBinarizer
from transformers import AutoTokenizer

# text Cleaning Function ---
def clean_text(text):
    """
    Clean discharge summary text for LLM processing.
    """
    # Remove placeholders
    text = re.sub(r'___+', '', text)
    # Remove newlines and collapse multiple spaces
    text = re.sub(r'\n', ' ', text)
    text = re.sub(r'\s+', ' ', text)
    return text.strip().lower()

# tokenization & Truncation
# i use a standard tokenizer to measure length and truncate to fit LLM context windows.

tokenizer = AutoTokenizer.from_pretrained("sentence-transformers/all-MiniLM-L6-v2")

def preprocess_and_truncate(text, max_length=2048):
    cleaned = clean_text(text)
    # truncate based on tokens, not characters, to ensure it fits the model
    tokens = tokenizer.encode(cleaned, truncation=True, max_length=max_length, add_special_tokens=True)
    return tokenizer.decode(tokens, skip_special_tokens=True)

print("Preprocessing and truncating notes...")
df['note_text_final'] = df['note_text'].apply(lambda x: preprocess_and_truncate(x))

# multi-Label Encoding (For DNN)
# Each note has multiple ICD-10 codes. i convert these into a binary matrix.
mlb = MultiLabelBinarizer()
y = mlb.fit_transform(df['icd_code_list'])

# Data Splitting (80/10/10 Strategy)

X_train, X_temp, y_train, y_temp = train_test_split(
    df['note_text_final'], y, test_size=0.20, random_state=42
)

# Split the remaining 20% equally into Validation and Test
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.50, random_state=42
)

print("-" * 30)
print(f"Preprocessing Complete.")
print(f"Train set: {len(X_train)} samples")
print(f"Val set:   {len(X_val)} samples")
print(f"Test set:  {len(X_test)} samples")
print(f"Total unique ICD-10 labels: {len(mlb.classes_)}")
print("-" * 30)

In [6]:
# efficient way to get both IDs and Mask
def tokenize_function(text):
    return tokenizer(
        text,
        truncation=True,
        max_length=512,
        padding='max_length', # useful for the DNN later
        return_tensors=None
    )

# Apply once to get a dictionary
tokenized_data = df['note_text_final'].apply(tokenize_function)

# Split the dictionary into two columns
df['input_ids'] = tokenized_data.apply(lambda x: x['input_ids'])
df['attention_mask'] = tokenized_data.apply(lambda x: x['attention_mask'])

In [None]:
!pip install -U langchain langchain-community chromadb sentence-transformers



In [8]:
!pip install -qU langchain-ollama

In [9]:
import os
import json
import warnings
from dotenv import load_dotenv
from langchain_community.llms import Ollama
from langchain_community.embeddings import OllamaEmbeddings

# Setup
warnings.filterwarnings("ignore", category=DeprecationWarning)
load_dotenv('/content/llama.env') # Path in Colab


api_url = "https://dev.chat.cosy.bio/ollama"
llm_name = "llama3.1:8b"
embedding_model = "nomic-embed-text:latest"

api_key = os.getenv("COSYBIO_API_KEY")
headers = {"Authorization": f"Bearer {api_key}"}

# initialize (Using the stable community classes)
chat_llm = Ollama(
    base_url=api_url,
    model=llm_name,
    temperature=0.0,
    headers=headers,
    format="json"
)

ollama_embeddings = OllamaEmbeddings(
    base_url=api_url,
    model=embedding_model,
    headers=headers
)

print(f"Colab Setup Complete!")

Colab Setup Complete!


In [10]:
import requests

url = "https://dev.chat.cosy.bio/ollama/api/tags"
headers = {
    "Authorization": f"Bearer {api_key}"
}

r = requests.get(url, headers=headers)
print(r.status_code)
print(r.text)


200
{"models":[{"name":"nomic-embed-text-v2-moe:latest","model":"nomic-embed-text-v2-moe:latest","modified_at":"2025-12-16T09:50:31.214529493Z","size":957680763,"digest":"ff9c2f10ef5e3722623a1b396e1e04efc27a93112c83e9b7b7b9ca1d05620965","details":{"parent_model":"","format":"gguf","family":"nomic-bert-moe","families":["nomic-bert-moe"],"parameter_size":"475.29M","quantization_level":"F16"},"connection_type":"local","urls":[0]},{"name":"nemotron-3-nano:latest","model":"nemotron-3-nano:latest","modified_at":"2025-12-16T09:48:18.540581906Z","size":24271934866,"digest":"b725f11174073334edd0c2ff396b8d4e66d7dab22a5a63717ccad5a08a270cf1","details":{"parent_model":"","format":"gguf","family":"nemotron_h_moe","families":["nemotron_h_moe"],"parameter_size":"31.6B","quantization_level":"Q4_K_M"},"connection_type":"local","urls":[0]},{"name":"nemotron-3-nano:30b","model":"nemotron-3-nano:30b","modified_at":"2025-12-16T09:48:08.170742323Z","size":24271934866,"digest":"b725f11174073334edd0c2ff396b8d

In [11]:
print("--- Colab Sanity Check ---")
try:
    # Test LLM
    prompt = "What is an iPhone? Respond with a JSON object: {'definition': 'short string'}"
    res = chat_llm.invoke(prompt)
    print(" LLM Response:", json.loads(res))

    # Test Embeddings
    vec = ollama_embeddings.embed_query("Colab test")
    print(f" Embeddings Working! (Length: {len(vec)})")
except Exception as e:
    print(f"Connection Failed: {e}")

--- Colab Sanity Check ---
 LLM Response: {'definition': 'A line of smartphones designed, developed, and marketed by Apple Inc.'}
 Embeddings Working! (Length: 768)


In [12]:


from langchain_ollama import OllamaEmbeddings

# Use client_kwargs to pass headers to the underlying httpx client
ollama_embeddings = OllamaEmbeddings(
    base_url="https://dev.chat.cosy.bio/ollama",
    model="nomic-embed-text:latest",
    client_kwargs={
        "headers": {
            "Authorization": f"Bearer {api_key}"
        }
    }
)

# Test the connection
try:
    vector = ollama_embeddings.embed_query("Test clinical note")
    print(f"Success! Vector length: {len(vector)}")
except Exception as e:
    print(f"Connection failed: {e}")

Success! Vector length: 768


In [13]:
from langchain_ollama import OllamaEmbeddings

# overwrite the old embeddings variable with the working one
embeddings = ollama_embeddings

In [14]:

# embedding a clinical note for later DNN
text = "Patient presenting with acute myocardial infarction..."
query_vector = embeddings.embed_query(text)
print(f"Vector Dimension: {len(query_vector)}") # Should be 768

Vector Dimension: 768


In [15]:
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity

# 1. Define three sentences: two similar, one different
clinical_notes = [
    "Patient presenting with acute myocardial infarction.", # Target
    "Emergency admission for a heart attack.",            # Semantically similar
    "Patient scheduled for routine skin checkup."          # Semantically different
]

# 2. Generate embeddings
vectors = embeddings.embed_documents(clinical_notes)

# 3. Calculate Cosine Similarity
# Similarity between 0 and 1 (1 = identical)
sim_heart_attack = cosine_similarity([vectors[0]], [vectors[1]])[0][0]
sim_skin_check = cosine_similarity([vectors[0]], [vectors[2]])[0][0]

print(f"--- Embedding Sanity Check ---")
print(f"Similarity (Infarction vs Heart Attack): {sim_heart_attack:.4f}")
print(f"Similarity (Infarction vs Skin Check):   {sim_skin_check:.4f}")

# 4. Logic Check
if sim_heart_attack > sim_skin_check:
    print("\n PASS: The model understands that 'Infarction' is closer to 'Heart Attack'.")
else:
    print("\n FAIL: The model is not distinguishing clinical meanings correctly.")

--- Embedding Sanity Check ---
Similarity (Infarction vs Heart Attack): 0.7115
Similarity (Infarction vs Skin Check):   0.4698

 PASS: The model understands that 'Infarction' is closer to 'Heart Attack'.


In [16]:
!pip install -qU langchain-chroma

In [17]:
import pandas as pd
import re

# 1. Extract the pairs from aggregated BigQuery strings
knowledge_base = []
for _, row in df.iterrows():
    codes = [c.strip() for c in str(row['icd_codes']).split(',')]
    titles = [t.strip() for t in str(row['icd_descriptions']).split('|')]
    for code, title in zip(codes, titles):
        knowledge_base.append({'icd_code': code, 'icd_description': title})

# 2. Create the lookup table and remove duplicates
df_icd_lookup = pd.DataFrame(knowledge_base).drop_duplicates(subset=['icd_code'])

# 3. CRITICAL: Clean the "Junk" (brackets and quotes)
df_icd_lookup['icd_code'] = (
    df_icd_lookup['icd_code']
    .astype(str)
    .str.replace(r"[\[\]']", "", regex=True)
    .str.strip()
)

print(f"Cleaned {len(df_icd_lookup)} unique codes. Ready for indexing.")
df_icd_lookup.head()

Cleaned 7335 unique codes. Ready for indexing.


Unnamed: 0,icd_code,icd_description
0,R001,"Bradycardia, unspecified"
1,I5033,Acute on chronic diastolic (congestive) heart ...
2,F0391,"Unspecified dementia, unspecified severity, wi..."
3,I2582,Chronic total occlusion of coronary artery
4,E871,Hypo-osmolality and hyponatremia


In [18]:
from langchain_core.documents import Document
from langchain_chroma import Chroma
import shutil
import os


# Convert to LangChain Documents
icd_documents_list = [
    Document(page_content=row['icd_description'], metadata={"icd_code": row['icd_code']})
    for _, row in df_icd_lookup.iterrows()
]

# Build the store
vector_db = Chroma.from_documents(
    documents=icd_documents_list,
    embedding=embeddings,
    persist_directory= '/content/chroma_db_v2'
)

print("Success! Your Vector Store is now built and saved.")

Success! Your Vector Store is now built and saved.


In [19]:
# 1. Define a clinical test query
test_query = "The patient has a history of high blood pressure and chronic kidney disease."

# 2. Search the vector store
# k=3 retrieves the top 3 most similar matches
results = vector_db.similarity_search_with_score(test_query, k=3)

print(f"--- Sanity Check Results for: '{test_query}' ---\n")

for i, (doc, score) in enumerate(results):
    # Retrieve metadata we saved earlier
    icd_code = doc.metadata.get('icd_code', 'MISSING')
    description = doc.page_content

    # In Chroma, lower scores = higher similarity (it's measuring distance)
    print(f"Rank {i+1} (Distance: {score:.4f})")
    print(f"  ICD-10 Code: {icd_code}")
    print(f"  Description: {description}")
    print("-" * 50)

# 3. Final logic check
if results[0][1] < 1.0 and results[0][0].metadata.get('icd_code') != 'MISSING':
    print("\n PASS: The vector store is returning relevant codes with metadata.")
else:
    print("\n FAIL: Check if metadata keys match or if the embeddings are empty.")

--- Sanity Check Results for: 'The patient has a history of high blood pressure and chronic kidney disease.' ---

Rank 1 (Distance: 0.4874)
  ICD-10 Code: D631
  Description: Anemia in chronic kidney disease
--------------------------------------------------
Rank 2 (Distance: 0.4937)
  ICD-10 Code: I130
  Description: Hypertensive heart and chronic kidney disease with heart failure and stage 1 through stage 4 chronic kidney disease, or unspecified chronic kidney disease
--------------------------------------------------
Rank 3 (Distance: 0.5110)
  ICD-10 Code: I129
  Description: Hypertensive chronic kidney disease with stage 1 through stage 4 chronic kidney disease, or unspecified chronic kidney disease
--------------------------------------------------

 PASS: The vector store is returning relevant codes with metadata.


In [20]:
def medical_retriever(query, k=10):
    """
    Finds the most relevant ICD-10 codes and removes duplicates
    to provide the cleanest context for the LLM.
    """
    # 1. Search the vector store
    results = vector_db.similarity_search(query, k=k)

    # 2. De-duplicate: Ensure the LLM doesn't see the same code twice
    seen_codes = set()
    context_chunks = []

    for doc in results:
        code = doc.metadata.get('icd_code')
        description = doc.page_content

        if code not in seen_codes:
            context_chunks.append(f"Code {code}: {description}")
            seen_codes.add(code)

    # 3. Join into a single string for the prompt
    # We take the top 5 unique ones to keep the prompt concise
    formatted_context = "\n".join(context_chunks[:5])
    return formatted_context

# Test the engine
clinical_snippet = "Patient has chronic renal failure and persistent hypertension."
context_for_llm = medical_retriever(clinical_snippet)

print("--- Context Ready for LLM ---")
print(context_for_llm)

--- Context Ready for LLM ---
Code N990: Postprocedural (acute) (chronic) kidney failure
Code D631: Anemia in chronic kidney disease
Code N19: Unspecified kidney failure
Code I130: Hypertensive heart and chronic kidney disease with heart failure and stage 1 through stage 4 chronic kidney disease, or unspecified chronic kidney disease
Code N179: Acute kidney failure, unspecified


In [21]:
# 1. Pick a sample note (Index 100 is usually a good, mid-sized note)
sample_index = 100
sample_note = df.loc[sample_index, "note_text"]

# 2. Run my cleaned retriever
retrieved_context = retrieved_context = medical_retriever(sample_note[:2000], k=10)

# 3. Print the results for manual inspection
print(f"--- [SANITY CHECK] Testing Note Index: {sample_index} ---")
print("\n--- RETRIEVED ICD-10 CANDIDATES ---")
print(retrieved_context)

print("\n" + "="*50)
print("--- ACTUAL CLINICAL NOTE (First 2000 chars) ---")
print(sample_note[:2000])

--- [SANITY CHECK] Testing Note Index: 100 ---

--- RETRIEVED ICD-10 CANDIDATES ---
Code R5082: Postprocedural fever
Code M041: Periodic fever syndromes
Code R502: Drug induced fever
Code I69992: Facial weakness following unspecified cerebrovascular disease
Code J09X1: Influenza due to identified novel influenza A virus with pneumonia

--- ACTUAL CLINICAL NOTE (First 2000 chars) ---
 
Name:  ___                   Unit No:   ___
 
Admission Date:  ___              Discharge Date:   ___
 
Date of Birth:  ___             Sex:   F
 
Service: MEDICINE
 
Allergies: 
heparin / Coumadin
 
Attending: ___
 
Chief Complaint:
fever
 
Major Surgical or Invasive Procedure:
___ PICC line placement
 
History of Present Illness:
Ms. ___ is a ___ lady with glioblastoma who is
admitted from the ED with fever and found to have gram negative
bacteremia.

Patient reports that she awoke with chills early on the morning
of admission. The chills progressed and when she attempted to 
get
out of her recliner, sh

In [None]:
# from google.colab import drive
# import shutil

# drive.mount('/content/drive')

# # Copy the database to your permanent Google Drive
# shutil.copytree('/content/chroma_db_v2', '/content/drive/MyDrive/my_icd10_db_v3.0')
# print("Vector DB safely backed up to Google Drive!")

# Zero-Shot LLM Classification

To implement the Zero-Shot LLM Classification, i will prompt Llama  to act as a coder using only its internal knowledge.



In [23]:
def run_zero_shot_json(note_text):
    # Constructing the prompt for JSON output
    prompt = f"""
    You are an expert medical coder. Analyze the clinical note provided.

    ### Task:
    1. Identify the primary and secondary diagnoses.
    2. Map them to the most accurate ICD-10 codes.
    3. Use chain-of-thought reasoning to explain your choices.

    ### Clinical Note:
    {note_text[:2000]}

    ### Output Requirements (JSON only):
    Return a JSON object with exactly these two keys:
    "reasoning": "your step-by-step clinical analysis",
    "icd_10_codes": ["CODE1", "CODE2", "CODE3"]
    """

    # Using your initialized chat_llm
    response = chat_llm.invoke(prompt)

    # Since format="json" is enabled, we can load it directly
    return json.loads(response)

# Run the test
print("--- RUNNING ZERO-SHOT JSON CLASSIFICATION ---")
zero_shot_json_output = run_zero_shot_json(sample_note)

print(json.dumps(zero_shot_json_output, indent=4))

--- RUNNING ZERO-SHOT JSON CLASSIFICATION ---
{
    "reasoning": "The patient, a  female with glioblastoma, presents with fever, weakness, and bacteremia. The patient's symptoms and lab results indicate a severe infection. The patient's history of glioblastoma and the presence of gram-negative bacteremia suggest that the infection is likely related to the cancer or its treatment. The patient's fever, chills, and weakness are consistent with sepsis, a life-threatening condition that requires immediate attention. The patient's lab results, including the elevated white blood cell count and lactate level, support this diagnosis. The patient's history of recent PICC line placement and the presence of bacteremia in the blood cultures suggest that the infection is likely related to the PICC line. Therefore, the primary diagnosis is sepsis due to a central line-associated bloodstream infection (CLABSI). The secondary diagnosis is glioblastoma, the underlying cancer that may be contributing to 

# The RAG Classification Script

This function will:

Retrieve: Use your medical_retriever to get the top candidates.

Contextualize: Insert those descriptions into a structured prompt.

Classify: Let Llama 3.1 act as the judge to decide which candidates actually apply.

In [24]:
def run_rag_classification(note_text):
    # 1. Get the 'Shortlist' of codes from your Vector Store
    # We use the first 2000 chars of the note to find the codes
    retrieved_context = medical_retriever(note_text[:2000], k=10)

    # 2. Build the RAG Prompt
    prompt = f"""
    You are an expert medical coder. You are provided with a clinical note and a list of POTENTIAL ICD-10 codes retrieved from a knowledge base.

    ### TASK:
    - Review the Clinical Note.
    - Review the Potential ICD-10 Codes (Candidate List).
    - Determine which codes from the Candidate List are supported by the Clinical Note.
    - If a code is not in the candidate list but you are 100% certain it applies based on the note, you may add it.

    ### CANDIDATE LIST (from Knowledge Base):
    {retrieved_context}

    ### CLINICAL NOTE:
    {note_text[:2000]}

    ### Output Requirements (JSON only):
    Return a JSON object with:
    "reasoning": "Explain why you chose specific codes from the candidate list based on clinical evidence.",
    "predicted_codes": ["CODE1", "CODE2"]
    """

    # 3. Invoke your Chat LLM
    response = chat_llm.invoke(prompt)
    return json.loads(response)

# Run the RAG test
print("--- RUNNING RAG CLASSIFICATION ---")
rag_output = run_rag_classification(sample_note)

print(json.dumps(rag_output, indent=4))

--- RUNNING RAG CLASSIFICATION ---
{
    "reasoning": "The patient presents with fever, which is a key symptom. The clinical note mentions that the patient has gram-negative bacteremia, which is a clear indication of an infection. The patient also reports weakness, which could be related to the infection or the underlying glioblastoma. The presence of a PICC line placement is also noted, which could be a potential source of infection. Based on these findings, the following codes are supported by the clinical note:",
    "predicted_codes": [
        "I69992",
        "R5082",
        "M041"
    ]
}
