In [1]:
import pickle
import pandas as pd 
import numpy as np 
import joblib
import sqlite3
import re
from typing import TypedDict, Optional
from sklearn.preprocessing import LabelEncoder

In [2]:
MODEL_PATH="../notebooks/disease_model.pkl"

try:
    model=joblib.load(MODEL_PATH)
    print("Model loaded successfully from:", MODEL_PATH)
    print("Model expects these features (in this order):")
    print(list(model.feature_names_in_))
except FileNotFoundError:
    print(f"Model file not found at {MODEL_PATH}")
    print("Make sure you've run the first notebook and saved the model")
except AttributeError:
    print("model.feature_names_in_ not available. Inspect training notebook for feature order.")

Model loaded successfully from: ../notebooks/disease_model.pkl
Model expects these features (in this order):
['Disease', 'Fever', 'Cough', 'Fatigue', 'Difficulty Breathing', 'Age', 'Gender', 'Blood Pressure', 'Cholesterol Level']


### The formatter function 

In [3]:
def format_symptoms(symptoms: dict):
    """
    Reorders and fills missing features so input matches the model requirements.
    """
    required_features = list(model.feature_names_in_)
    formatted = {feature: symptoms.get(feature, 0) for feature in required_features}
    return pd.DataFrame([formatted])

In [4]:
sample_symptoms = {
    "Disease": 0,
    "Fever": 1,
    "Cough": 0,
    "Fatigue": 1,
    "Difficulty Breathing": 0,
    "Age": 45,
    "Gender": 1,
    "Blood Pressure": 140,
    "Cholesterol Level": 230
}

features = format_symptoms(sample_symptoms)
print("Final features passed to model:")
print(features)

prediction = model.predict(features)[0]
print("Predicted Outcome:", prediction)

Final features passed to model:
   Disease  Fever  Cough  Fatigue  Difficulty Breathing  Age  Gender  \
0        0      1      0        1                     0   45       1   

   Blood Pressure  Cholesterol Level  
0             140                230  
Predicted Outcome: 1


### Initialize database with sample data

In [5]:
def init_db(db_path="patient.db"):
    conn=sqlite3.connect(db_path)
    cursor=conn.cursor()

    cursor.execute("""
    CREATE TABLE IF NOT EXISTS patient_records(
        id INTEGER PRIMARY KEY AUTOINCREMENT,
        symptoms TEXT,
        prediction INTEGER,
        reasoning TEXT,
        created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
    )
    """)
    cursor.execute("SELECT COUNT(*) FROM patient_records")
    count=cursor.fetchone()[0]
        
    
    if count == 0:
        # Add some sample records for testing
        sample_records = [
            ("Fever: 1, Cough: 0, Fatigue: 1, Age: 45", 1, "Patient shows signs of flu-like symptoms with fever and fatigue"),
            ("Fever: 0, Cough: 1, Fatigue: 0, Age: 30", 0, "Mild respiratory symptoms, likely common cold"),
            ("Fever: 1, Cough: 1, Fatigue: 1, Age: 65", 1, "Multiple symptoms present, requires medical attention")
        ]
        
        cursor.executemany(
            "INSERT INTO patient_records (symptoms, prediction, reasoning) VALUES (?, ?, ?)",
            sample_records
        )
        print("Added sample records to database")
    
    conn.commit()
    conn.close()
    print("patient_records table is ready.")

init_db("patient.db")

patient_records table is ready.


In [6]:
# !pip install langchain-community faiss-cpu sentence-transformer

In [7]:
try:
    from langchain_community.vectorstores import FAISS
    from langchain_community.embeddings import HuggingFaceEmbeddings

    conn=sqlite3.connect("patient.db")
    df=pd.read_sql_query("SELECT id, symptoms, prediction, reasoning FROM patient_records", conn)
    conn.close()

    print(f"Retrieved {len(df)} records from database")

    if len(df)>0:
        df["case_text"]=df.apply(
            lambda row: f"Symptoms: {row['symptoms']} | Prediction:{row['prediction']} | Reasoning{row['reasoning']}",
            axis=1
        )

        embeddings=HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")
        vectorstore=FAISS.from_texts(df["case_text"].tolist(), embeddings)
        print(f"Vectorstore built with {len(df)} past cases")
    else:
        vectorstore=None
        print("No records found in database, vectorstore set to None")

except ImportError as e:
    print(f"Import error:{e}")
    print("Please install required packages: pip install langchain-community faiss-cpu sentence-transformer")
    vectorstore=None

