new


In [3]:
!head -1  '/content/icd_11_prompt_delimited.txt'

Disorder Name:   Disorder  of intellectual development, mild   |  Disorder Code: 6A00.0   |  Disorder Symptoms: A mild   Disorder  of intellectual development is a condition originating during the developmental period characterised by significantly below average intellectual functioning and adaptive behaviour that are approximately two to three standard deviations below the mean (approximately 0.1   2.3 percentile), based on appropriately normed, individually administered standardized tests or by comparable behavioural indicators when standardized testing is unavailable. Affected persons often exhibit difficulties in the acquisition and comprehension of complex language concepts and academic skills. Most master basic self-care, domestic, and practical activities. Persons affected by a mild   Disorder  of intellectual development can generally achieve relatively independent living and employment as adults but may require appropriate support.  Disorder  of intellectual development, mild 

In [None]:
import csv
from collections import defaultdict
import numpy as np
from sentence_transformers import SentenceTransformer
import re
from collections import Counter
import math
import pandas as pd
import os

# File paths
original_file = '/content/icd_11_prompt_delimited.txt'  # Format: Disorder Name|Disorder Code|Disorder Symptoms
user_symptoms_file = '/content/enhanced_psychology_symptoms.txt'  # Format: Symptoms only (one per line)
high_similarity_file = '/content/sapbert_high_similarity_pairs.csv'  # Output: User symptoms|predicted ICD Code|predicted Symptoms|Cosine Similarity
all_similarity_file = '/content/sapbert_all_similarity_pairs.csv'  # New: All comparisons
unmatched_log = '/content/sapbert_unmatched_icd_codes.txt'
low_similarity_log = '/content/sapbert_low_similarity_pairs.csv'
original_codes_log = '/content/sapbert_original_codes.txt'

# Debug input files
for file_path in [original_file, user_symptoms_file]:
    if not os.path.exists(file_path):
        print(f"Error: File {file_path} not found.")
        exit()
    print(f"File size of {file_path}: {os.path.getsize(file_path)} bytes")
    with open(file_path, 'r', encoding='utf-8') as f:
        print(f"First 2 lines of {file_path}:")
        for i, line in enumerate(f):
            if i < 2:
                print(line.strip())
            else:
                break

# Expanded typo/synonym correction dictionary for ICD terms
typo_corrections = {
    'behaviour': 'behavior',
    'behaviours': 'behaviors',
    'adaptative': 'adaptive',
    'intelectual': 'intellectual',
    'standerd': 'standard',
    'deviation': 'deviations',
    'cognitive': 'intellectual',  # Medical synonym
    'mental': 'intellectual',
    'adaptive behaviour': 'adaptive behavior',
    'belowaverage': 'below average',
    'standardised': 'standardized',
    'percentile': 'percentiles',
}

# Normalize ICD code: remove prefixes, uppercase, strip whitespace
def normalize_code(code):
    code = re.sub(r'^\s*Disorder Code:\s*', '', code, flags=re.IGNORECASE)
    code = re.sub(r'\s+', '', code)
    return code.upper()

# Preprocess text: lowercase, remove punctuation, correct typos/synonyms
def preprocess_text(text):
    text = text.lower()
    text = re.sub(r'[^\w\s]', '', text)
    for wrong, correct in typo_corrections.items():
        text = text.replace(wrong, correct)
    return text

# TF-IDF cosine similarity function (fallback)
def compute_tf_idf_cosine(original, user):
    docs = [preprocess_text(original), preprocess_text(user)]
    vocab = list(set(word for doc in docs for word in doc.split()))
    tf_docs = [[Counter(doc.split()).get(word, 0) / len(doc.split()) for word in vocab] for doc in docs]
    N = len(docs)
    idf = [math.log(N / (1 + sum(1 for doc in docs if word in doc))) for word in vocab]
    tfidf_docs = [[tf * idf_val for tf, idf_val in zip(tf_vector, idf)] for tf_vector in tf_docs]
    vec1, vec2 = np.array(tfidf_docs[0]), np.array(tfidf_docs[1])
    dot = np.dot(vec1, vec2)
    norm1, norm2 = np.linalg.norm(vec1), np.linalg.norm(vec2)
    return dot / (norm1 * norm2) if norm1 > 0 and norm2 > 0 else 0.0

