In [3]:
from langchain_weaviate import WeaviateVectorStore
import weaviate
from langchain_openai import OpenAIEmbeddings
weaviate_client = weaviate.connect_to_local()
embeddings = OpenAIEmbeddings()
store = WeaviateVectorStore(
    client=weaviate_client,
    index_name="AxelleMedicalDocs",
    text_key="text",
    embedding=embeddings,
)

In [6]:
store.similarity_search("what is breast cancer?")[0].metadata

{'medical_condition': 'Breast Cancer',
 'app_link': '{"Segments": [{"Value": "0B890000-5643-0050-8818-08DC7E9CAFE4", "SegmentType": "MedicalCondition"}, {"Value": "0B890000-5643-0050-4C6C-08DC7E9CB010", "SegmentType": "TheScience"}, {"Value": "OverviewIndex", "SegmentType": "PageIndex"}]}',
 'references': ['https://www.cdc.gov/cancer/breast/basic_info/what-is-breast-cancer.htm',
  'https://www.mayoclinic.org/diseases-conditions/breast-cancer/symptoms-causes/syc-20352470'],
 'doc_id': UUID('43c61ab6-3b9b-5c54-805c-d284e86d4409')}

In [4]:
import sys
from pathlib import Path

# Add the project root to Python path
project_root = Path(__file__).parent.parent.parent if '__file__' in globals() else Path.cwd().parent.parent
sys.path.insert(0, str(project_root))

from graphs.axelle_ai.src.config import gpt_oss_20b, gpt_oss_120b, llama3p3_70b_versatile

print('oss20b: ' + (str)(gpt_oss_20b.invoke("What is malaria?").response_metadata['token_usage']['total_tokens']))
print('oss120b: ' + (str)(gpt_oss_120b.invoke("What is malaria?").response_metadata['token_usage']['total_tokens']))
print('llama3p3_70b_versatile: ' + (str)(llama3p3_70b_versatile.invoke("What is malaria?").response_metadata['token_usage']['total_tokens']))

oss20b: 591
oss120b: 1716
oss120b: 1716
llama3p3_70b_versatile: 497
llama3p3_70b_versatile: 497


In [5]:
SYSTEM_PROMPT = """You are an expert health assistant.

You have access to two tools:

- get_date: use this to get the current date
- internet_search: use this to search the internet for information you need
- get_user_location: use to get user's current location

If a user ask a question and you don't have enough context to answer the question, use the internet_search tool, prioritize recent information. get_date tool will help you know know the current date"""

from tavily import TavilyClient
import os
tavily_client = TavilyClient(api_key=os.environ["TAVILY_API_KEY"])

In [6]:
from dataclasses import dataclass
from langchain.tools import tool, ToolRuntime
from typing import Literal, Optional, Dict

# ---------- Trusted medical domains ----------
TRUSTED_DOMAINS = [
    "who.int", "cdc.gov", "nih.gov", "ncbi.nlm.nih.gov",
    "ema.europa.eu", "ema.eu", "fda.gov",
    "nice.org.uk", "bmj.com", "nejm.org", "thelancet.com",
    "mayoclinic.org", "clevelandclinic.org", "uptodate.com",
    "unicef.org", "unfpa.org"
]

@tool
def get_date():
    """Get the current date."""
    from datetime import datetime
    return datetime.now().strftime("%Y-%m-%d")

@tool
# Search tool to use to do research
def internet_search(
    query: str,
    max_results: int = 5,
    topic: Literal["general", "news", "finance" ] = "general",
    include_raw_content: bool = False,
):
    """Run a web search"""
    search_docs = tavily_client.search(
        query,
        max_results=max_results,
        include_raw_content=include_raw_content,
        topic=topic,
        #add trusted domains
        #trusted_domains=TRUSTED_DOMAINS,
    )
    return search_docs

@dataclass
class Context:
    """Custom runtime context schema."""
    user_id: str

@tool
def get_user_location(runtime: ToolRuntime[Context]) -> str:
    """Retrieve user information based on user ID."""
    user_id = runtime.context.user_id
    return "Ghana" if user_id == "1" else "Europe"

In [7]:
from dataclasses import dataclass

# We use a dataclass here, but Pydantic models are also supported.
@dataclass
class ResponseFormat:
    """Response schema for the agent."""
    # A punny response (always required)
    response: str
    # Any interesting information about the weather if available
    follow_up_questions: list[str] | None = None

In [8]:
from langgraph.checkpoint.memory import InMemorySaver

checkpointer = InMemorySaver()



In [10]:
from langchain.agents import create_agent
agent = create_agent(
    model=llama3p3_70b_versatile,
    system_prompt=SYSTEM_PROMPT,
    tools=[get_date, internet_search, get_user_location],
    context_schema=Context,
    #response_format=ResponseFormat,
    checkpointer=checkpointer
)

# `thread_id` is a unique identifier for a given conversation.
config = {
    "configurable": {"thread_id": "1"},
    "metadata": {"user_id": "user_123"}
    }

response = agent.invoke(
    {"messages": [{"role": "user", "content": "what is OCD"}]},
    config=config,
    context=Context(user_id="1")
)

print(response)

# Note that we can continue the conversation using the same `thread_id`.
response = agent.invoke(
    {"messages": [{"role": "user", "content": "thank you!"}]},
    config=config,
    context=Context(user_id="1")
)

print(response)

