In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig, pipeline

model_name = "deepseek-ai/deepseek-llm-7b-chat"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="auto")
model.generation_config = GenerationConfig.from_pretrained(model_name)
model.generation_config.pad_token_id = model.generation_config.eos_token_id

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [38]:
from langchain.agents import AgentExecutor, create_tool_calling_agent, Tool
from duckduckgo_search import DDGS
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.messages import AIMessage
from langchain_google_genai import ChatGoogleGenerativeAI
import re
import os
from dotenv import load_dotenv

In [3]:
from langchain.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from langchain.llms import HuggingFacePipeline
from FlagEmbedding import FlagAutoModel

embed_model = FlagAutoModel.from_finetuned('BAAI/bge-base-en-v1.5',
                                      query_instruction_for_retrieval="Represent this sentence for searching relevant passages:",
                                      use_fp16=True)

In [4]:
def initalize_agent():
    load_dotenv("secrets.env")
    my_key = os.getenv("GEMINI_KEY")
    
    os.environ["GOOGLE_API_KEY"] = my_key
    llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", temperature=0.7)
    TRUSTED_SITES = [
        "cdc.gov", 
        "nih.gov",
        "drugs.com",
        "medlineplus.gov",
        "hopkinsmedicine.org",
        "who.int",
        "fda.gov",
        "health.harvard.edu"
    ]
    def medical_search(query: str, max_results: int = 5) -> str:
        filter_domains = " OR ".join([f"site:{domain}" for domain in TRUSTED_SITES])
        full_query = f"{query} {filter_domains}"
    
        results_text = []
        with DDGS() as ddgs:
            for r in ddgs.text(full_query, max_results=max_results):
                title = r.get("title", "")
                snippet = r.get("body", "")
                url = r.get("href", "")
                results_text.append(f"Title: {title}\nSnippet: {snippet}\nSource: {url}")
    
        return "\n\n".join(results_text) if results_text else "No trusted medical sources found."
        
    tools = [
        Tool(
            name="TrustedMedicalSearch",
            func=medical_search,
            description="Searches medical information from trusted sources like CDC, NIH, Mayo Clinic, and more."
        )]
    
    prompt = ChatPromptTemplate.from_messages(
        [("system", (
            "You are a medical assistant. You have access to a specialized web search tool "
            "that ONLY returns information from a predefined list of trustworthy sources, "
            "including reputable general, academic, and medical websites. "
            "Use this tool when you need to find current or specific information. "
            "When you provide an answer based on information from the search tool, "
            "If the tool returns 'No information found from trusted sources', state that clearly and do not invent information."
            )),
            ("user", "{input}"),
            AIMessage(content="Okay, I will use the trusted search tool if necessary to find the answer from reliable sources."),
            ("placeholder", "{agent_scratchpad}"),
        ]
    )
    agent = create_tool_calling_agent(llm, tools, prompt)
    agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True, handle_parsing_errors=True, return_intermediate_steps=True)
    return agent, agent_executor

In [5]:
def search_vectordb(user_query):
    torch.cuda.empty_cache()
    query = user_query
    embed_query = embed_model.encode([query])
    if "pregnant" in query.lower():
        category = 'specific population usage'
    elif "uses" in query.lower():
        category = 'general'
    elif "clinical" in query.lower():
        category = 'clinical data'
    elif "dose" in query.lower():
        category = 'dosage and administration'
    else:
        category = 'general'
    results = vector_store.similarity_search_by_vector(
        embed_query[0], k=1, filter={"category": category})
    cleaned_context = [doc.page_content for doc in results]
    return query, cleaned_context

In [6]:
def generate_deep_seek(query, cleaned_context):
  torch.cuda.empty_cache()
  prompt_template = """
  You are an expert medical assistant explaining medicine to someone without any prior knowledge.

  Use ***ONLY*** the provided context to answer the user's question. Respond only with exact information found in the provided text.
  If the context does not contain the answer, say "The context does not provide enough information."

  # Context: {context}

  # Question: {question}

  Answer:
  """
  prompt = PromptTemplate(template=prompt_template, input_variables=['context', 'question'])
  llm = HuggingFacePipeline(
      pipeline=pipeline("text-generation", model=model, tokenizer=tokenizer, max_new_tokens=300)
  )
  qa_chain = prompt | llm
  response = qa_chain.invoke({'context' : cleaned_context, 'question' : query})

  def extract_answer(text):
      # Ensure "Answer: " exists in the string
      if "Answer:" in text:
          return text.split("Answer:", 1)[1]

  result = extract_answer(response)
  return result, cleaned_context