# Read original ICD records
original_records = {}
original_codes = set()
with open(original_file, 'r', encoding='utf-8') as f:
    reader = csv.reader(f, delimiter='|')
    for row in reader:
        if len(row) == 3:
            name, code, symptoms = [field.strip() for field in row]
            code = normalize_code(code)
            original_records[code] = {'name': name, 'symptoms': preprocess_text(symptoms), 'raw_symptoms': symptoms}
            original_codes.add(code)
        else:
            print(f"Skipping malformed original record: {row}")

# Save original codes
with open(original_codes_log, 'w', encoding='utf-8') as f:
    for code in sorted(original_codes):
        f.write(f"{code}\n")
print(f"Saved {len(original_codes)} original ICD codes to {original_codes_log}")

# Read user symptoms
user_symptoms_list = []
with open(user_symptoms_file, 'r', encoding='utf-8') as f:
    for line in f:
        symptom = line.strip()
        if symptom:
            user_symptoms_list.append(preprocess_text(symptom))
print(f"Read {len(user_symptoms_list)} user symptom records from {user_symptoms_file}")

# Initialize embedding method with SapBERT
try:
    embedding_function = SentenceTransformer('cambridgeltl/SapBERT-from-PubMedBERT-fulltext')
    print("Using SapBERT embeddings")
    use_pubmedbert = True
except Exception as e:
    print(f"Error initializing SapBERT: {e}. Falling back to TF-IDF.")
    use_pubmedbert = False

# Compute cosine similarities with progress tracking and batching
results = []
low_similarity_pairs = []
high_similarity_pairs = []
total_pairs = len(user_symptoms_list) * len(original_records)
processed_pairs = 0
batch_size = 100

# Pre-embed original symptoms for efficiency
if use_pubmedbert:
    original_embeddings = {
        code: embedding_function.encode(data['symptoms'])
        for code, data in original_records.items()
    }

# Process user symptoms in batches
for i in range(0, len(user_symptoms_list), batch_size):
    batch = user_symptoms_list[i:i + batch_size]

    # Embed batch of user symptoms
    if use_pubmedbert:
        user_embeddings = embedding_function.encode(batch)

    for j, user_symptoms in enumerate(batch):
        for original_code, original_data in original_records.items():
            try:
                if use_pubmedbert:
                    user_vector = np.array(user_embeddings[j])
                    original_vector = original_embeddings[original_code]
                    cosine_similarity = np.dot(original_vector, user_vector) / (
                        np.linalg.norm(original_vector) * np.linalg.norm(user_vector)
                    )
                else:
                    cosine_similarity = compute_tf_idf_cosine(original_data['symptoms'], user_symptoms)

                results.append([user_symptoms, original_code, original_data['raw_symptoms'], cosine_similarity])
                if cosine_similarity < 0.9:
                    low_similarity_pairs.append([user_symptoms, original_code, original_data['raw_symptoms'], f"{cosine_similarity:.4f}"])
                if cosine_similarity > 0.75:
                    high_similarity_pairs.append([user_symptoms, original_code, original_data['raw_symptoms'], f"{cosine_similarity:.4f}"])

                # Debug similarity scores
                if processed_pairs < 10 or cosine_similarity > 0.9:
                    print(f"Similarity {cosine_similarity:.4f} for user symptoms vs ICD code {original_code}")

                processed_pairs += 1
                if processed_pairs % 1000 == 0:
                    print(f"Processed {processed_pairs}/{total_pairs} pairs ({processed_pairs/total_pairs*100:.2f}%)")
            except Exception as e:
                print(f"Error processing user symptoms with ICD code {original_code}: {e}")

# Save all similarity pairs
if results:
    with open(all_similarity_file, 'w', newline='', encoding='utf-8') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['User symptoms', 'predicted ICD Code', 'predicted Symptoms', 'Cosine Similarity'])
        writer.writerows([[row[0], row[1], row[2], f"{row[3]:.4f}"] for row in results])
    print(f"Logged {len(results)} total similarity pairs to {all_similarity_file}")

# Save high similarity pairs (> 0.75)
if high_similarity_pairs:
    with open(high_similarity_file, 'w', newline='', encoding='utf-8') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['User symptoms', 'predicted ICD Code', 'predicted Symptoms', 'Cosine Similarity'])
        writer.writerows(high_similarity_pairs)
    print(f"Logged {len(high_similarity_pairs)} pairs with similarity > 0.75 to {high_similarity_file}")

