### 🧠 Context Memory and Retrieval (Embeddings + FAISS)
##### This notebook adds memory and retrieval capabilities to your autonomous agent system.
##### Features:
- Session memory: Track past outputs
- Embedding-based retrieval: Fetch relevant context dynamically
- Optional persistence

In [1]:
# 📦 Imports
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
from typing import List, Dict
import uuid
import os
import pickle

In [2]:
# ✅ Configurable Variables
EMBEDDING_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
EMBEDDING_DIM = 384  # Depends on the model used
FAISS_STORE_PATH = "/Users/sunnyraj/code_files/git_repos/MedAgenticSage/data/memory_store"

# 🔍 Load embedding model
embedding_model = SentenceTransformer(EMBEDDING_MODEL_NAME)

# 🧠 Initialize FAISS index and memory
index = faiss.IndexFlatL2(embedding_model.get_sentence_embedding_dimension())
documents: List[str] = []
metadata: List[Dict] = []


In [24]:
# 📥 Add to memory + FAISS index

def add_to_memory(text, source: str, tags: Dict = {}):
    if hasattr(text, "content"):  # Check if it's a LangChain Message
        text = text.content
    vector = embedding_model.encode([text])
    index.add(np.array(vector).astype("float32"))
    documents.append(text)
    metadata.append({"source": source, "tags": tags})


# 🔎 Retrieve similar context
def retrieve_context(query: str, k: int = 3) -> List[Dict]:
    query_vec = embedding_model.encode([query])
    D, I = index.search(np.array(query_vec).astype("float32"), k)
    results = []
    for i in I[0]:
        if i < len(documents):
            results.append({
                "text": documents[i],
                "metadata": metadata[i]
            })
    return results

# 💾 Save / Load memory

def save_memory(path=FAISS_STORE_PATH):
    os.makedirs(path, exist_ok=True)
    faiss.write_index(index, os.path.join(path, "index.faiss"))
    with open(os.path.join(path, "meta.pkl"), "wb") as f:
        pickle.dump((documents, metadata), f)


def load_memory(path=FAISS_STORE_PATH):
    global index, documents, metadata
    index = faiss.read_index(os.path.join(path, "index.faiss"))
    with open(os.path.join(path, "meta.pkl"), "rb") as f:
        documents, metadata = pickle.load(f)

In [25]:
from configs import models, env
from backend.llm.api import load_llm_langchain

config_loaded = {"model_config": models, "env": env}
llm = load_llm_langchain(source='groq', model_name='LLaMA-3', config=config_loaded)

[LLM Loader] Successfully initialized model 'llama-3.1-8b-instant' from 'groq'.


In [32]:
# ✨ Agent Node Wrappers with Memory Integration
from backend.agents.api import (
    run_symptom_checker,
    run_ehr_summarizer,
    run_literature_qa,
    run_drug_interactions,
    run_treatment_plan,
)

def log_keys(state, node_name):
    print(f"🔑 Keys in state at {node_name}: {list(state.keys())}")


def symptom_node(state):
    log_keys(state, "symptom_node")
    symptoms = state.get("symptoms")
    diagnosis = run_symptom_checker(symptoms, llm)
    state["diagnosis"] = diagnosis
    add_to_memory(diagnosis, source="diagnosis")
    return state

def ehr_node(state):
    log_keys(state, "ehr_node")
    notes = state.get("ehr_text")
    summary = run_ehr_summarizer(notes, llm)
    state["summary"] = summary
    add_to_memory(summary, source="ehr_summary")
    return state

def literature_node(state):
    log_keys(state, "literature_node")
    question = state.get("question")
    retrieved = retrieve_context(question)
    context = "\n".join([r["text"] for r in retrieved])
    answer = run_literature_qa(question, llm, context)
    state["literature_answer"] = answer
    add_to_memory(answer, source="literature_qa")
    return state