except Exception as e:
    print(f"Error creating vectorstore:{e}")
    vectorstore=None

Retrieved 7 records from database


  embeddings=HuggingFaceEmbeddings(model_name="sentence-transformers/all-MiniLM-L6-v2")


Vectorstore built with 7 past cases


### Define retrieval function

In [8]:
def retrieve_node(state):
    """Retrieve similar cases from the vectorstore"""
    if vectorstore is None:
        state["retrieved_cases"]=[]
        return state

    query=str(state["symptoms"])

    try:
        results=vectorstore.similarity_search(query, k=3)
        state["retrieved_cases"]=[r.page_content for r in results]

    except Exception as e:
        print(f"Error during retrieval:{e}")
        state["retrieved_cases"]=[]

    return state    

In [9]:
class PatientState(TypedDict):
    symptoms: dict
    prediction: Optional[int]
    reasoning: Optional[str]
    retrieved_cases: Optional[list]

In [10]:
def predict_disease(state: PatientState) -> dict:
    """
    Takes symptoms, formats them, predicts disease, and adds prediction to the state.
    """
    symptoms = state["symptoms"]
    features = format_symptoms(symptoms)
    
    print("Features passed to the model:", features.to_dict(orient="records")[0])
    
    prediction = model.predict(features)[0]

    state["prediction"] = int(prediction)
    return state

In [11]:
try:
    from langchain_ollama import OllamaLLM
    
    llm = OllamaLLM(
        model="mistral",
        streaming=False
    )
    ollama_available = True
    print("Ollama LLM initialized successfully")
except Exception as e:
    print(f"Ollama not available: {e}")
    print("Using fallback reasoning function")
    ollama_available = False

def reasoning_node(state: dict):
    """Generate reasoning for the prediction"""
    symptoms = state["symptoms"]
    prediction = state["prediction"]
    retrieved_text = "\n".join(state.get("retrieved_cases", []))
    
    if ollama_available:
        # Use LLM for reasoning
        prompt = f"""
        You are a medical assistant.

        Here are some similar past cases:
        {retrieved_text}

        Now, based on the symptoms: {symptoms} 
        and the prediction: {prediction},

        Write a professional, short, readable explanation. 
        - Use point-wise format.
        - Include a clear conclusion at the end.
        - Keep sentences concise and natural.
        - Add relevant emojis for readability.
        """
        
        try:
            reasoning_text = llm.invoke(prompt)
            # Clean up the text
            reasoning_text = re.sub(r'\b(\w+)\s+\1\b', r'\1', reasoning_text)
            reasoning_text = re.sub(r'\s+([,.;:])', r'\1', reasoning_text)
            reasoning_text = re.sub(r'\n+', '\n', reasoning_text).strip()
        except Exception as e:
            print(f"Error with LLM reasoning: {e}")
            reasoning_text = generate_fallback_reasoning(symptoms, prediction)
    else:
        # Fallback reasoning
        reasoning_text = generate_fallback_reasoning(symptoms, prediction)
    
    state["reasoning"] = reasoning_text
    return state

def generate_fallback_reasoning(symptoms, prediction):
    """Generate basic reasoning when LLM is not available"""
    outcome = "Positive" if prediction == 1 else "Negative"
    
    # Extract key symptoms
    key_symptoms = []
    if symptoms.get("Fever", 0) == 1:
        key_symptoms.append("fever")
    if symptoms.get("Cough", 0) == 1:
        key_symptoms.append("cough")
    if symptoms.get("Fatigue", 0) == 1:
        key_symptoms.append("fatigue")
    if symptoms.get("Difficulty Breathing", 0) == 1:
        key_symptoms.append("difficulty breathing")
    
    reasoning = f"""
    Analysis Summary:
    • Patient Age: {symptoms.get('Age', 'N/A')} years
    • Key Symptoms Present: {', '.join(key_symptoms) if key_symptoms else 'None reported'}
    • Prediction: {outcome}
    
    Reasoning:
    • Based on the symptom profile and patient demographics
    • The model indicates a {outcome.lower()} outcome
    • {"Multiple symptoms suggest medical attention may be needed" if prediction == 1 else "Symptoms appear mild based on current assessment"}
    
    Note: This is an automated assessment. Please consult healthcare professionals for proper medical advice.
    """
    
    return reasoning

Ollama LLM initialized successfully