# Save low similarity pairs
if low_similarity_pairs:
    with open(low_similarity_log, 'w', newline='', encoding='utf-8') as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow(['User symptoms', 'predicted ICD Code', 'predicted Symptoms', 'Cosine Similarity'])
        writer.writerows(low_similarity_pairs)
    print(f"Logged {len(low_similarity_pairs)} pairs with similarity < 0.9 to {low_similarity_log}")

# Log unmatched ICD codes
matched_codes = set([row[1] for row in high_similarity_pairs])
unmatched_codes = [code for code in original_codes if code not in matched_codes]
if unmatched_codes:
    with open(unmatched_log, 'w', encoding='utf-8') as f:
        for code in unmatched_codes:
            f.write(f"{code}\n")
    print(f"Logged {len(unmatched_codes)} unmatched ICD codes to {unmatched_log}")

# Validation summary
total_original = len(original_records)
total_user_symptoms = len(user_symptoms_list)
matched_codes_count = len(matched_codes)
average_matches_per_symptom = len(high_similarity_pairs) / total_user_symptoms if total_user_symptoms > 0 else 0

# Compute average similarity per user symptom set
similarities_per_symptom = defaultdict(list)
for result in results:
    similarities_per_symptom[result[0]].append(float(result[3]))
average_similarities = {symptom: np.mean(sims) for symptom, sims in similarities_per_symptom.items()}

print("\nValidation Summary:")
print(f"Total original ICD records: {total_original}")
print(f"Total user symptom records: {total_user_symptoms}")
print(f"Matched ICD codes: {matched_codes_count} ({matched_codes_count / total_original * 100:.2f}%)")
print(f"Average matches per user symptom set: {average_matches_per_symptom:.2f}")
print(f"High similarity pairs (> 0.75): {len(high_similarity_pairs)}")
print(f"Total similarity pairs: {len(results)}")
print("\nAverage Cosine Similarity per User Symptom Set:")
for symptom, avg_sim in sorted(average_similarities.items(), key=lambda x: x[1], reverse=True):
    print(f"User symptoms: {symptom[:50]}...: {avg_sim:.4f}")

# Overall average similarity
if results:
    overall_avg = np.mean([float(result[3]) for result in results])
    print(f"\nOverall Average Cosine Similarity: {overall_avg:.4f}")
else:
    print("No results computed. Check embedding errors or input data.")

File size of /content/icd_11_prompt_delimited.txt: 1280665 bytes
First 2 lines of /content/icd_11_prompt_delimited.txt:
Disorder Name:   Disorder  of intellectual development, mild   |  Disorder Code: 6A00.0   |  Disorder Symptoms: A mild   Disorder  of intellectual development is a condition originating during the developmental period characterised by significantly below average intellectual functioning and adaptive behaviour that are approximately two to three standard deviations below the mean (approximately 0.1   2.3 percentile), based on appropriately normed, individually administered standardized tests or by comparable behavioural indicators when standardized testing is unavailable. Affected persons often exhibit difficulties in the acquisition and comprehension of complex language concepts and academic skills. Most master basic self-care, domestic, and practical activities. Persons affected by a mild   Disorder  of intellectual development can generally achieve relatively indepe