def drug_node(state):
    log_keys(state, "drug_node")
    meds = state.get("medications")

    # print("🔍 DEBUG - Raw meds:", meds, "| Type:", type(meds))

    # Defensive conversion
    if isinstance(meds, str):
        meds = [m.strip() for m in meds.split(",") if m.strip()]
    elif not isinstance(meds, list):
        print("⚠️ Unexpected medications type. Defaulting to empty list.")
        meds = []

    patient_data = state.get("diagnosis", "")
    report = run_drug_interactions(meds, llm, patient_data)
    state["interaction_report"] = report
    add_to_memory(report, source="drug_checker")
    return state



def treatment_node(state):
    log_keys(state, "treatment_node")
    profile = state.get("patient_profile")
    plan = run_treatment_plan(profile, llm)
    state["treatment_plan"] = plan
    add_to_memory(plan, source="treatment_plan")
    return state

In [37]:
from langgraph.graph import StateGraph, END
from typing import TypedDict, Optional, List

# Step 1: Define the schema
class AgentState(TypedDict):
    symptoms: Optional[str]
    ehr_text: Optional[str]
    medications: Optional[List[str]]
    question: Optional[str]
    diagnosis: Optional[str]
    summary: Optional[str]
    answer: Optional[str]
    plan: Optional[str]
    drug_warnings: Optional[str]
    patient_profile: Optional[dict]

graph = StateGraph(state_schema=AgentState)

# Add nodes
graph.add_node("symptom_checker", symptom_node)
graph.add_node("ehr_summarizer", ehr_node)
graph.add_node("literature_qa", literature_node)
graph.add_node("drug_checker", drug_node)
graph.add_node("treatment_planner", treatment_node)

# Set flow
graph.set_entry_point("symptom_checker")
graph.add_edge("symptom_checker", "ehr_summarizer")
graph.add_edge("ehr_summarizer", "literature_qa")
graph.add_edge("literature_qa", "drug_checker")
graph.add_edge("drug_checker", "treatment_planner")
graph.add_edge("treatment_planner", END)

# ## 🚀 Run Graph with Enriched Inputs

app_graph = graph.compile()

In [None]:
initial_state = {
    "symptoms": "shortness of breath, chest pain, fatigue",
    "ehr_text": "Patient has a history of hypertension and presents with elevated troponins...",
    "question": "What is the current recommendation for NSTEMI management in elderly patients?",
    "medications": ["Aspirin", "Warfarin"],
    "patient_profile": {
        "diagnosis": "NSTEMI",
        "age": 72,
        "sex": "Male",
        "comorbidities": ["Hypertension", "Atrial Fibrillation"]
    }
}

final_state = app_graph.invoke(initial_state)

for key, value in final_state.items():
    print(f"\n📌 {key.upper()}:")
    if isinstance(value, str):
        print(value)
    elif hasattr(value, "content"):
        print(value.content)
    elif isinstance(value, list):
        # Print each item or summarize the list
        for i, item in enumerate(value, 1):
            print(f"  {i}. {item}")
    else:
        # fallback, print repr
        print(repr(value))
    print()


🔑 Keys in state at symptom_node: ['symptoms', 'ehr_text', 'medications', 'question', 'patient_profile']
🔑 Keys in state at ehr_node: ['symptoms', 'ehr_text', 'medications', 'question', 'diagnosis', 'patient_profile']
🔑 Keys in state at literature_node: ['symptoms', 'ehr_text', 'medications', 'question', 'diagnosis', 'summary', 'patient_profile']
🔑 Keys in state at drug_node: ['symptoms', 'ehr_text', 'medications', 'question', 'diagnosis', 'summary', 'patient_profile']
🔑 Keys in state at treatment_node: ['symptoms', 'ehr_text', 'medications', 'question', 'diagnosis', 'summary', 'patient_profile']

📌 SYMPTOMS:
shortness of breath, chest pain, fatigue


📌 EHR_TEXT:
Patient has a history of hypertension and presents with elevated troponins...


📌 MEDICATIONS:
  1. Aspirin
  2. Warfarin


📌 QUESTION:
What is the current recommendation for NSTEMI management in elderly patients?


📌 DIAGNOSIS:
**Ranked List of Likely Diagnoses:**

1. **Acute Coronary Syndrome (ACS)**: High likelihood due to c