In [17]:
def self_eval(query, response):
    torch.cuda.empty_cache()
    gem_llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", temperature=0.7)
    prompt = PromptTemplate.from_template("""
    Given the following question and answer, determine if the answer is sufficient. An answer is sufficient if it provides ample information related to the question.
    
    An answer is insufficient if it doesn't contain relevant information. Additionally, phrases like "the provided context does not contain" or "the context does not provide" are common for insufficient answers
    
    Question: {question}
    Answer: {answer}
    
    Respond only with "Sufficient" or "Insufficient".
    """)
    
    question = query
    answer = response
    
    sufficiency_chain = prompt | gem_llm
    
    sufficiency_result = sufficiency_chain.invoke({
        "question": question,
        "answer": answer
    })
    return sufficiency_result.content

In [26]:
def generate_with_agent(user_query, agent_executor):
    query, cleaned_context = search_vectordb(user_query)
    response, context = generate_deep_seek(query, cleaned_context)
    sufficiency = self_eval(query, response)
    if 'Insufficient' in sufficiency:
        agent_response = agent_executor.invoke({"input": user_query})
        summary = agent_response['output']
        scraped_data = agent_response['intermediate_steps']
        return summary, scraped_data
    else:
        return response, context

def generate(user_query):
    query, cleaned_context = search_vectordb(user_query)
    response, context = generate_deep_seek(query, cleaned_context)
    return response, context

In [9]:
vector_store = FAISS.load_local("my_vector_store", embeddings=embed_model, allow_dangerous_deserialization=True)

`embedding_function` is expected to be an Embeddings object, support for passing in a function will soon be removed.


In [10]:
agent, agent_executor = initalize_agent()

In [39]:
result, context = generate_with_agent("What are the side effects of ozempic?", agent_executor)

Device set to use cuda:0




[1m> Entering new AgentExecutor chain...[0m
[32;1m[1;3m
Invoking: `TrustedMedicalSearch` with `side effects of Ozempic`


[0m

DuckDuckGoSearchException: https://html.duckduckgo.com/html 202 Ratelimit

In [29]:
result



In [30]:
context

[(ToolAgentAction(tool='TrustedMedicalSearch', tool_input='ozempic', log='\nInvoking: `TrustedMedicalSearch` with `ozempic`\n\n\n', message_log=[AIMessageChunk(content='', additional_kwargs={'function_call': {'name': 'TrustedMedicalSearch', 'arguments': '{"__arg1": "ozempic"}'}}, response_metadata={'finish_reason': 'STOP', 'safety_ratings': []}, id='run--3e19f339-fbc8-4649-819c-4a691e78e172', tool_calls=[{'name': 'TrustedMedicalSearch', 'args': {'__arg1': 'ozempic'}, 'id': 'a8e6cc85-8f67-47bb-9e73-2e293d9c731c', 'type': 'tool_call'}], usage_metadata={'input_tokens': 138, 'output_tokens': 9, 'total_tokens': 147, 'input_token_details': {'cache_read': 0}}, tool_call_chunks=[{'name': 'TrustedMedicalSearch', 'args': '{"__arg1": "ozempic"}', 'id': 'a8e6cc85-8f67-47bb-9e73-2e293d9c731c', 'index': None, 'type': 'tool_call_chunk'}])], tool_call_id='a8e6cc85-8f67-47bb-9e73-2e293d9c731c'),

In [22]:
result, context = generate("What is ozempic?")
result

Device set to use cuda:0


'\n  \n  The provided context does not contain information about "Ozempic." It is possible that Ozempic is not mentioned in the context or the context does not provide enough information about it.'

In [23]:
context

['Summary of AMOXICILLIN | Uses: Adults and Pediatric Patients Upper Respiratory Tract Infections of the Ear, Nose, and Throat: Amoxicillin for Oral Suspension, USP is indicated in the treatment of infections due to susceptible (ONLY β-lactamase–negative) isolates of Streptococcus species. (α-and β-hemolytic isolates only), Streptococcus pneumoniae , Staphylococcus spp., or Haemophilus influenzae . Infections of the Genitourinary Tract: Amoxicillin for Oral Suspension, USP is indicated in the treatment of infections due to susceptible (ONLY β-lactamase–negative) isolates of Escherichia coli , Proteus mirabilis , or Enterococcus faecalis . Infections of the Skin and Skin Structure: Amoxicillin for Oral Suspension, USP is indicated in the treatment of infections due to susceptible (ONLY β-lactamase–negative) isolates of Streptococcus spp. (α-and β-hemolytic isolates only), Staphylococcus spp., or E. coli . Infections of the Lower Respiratory Tract: Amoxicillin for Oral Suspension, USP is