Error while fetching `HF_TOKEN` secret value from your vault: 'Requesting secret HF_TOKEN timed out. Secrets can only be fetched when running from the Colab UI.'.
You are not authenticated with the Hugging Face Hub in this notebook.
If the error persists, please let us know by opening an issue on GitHub (https://github.com/huggingface/huggingface_hub/issues/new).


config.json:   0%|          | 0.00/462 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/198 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Using SapBERT embeddings
Similarity 0.6022 for user symptoms vs ICD code 6A00.0
Similarity 0.5894 for user symptoms vs ICD code 6A00.1
Similarity 0.5947 for user symptoms vs ICD code 6A00.2
Similarity 0.5990 for user symptoms vs ICD code 6A00.3
Similarity 0.5310 for user symptoms vs ICD code 6A00.4
Similarity 0.5612 for user symptoms vs ICD code 6A01.0
Similarity 0.6280 for user symptoms vs ICD code 6A01.1
Similarity 0.6081 for user symptoms vs ICD code 6A01.Y
Similarity 0.6081 for user symptoms vs ICD code 6A01.Z
Similarity 0.4390 for user symptoms vs ICD code 6A02.0
Processed 1000/20240 pairs (4.94%)
Processed 2000/20240 pairs (9.88%)
Processed 3000/20240 pairs (14.82%)
Processed 4000/20240 pairs (19.76%)
Processed 5000/20240 pairs (24.70%)
Processed 6000/20240 pairs (29.64%)
Processed 7000/20240 pairs (34.58%)
Processed 8000/20240 pairs (39.53%)
Processed 9000/20240 pairs (44.47%)
Processed 10000/20240 pairs (49.41%)
Processed 11000/20240 pairs (54.35%)
Processed 12000/20240 pairs (

In [None]:
# Step 1: Install required libraries
!pip install langchain langchain-community langchain-huggingface langchain-ollama faiss-cpu transformers torch ollama langgraph -q

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.7/43.7 kB[0m [31m3.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m2.5/2.5 MB[0m [31m57.3 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m31.4/31.4 MB[0m [31m21.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m363.4/363.4 MB[0m [31m1.5 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m13.8/13.8 MB[0m [31m113.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m24.6/24.6 MB[0m [31m96.1 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m883.7/883.7 kB[0m [31m55.9 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m664.8/664.8 MB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━

In [None]:
#cell 1
# Install required Python libraries (remove langchain-hub)
!pip install langchain langchain-community faiss-cpu ollama flask langgraph -q

# Install Ollama in Colab
!curl -fsSL https://ollama.com/install.sh | sh

# Start Ollama server in the background with subprocess for better control
import subprocess
import time
import os

# Start Ollama server and check if it's running
try:
    subprocess.Popen(["ollama", "serve"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
    print("Starting Ollama server...")
    time.sleep(5)  # Wait for server to initialize
except Exception as e:
    print(f"Error starting Ollama server: {e}")
    raise

# Verify server is running by checking API availability
import requests
try:
    response = requests.get("http://127.0.0.1:11434")
    if response.status_code == 200:
        print("Ollama server is running at 127.0.0.1:11434")
    else:
        print("Ollama server not responding")
        raise Exception("Ollama server failed to start")
except requests.ConnectionError:
    print("Error: Could not connect to Ollama server. Ensure 'ollama serve' is running.")
    raise

# Pull gemma2:27b model (takes time, ~27GB)
try:
    subprocess.run(["ollama", "pull", "gemma2:9b"], check=True)
    print("Successfully pulled gemma2:9b model")
except subprocess.CalledProcessError as e:
    print(f"Error pulling model: {e}")
    raise

# Verify Ollama is running and model is available
try:
    result = subprocess.run(["ollama", "list"], capture_output=True, text=True)
    print(result.stdout)
except subprocess.CalledProcessError as e:
    print(f"Error listing models: {e}")
    raise

>>> Installing ollama to /usr/local
>>> Downloading Linux amd64 bundle
######################################################################## 100.0%
>>> Creating ollama user...
>>> Adding ollama user to video group...
>>> Adding current user to ollama group...
>>> Creating ollama systemd service...
>>> The Ollama API is now available at 127.0.0.1:11434.
>>> Install complete. Run "ollama" from the command line.
Starting Ollama server...
Ollama server is running at 127.0.0.1:11434
Successfully pulled gemma2:9b model
NAME         ID              SIZE      MODIFIED               
gemma2:9b    ff02c3702f32    5.4 GB    Less than a second ago    



In [None]:
import csv
import os
import numpy as np
import pandas as pd
from langchain_core.documents import Document
from langchain_community.vectorstores import FAISS
from langchain_huggingface import HuggingFaceEmbeddings # Use HuggingFaceEmbeddings
from langchain_ollama import OllamaEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_community.llms import Ollama
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
# from sentence_transformers import SentenceTransformer # Remove direct import to avoid conflicts

# Configuration
INPUT_ICD_PATH = '/content/icd_11_prompt_delimited.txt'
ALL_SIMILARITY_PATH = '/content/sapbert_all_similarity_pairs.csv'
TOP5_SIMILARITY_PATH = '/content/sapbert_top5_similarity_pairs.csv'
VECTORESTORE_BASE = '/content/vectorstore'
MODEL = "gemma2:9b"  # For LLM only
SAFE_MODEL = MODEL.replace(":", "")
# Use HuggingFaceEmbeddings compatible model name
EMBEDDING_MODEL_NAME = "cambridgeltl/SapBERT-from-PubMedBERT-fulltext"
PERSIST_DIR = os.path.join(VECTORESTORE_BASE, f"faiss_index-{EMBEDDING_MODEL_NAME.replace('/', '_').replace('-', '_')}-top5")
BATCH_SIZE = 50
TOP_K = 5  # Number of top matches to retrieve

# Step 1: Load original ICD data to get disorder names
def load_icd_data(file_path):
    icd_data = {}
    with open(file_path, 'r', encoding='utf-8') as f:
        reader = csv.reader(f, delimiter='|')
        for row in reader:
            if len(row) == 3:
                name, code, symptoms = [field.strip() for field in row]
                code = normalize_code(code)
                icd_data[code] = {'name': name, 'symptoms': symptoms}
            else:
                print(f"Skipping malformed ICD record: {row}")
    return icd_data

def normalize_code(code):
    code = re.sub(r'^\s*Disorder Code:\s*', '', code, flags=re.IGNORECASE)
    code = re.sub(r'\s+', '', code)
    return code.upper()

# Step 2: Filter top 5 similarity pairs per user symptom set
def filter_top5_similarity_pairs(input_file, output_file, top_k=5):
    if os.path.exists(output_file):
        print(f"Top 5 similarity pairs file already exists at {output_file}. Skipping filtering.")
        return
    try:
        df = pd.read_csv(input_file)
        df['Cosine Similarity'] = df['Cosine Similarity'].astype(float)
        # Group by User symptoms and select top k matches
        df_top_k = df.groupby('User symptoms').apply(
            lambda x: x.nlargest(top_k, 'Cosine Similarity')
        ).reset_index(drop=True)
        df_top_k.to_csv(output_file, index=False)
        print(f"Saved {len(df_top_k)} top {top_k} similarity pairs to {output_file}")
    except Exception as e:
        print(f"Error filtering top 5 similarity pairs: {e}")
        raise

# Step 3: Build FAISS vectorstore with SapBERT embeddings of user symptoms
def build_sapbert_vectorstore():
    if os.path.exists(PERSIST_DIR):
        print(f"Vectorstore already exists at {PERSIST_DIR}. Skipping build.")
        return

    print("Building SapBERT-based vectorstore for top 5 user symptoms...")

    # Filter top 5 similarity pairs
    filter_top5_similarity_pairs(ALL_SIMILARITY_PATH, TOP5_SIMILARITY_PATH, top_k=TOP_K)

    # Load ICD data for metadata
    icd_data = load_icd_data(INPUT_ICD_PATH)

    # Load top 5 similarity pairs
    try:
        df_top_k = pd.read_csv(TOP5_SIMILARITY_PATH)
    except Exception as e:
        print(f"Error loading top 5 similarity pairs: {e}")
        raise

    # Initialize HuggingFaceEmbeddings with SapBERT model - This is the embedding function to use with FAISS
    try:
        embedding_function = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME, cache_folder="/content/hf_cache")
        print(f"Using HuggingFaceEmbeddings with model: {EMBEDDING_MODEL_NAME}")
    except Exception as e:
        print(f"Error initializing HuggingFaceEmbeddings: {e}")
        raise

    documents = []

    # Create documents
    for _, row in df_top_k.iterrows():
        user_symptoms = row['User symptoms']
        predicted_code = row['predicted ICD Code']
        predicted_symptoms = row['predicted Symptoms']
        similarity = row['Cosine Similarity']

        # Get disorder name from original ICD data
        disorder_name = icd_data.get(predicted_code, {}).get('name', 'Unknown')

        # Create document with metadata
        full_text = f"Disorder Name: {disorder_name} | Disorder Code: {predicted_code} | Disorder Symptoms: {predicted_symptoms} | User Symptoms: {user_symptoms} | Cosine Similarity: {similarity:.4f}"
        doc = Document(
            page_content=full_text,
            metadata={
                'user_symptoms': user_symptoms,
                'predicted_code': predicted_code,
                'disorder_name': disorder_name,
                'cosine_similarity': similarity
            }
        )
        documents.append(doc)

    if not documents:
        print("Error: No valid records found to build vectorstore.")
        return

    print(f"Embedding {len(documents)} records...")
    try:
        # Create FAISS vectorstore using HuggingFaceEmbeddings
        # Pass the embedding_function (HuggingFaceEmbeddings) directly to FAISS.from_documents
        vectorstore = FAISS.from_documents(documents, embedding_function)
        vectorstore.save_local(PERSIST_DIR)
        print(f"Vectorstore built and saved at {PERSIST_DIR}")
    except Exception as e:
        print(f"Error building vectorstore: {e}")
        raise

# Step 4: Load vectorstore and set up RAG pipeline
def setup_rag_pipeline():
    try:
        # Load vectorstore using HuggingFaceEmbeddings
        # Pass the embedding_function (HuggingFaceEmbeddings) directly to FAISS.load_local
        embedding_function = HuggingFaceEmbeddings(model_name=EMBEDDING_MODEL_NAME, cache_folder="/content/hf_cache")
        vectorstore = FAISS.load_local(
            PERSIST_DIR,
            embedding_function,
            allow_dangerous_deserialization=True
        )
        retriever = vectorstore.as_retriever(search_kwargs={"k": TOP_K})
    except Exception as e:
        print(f"Error loading vectorstore: {e}")
        raise

    # Initialize LLM
    try:
        llm = Ollama(model=MODEL)
    except Exception as e:
        print(f"Error initializing LLM: {e}")
        raise

    # RAG prompt
    rag_prompt = ChatPromptTemplate.from_messages([
        ("system", """You are a medical diagnosis assistant. Use the provided ICD-11 context to identify the most likely diagnosis. Provide the disorder name, code, and explain why the symptoms match. If the context is insufficient, state so. Include a disclaimer that you are not a medical professional and a comprehensive evaluation by a qualified healthcare professional is needed."""),
        ("human", "Context: {context}\n\nQuestion: {question}")
    ])

    # RAG chain
    rag_chain = (
        {"context": lambda x: "\n\n".join([d.page_content for d in retriever.invoke(x["question"])]), "question": RunnablePassthrough()}
        | rag_prompt
        | llm
        | StrOutputParser()
    )
    return rag_chain

# Step 5: Test RAG pipeline
def test_rag():
    build_sapbert_vectorstore()
    rag_chain = setup_rag_pipeline()

    query = "Engagement in repetitive patterns of behavior or restricted interests; Intellectual functioning below age-expected norms; Participates in routine household tasks; Delayed academic performance relative to peers; Hypersensitivity or hyposensitivity to sensory stimuli; Verbalizes thoughts aloud to self; Rigid insistence on specific routines involving phone use; Limited reciprocal social interaction; Challenges adjusting to changes in routine or environment; Monotonous or atypical speech patterns; Exceptional visual or detailed memory recall."
    try:
        response = rag_chain.invoke({"question": query})
        print(f"Query: {query}\nResponse: {response}")
    except Exception as e:
        print(f"Error in RAG: {e}")

# Run the pipeline
if __name__ == "__main__":
    test_rag()

Building SapBERT-based vectorstore for top 5 user symptoms...


  df_top_k = df.groupby('User symptoms').apply(


Saved 275 top 5 similarity pairs to /content/sapbert_top5_similarity_pairs.csv




config.json:   0%|          | 0.00/462 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/198 [00:00<?, ?B/s]

vocab.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]

Using HuggingFaceEmbeddings with model: cambridgeltl/SapBERT-from-PubMedBERT-fulltext
Embedding 275 records...




Vectorstore built and saved at /content/vectorstore/faiss_index-cambridgeltl_SapBERT_from_PubMedBERT_fulltext-top5


  llm = Ollama(model=MODEL)


Query: Engagement in repetitive patterns of behavior or restricted interests; Intellectual functioning below age-expected norms; Participates in routine household tasks; Delayed academic performance relative to peers; Hypersensitivity or hyposensitivity to sensory stimuli; Verbalizes thoughts aloud to self; Rigid insistence on specific routines involving phone use; Limited reciprocal social interaction; Challenges adjusting to changes in routine or environment; Monotonous or atypical speech patterns; Exceptional visual or detailed memory recall.
Response: Based on the symptoms provided, the most likely diagnosis is **Autism Spectrum Disorder (ASD)**. 

Here's why:

* **Repetitive patterns of behavior or restricted interests:** This is a hallmark symptom of ASD, often manifesting as repetitive movements, insistence on sameness, and intense fixations on specific topics or objects.
* **Intellectual functioning below age-expected norms:** While not all individuals with ASD have intellectua

In [None]:
import csv
import os
import numpy as np
import pandas as pd
from langchain_core.documents import Document
from langchain_community.vectorstores import FAISS
from langchain_core.prompts import ChatPromptTemplate
from langchain_ollama import OllamaLLM
from langchain_core.output_parsers import StrOutputParser
from langchain_core.runnables import RunnablePassthrough
from langchain_huggingface import HuggingFaceEmbeddings

# Configuration
INPUT_ICD_PATH = '/content/icd_11_prompt_delimited.txt'
TOP5_SIMILARITY_PATH = '/content/sapbert_top5_similarity_pairs.csv'
VECTORESTORE_BASE = '/content/vectorstore'
OUTPUT_DIAGNOSES_PATH = '/content/rag_diagnoses_all_symptoms.csv'
MODEL = "gemma2:9b"  # For LLM only
SAFE_MODEL = MODEL.replace(":", "")
PERSIST_DIR = os.path.join(VECTORESTORE_BASE, f"faiss_index-sapbert-top5")
BATCH_SIZE = 50
TOP_K = 5  # Number of top matches to retrieve

# Step 1: Load original ICD data to get disorder names
def load_icd_data(file_path):
    icd_data = {}
    with open(file_path, 'r', encoding='utf-8') as f:
        reader = csv.reader(f, delimiter='|')
        for row in reader:
            if len(row) == 3:
                name, code, symptoms = [field.strip() for field in row]
                code = normalize_code(code)
                icd_data[code] = {'name': name, 'symptoms': symptoms}
            else:
                print(f"Skipping malformed ICD record: {row}")
    return icd_data

def normalize_code(code):
    code = re.sub(r'^\s*Disorder Code:\s*', '', code, flags=re.IGNORECASE)
    code = re.sub(r'\s+', '', code)
    return code.upper()

# Step 2: Load unique symptoms from top 5 similarity pairs
def load_unique_symptoms(file_path):
    try:
        df = pd.read_csv(file_path)
        unique_symptoms = df['User symptoms'].unique()
        print(f"Loaded {len(unique_symptoms)} unique symptom sets from {file_path}")
        return unique_symptoms
    except Exception as e:
        print(f"Error loading top 5 similarity pairs: {e}")
        raise

# Step 3: Build FAISS vectorstore with SapBERT embeddings
def build_sapbert_vectorstore():
    if os.path.exists(PERSIST_DIR):
        print(f"Vectorstore already exists at {PERSIST_DIR}. Skipping build.")
        return

    print("Building SapBERT-based vectorstore for top 5 user symptoms...")

    # Load ICD data for metadata
    icd_data = load_icd_data(INPUT_ICD_PATH)

    # Load top 5 similarity pairs
    try:
        df_top_k = pd.read_csv(TOP5_SIMILARITY_PATH)
    except Exception as e:
        print(f"Error loading top 5 similarity pairs: {e}")
        raise

    # Initialize SapBERT embeddings
    try:
        embedding_function = HuggingFaceEmbeddings(
            model_name='cambridgeltl/SapBERT-from-PubMedBERT-fulltext',
            model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'}
        )
    except Exception as e:
        print(f"Error initializing SapBERT embeddings: {e}")
        raise

    documents = []

    # Create documents
    for _, row in df_top_k.iterrows():
        user_symptoms = row['User symptoms']
        predicted_code = row['predicted ICD Code']
        predicted_symptoms = row['predicted Symptoms']
        similarity = row['Cosine Similarity']

        # Get disorder name from original ICD data
        disorder_name = icd_data.get(predicted_code, {}).get('name', 'Unknown')

        # Create document with metadata
        full_text = f"Disorder Name: {disorder_name} | Disorder Code: {predicted_code} | Disorder Symptoms: {predicted_symptoms} | User Symptoms: {user_symptoms} | Cosine Similarity: {similarity:.4f}"
        doc = Document(
            page_content=full_text,
            metadata={
                'user_symptoms': user_symptoms,
                'predicted_code': predicted_code,
                'disorder_name': disorder_name,
                'cosine_similarity': float(similarity)
            }
        )
        documents.append(doc)

    if not documents:
        print("Error: No valid records found to build vectorstore.")
        return

    print(f"Embedding {len(documents)} records...")
    try:
        # Create FAISS vectorstore
        vectorstore = FAISS.from_documents(documents, embedding_function)
        vectorstore.save_local(PERSIST_DIR)
        print(f"Vectorstore built and saved at {PERSIST_DIR}")
    except Exception as e:
        print(f"Error building vectorstore: {e}")
        raise

# Step 4: Load vectorstore and set up RAG pipeline
def setup_rag_pipeline():
    try:
        embedding_function = HuggingFaceEmbeddings(
            model_name='cambridgeltl/SapBERT-from-PubMedBERT-fulltext',
            model_kwargs={'device': 'cuda' if torch.cuda.is_available() else 'cpu'}
        )
        vectorstore = FAISS.load_local(
            PERSIST_DIR,
            embedding_function,
            allow_dangerous_deserialization=True
        )
        retriever = vectorstore.as_retriever(search_kwargs={"k": TOP_K})
    except Exception as e:
        print(f"Error loading vectorstore: {e}")
        raise

    # Initialize LLM
    try:
        llm = OllamaLLM(model=MODEL)
    except Exception as e:
        print(f"Error initializing LLM: {e}")
        raise

    # RAG prompt
    rag_prompt = ChatPromptTemplate.from_messages([
        ("system", """You are a medical diagnosis assistant. Use the provided ICD-11 context to identify the most likely diagnosis. Provide the disorder name, code, and explain why the symptoms match. If the context is insufficient, state so. Include a disclaimer that you are not a medical professional and a comprehensive evaluation by a qualified healthcare professional is needed."""),
        ("human", "Context: {context}\n\nQuestion: {question}")
    ])

    # RAG chain
    rag_chain = (
        {"context": lambda x: "\n\n".join([d.page_content for d in retriever.invoke(x["question"])]), "question": RunnablePassthrough()}
        | rag_prompt
        | llm
        | StrOutputParser()
    )
    return rag_chain

# Step 5: Process all symptoms and save diagnoses
def process_all_symptoms():
    build_sapbert_vectorstore()
    rag_chain = setup_rag_pipeline()

    # Load unique symptoms
    unique_symptoms = load_unique_symptoms(TOP5_SIMILARITY_PATH)

    # Process each symptom set and collect diagnoses
    diagnoses = []
    for query in unique_symptoms:
        try:
            response = rag_chain.invoke({"question": query})
            diagnoses.append([query, response])
            print(f"Query: {query}\nResponse: {response}\n")
        except Exception as e:
            print(f"Error processing query '{query}': {e}")
            diagnoses.append([query, f"Error: {e}"])

    # Save diagnoses to CSV
    try:
        with open(OUTPUT_DIAGNOSES_PATH, 'w', newline='', encoding='utf-8') as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow(['Query', 'Response'])
            writer.writerows(diagnoses)
        print(f"Saved {len(diagnoses)} diagnoses to {OUTPUT_DIAGNOSES_PATH}")
    except Exception as e:
        print(f"Error saving diagnoses: {e}")

# Run the pipeline
if __name__ == "__main__":
    import torch  # Import torch for device check
    process_all_symptoms()

Building SapBERT-based vectorstore for top 5 user symptoms...




Embedding 275 records...




Vectorstore built and saved at /content/vectorstore/faiss_index-sapbert-top5
Loaded 55 unique symptom sets from /content/sapbert_top5_similarity_pairs.csv
Query: 1 scholastic backwardness cant understand money concepts behavioral concerns in social situations needs help after defecation phrasal speech need based expressive language stubbornness bites his hand if wish is not fulfilled lacks attention and concentration low sitting tolerance cant eat rice with his hands difficulty with learning new skills poor shortterm memory delayed developintellectual milestones hyperactive tendencies easily distracted poor impulse control trouble following multistep instructions limited problemsolving abilities
Response: Based on the provided symptoms, the most likely diagnosis is **Other specified developmental speech or language disorder (6A01.Y)**.  

Here's why:

* **Persistent Speech and Language Difficulties:** The symptoms indicate challenges with expressive language ("need-based expressive lan

In [None]:
# Step 6:Optional Start Ollama server and verify
def start_ollama_server():
    try:
        subprocess.Popen(["ollama", "serve"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        print("Starting Ollama server...")
        time.sleep(10)
        response = requests.get("http://127.0.0.1:11434")
        if response.status_code == 200:
            print("Ollama server is running at 127.0.0.1:11434")
        else:
            raise Exception("Ollama server failed to start")
    except requests.ConnectionError:
        print("Error: Could not connect to Ollama server.")
        raise
    except Exception as e:
        print(f"Error starting Ollama server: {e}")
        raise

start_ollama_server()

# Step 5: Pull MedLlama2 model if not already available
try:
    result = subprocess.run(["ollama", "list"], capture_output=True, text=True)
    if "MedLlama2" not in result.stdout:
        print("Pulling MedLlama2 model...")
        subprocess.run(["ollama", "pull", "MedLlama2"], check=True)
        print("Successfully pulled MedLlama2 model")
except subprocess.CalledProcessError as e:
    print(f"Error checking or pulling model: {e}")
    raise