{'messages': [HumanMessage(content='what is OCD', additional_kwargs={}, response_metadata={}, id='359d5489-11c1-4820-bdf3-425c457b8a2d'), HumanMessage(content='what is OCD', additional_kwargs={}, response_metadata={}, id='43649bce-29df-41fc-b0e3-41adc93787bd'), AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'scw9q8q9d', 'function': {'arguments': '{"include_raw_content":false,"max_results":5,"query":"OCD definition","topic":"general"}', 'name': 'internet_search'}, 'type': 'function'}]}, response_metadata={'token_usage': {'completion_tokens': 36, 'prompt_tokens': 437, 'total_tokens': 473, 'completion_time': 0.106933264, 'completion_tokens_details': None, 'prompt_time': 0.063058801, 'prompt_tokens_details': None, 'queue_time': 0.008392739, 'total_time': 0.169992065}, 'model_name': 'llama-3.3-70b-versatile', 'system_fingerprint': 'fp_bebe2dd4fb', 'service_tier': 'on_demand', 'finish_reason': 'tool_calls', 'logprobs': None, 'model_provider': 'groq'}, id='lc_run--af29e804-f84

In [None]:
response = agent.invoke(
    {"messages": [{"role": "user", "content": "where am I located?"}]},
    config=config,
    context=Context(user_id="1")
)

In [None]:
from langchain.messages import HumanMessage, AIMessage

#reverse order
state = agent.get_state(config)
for msg in state.values['messages']:
    if msg.type == "human":
        print(f"User: {msg.content}")
    elif msg.type == "ai":
        if msg.content:
            print(f"AI: {msg.content}")
    elif msg.type == "tool":
        print(f"Tool: {msg.content}")

In [11]:
import uuid
from langgraph_sdk import get_client

client = get_client(url="http://localhost:2024")


# Or create a thread using the client
threads = await client.threads.search(
    metadata={"user_id": "1"}
)
if threads:
    thread_id = threads[0]["thread_id"]
else:
    thread = await client.threads.create(metadata={"user_id": "1"})
    thread_id = thread["thread_id"]

async for chunk in client.runs.stream(
    thread_id=thread_id,
    assistant_id="agent",
    input={"messages": [{"role": "user", "content": "What is OCD?"}]},
    context=Context(user_id="1")
):
    print(chunk.data)

{'run_id': 'f3a1803b-47f6-4ccf-b923-e5c372a92d94', 'attempt': 1}
{'messages': [{'content': 'What is OCD?', 'additional_kwargs': {}, 'response_metadata': {}, 'type': 'human', 'name': None, 'id': '36b28db6-953b-4f79-a404-a089e452a0f7'}, {'content': 'Obsessive-Compulsive Disorder (OCD) is a mental health condition characterized by persistent, unwanted thoughts (obsessions) and repetitive behaviors or mental acts (compulsions) that an individual feels compelled to perform. Here are some key points about OCD:\n\n1. **Obsessions**: These are intrusive and unwanted thoughts, images, or urges that cause significant anxiety or distress. Common obsessions include fears of contamination, harm, or the need for things to be symmetrical or in a particular order.\n\n2. **Compulsions**: To alleviate the anxiety brought on by obsessions, individuals with OCD engage in compulsive behaviors. These might include excessive cleaning, checking things repeatedly, counting, or arranging items in a specific way

In [None]:
threads = await client.threads.search()

print((threads[0]))

In [None]:
threads = await client.threads.search(
    metadata={"user_id": "1"}
)

print((threads))

In [None]:
thread = await client.threads.get(
    
    thread_id=threads[0]['thread_id']
)

print(thread)

In [None]:
threads[0]['thread_id']

In [3]:
# from graphs.axelle_ai.src.utils import update_thread_title, generate_title

# await update_thread_title(
#     thread_id='8daa627e-c4cc-47d4-a401-71feed8d8a01',
#     messages=[{"role": "user", "content": "What is OCD?"}]
# )
from langgraph_sdk import get_client

client = get_client(url="http://localhost:2024")
await client.threads.update(thread_id='8daa627e-c4cc-47d4-a401-71feed8d8a01', metadata={"title": "OCD Info"})

{'thread_id': '8daa627e-c4cc-47d4-a401-71feed8d8a01',
 'status': 'idle',
 'metadata': {'owner': 'anonymous',
  'title': 'OCD Info',
  'user_id': '1',
  'graph_id': 'agent',
  'thread_name': '',
  'assistant_id': 'fe096781-5601-53d2-b2f6-0d3403f7e9ca'},
 'user_id': 'anonymous',
 'created_at': '2025-12-11T15:47:44.135447Z'}

In [None]:
from graphs.axelle_ai.src.utils import update_thread_title, create_thread, get_thread
from langgraph_sdk.schema import Thread

async def chat_with_auto_title(user_id: str, message: str, thread_id: str|None = None):
    print(f"Chatting with user {user_id} in thread {thread_id}...")
    if not thread_id:
        thread : Thread = await create_thread(user_id)
        thread_id = thread["thread_id"]
        print(f"Created new thread with ID {thread_id}")
    else:
        thread: Thread = await get_thread(thread_id)
    # Run the agent with proper message format
    input_data = {"messages": [{"role": "user", "content": message}]}
    
    print(f"Input data: {input_data}")
    # Check metadata directly from run result
    if not thread.get("metadata", {}).get("title"):
        print("No title found, generating...")
        # Generate and update title (await it!)
        await update_thread_title(thread_id, input_data["messages"])
    async for chunk in client.runs.stream(
        thread_id=thread_id,
        assistant_id="agent",
        input=input_data,
        context=Context(user_id="1"),
        stream_mode='messages'
    ):
        if type(chunk.data) == list and chunk.data[0].get("type") == "ai":
            print(chunk.data[0].get("content"))
    
    # return run

result = await chat_with_auto_title("1", "What is OCD?")

In [None]:
thread = await client.threads.search()
thread = await client.threads.get(thread_id=thread[0]['thread_id'])

print(thread.get('metadata').get('title'))