In [12]:
try:
    from langgraph.graph import StateGraph
    
    # Build the graph
    builder = StateGraph(PatientState)
    builder.add_node("predict", predict_disease)
    builder.add_node("retrieve", retrieve_node)
    builder.add_node("reasoning", reasoning_node)
    
    # Add edges
    builder.add_edge("predict", "retrieve")
    builder.add_edge("retrieve", "reasoning")
    
    # Set entry point
    builder.set_entry_point("predict")
    app = builder.compile()
    
    print("LangGraph workflow created successfully")
    
except ImportError as e:
    print(f"LangGraph not available: {e}")
    print("Please install: pip install langgraph")
    app = None

LangGraph workflow created successfully


In [13]:
if app is not None:
    sample_state: PatientState = {
        "symptoms": {
            "Disease": 0,
            "Fever": 1,
            "Cough": 0,
            "Fatigue": 1,
            "Difficulty Breathing": 0,
            "Age": 45,
            "Gender": 1,
            "Blood Pressure": 140,
            "Cholesterol Level": 230
        },
        "prediction": None,
        "reasoning": None,
        "retrieved_cases": None
    }
    
    try:
        output = app.invoke(sample_state)
        print("\n" + "="*50)
        print("WORKFLOW RESULTS")
        print("="*50)
        print("Symptoms:", output["symptoms"])
        print("Prediction:", output["prediction"])
        print("Retrieved Cases:", len(output.get("retrieved_cases", [])))
        print("\nReasoning:")
        print(output["reasoning"])
    except Exception as e:
        print(f"Error running workflow: {e}")
else:

    print("\n" + "="*50)
    print("MANUAL TESTING (without LangGraph)")
    print("="*50)
    
    # Test each component manually
    test_state = {
        "symptoms": sample_symptoms,
        "prediction": None,
        "reasoning": None,
        "retrieved_cases": None
    }
    
    # Step 1: Predict
    test_state = predict_disease(test_state)
    print("After prediction:", test_state["prediction"])
    
    # Step 2: Retrieve
    test_state = retrieve_node(test_state)
    print("Retrieved cases:", len(test_state.get("retrieved_cases", [])))
    
    # Step 3: Reasoning
    test_state = reasoning_node(test_state)
    print("\nFinal Reasoning:")
    print(test_state["reasoning"])

Features passed to the model: {'Disease': 0, 'Fever': 1, 'Cough': 0, 'Fatigue': 1, 'Difficulty Breathing': 0, 'Age': 45, 'Gender': 1, 'Blood Pressure': 140, 'Cholesterol Level': 230}

WORKFLOW RESULTS
Symptoms: {'Disease': 0, 'Fever': 1, 'Cough': 0, 'Fatigue': 1, 'Difficulty Breathing': 0, 'Age': 45, 'Gender': 1, 'Blood Pressure': 140, 'Cholesterol Level': 230}
Prediction: 1
Retrieved Cases: 3

Reasoning:
**Medical Assessment 🏥**
• **Fever**: Present (1), indicating potential underlying condition 🔥
• **Cough**: Absent (0), reducing likelihood of respiratory infection 🤧
• **Fatigue**: Reported (1), suggesting possible underlying cause requiring medical attention 😴
• **Difficulty Breathing**: Not observed (0), reassuring but not eliminating other possibilities 👅
• **Age**: 45, within normal range for this profile 👵
• **Gender**: Female (1), typical for this demographic 🚺
• **Blood Pressure**: 140 mmHg, slightly elevated but not critically high 💧
• **Cholesterol Level**: 230 mg/dL, border

In [14]:
def save_prediction_to_db(symptoms_dict, prediction, reasoning, db_path="patient.db"):
    """Save the prediction results to database"""
    conn = sqlite3.connect(db_path)
    cursor = conn.cursor()
    
    symptoms_str = str(symptoms_dict)
    
    cursor.execute(
        "INSERT INTO patient_records (symptoms, prediction, reasoning) VALUES (?, ?, ?)",
        (symptoms_str, prediction, reasoning)
    )
    
    conn.commit()
    record_id = cursor.lastrowid
    conn.close()
    
    print(f"Record saved to database with ID: {record_id}")
    return record_id

# Test saving the prediction
if 'output' in locals() and output.get("prediction") is not None:
    save_prediction_to_db(
        output["symptoms"], 
        output["prediction"], 
        output["reasoning"]
    )
else:
    print("No prediction to save")

print("\nNotebook execution completed!")

Record saved to database with ID: 8

Notebook execution completed!
