# **Set up Environment**

In [None]:
%%bash 
pip install -qU langchain-nomic
pip install -U langchain-chroma
pip install langchain-openai
pip install --upgrade --quiet langchain-community gpt4all
pip install --upgrade --quiet  gpt4all > /dev/null

In [None]:
!pip install -U langchain-nomic langchain_community tiktoken langchainhub chromadb langchain langgraph tavily-python gpt4all

In [None]:
%pip install --upgrade --quiet langchain-nvidia-ai-endpoints

In [None]:
!pip install langchain_openai

In [None]:
%pip install --upgrade openai

In [None]:
!pip install gradio

In [None]:
!pip install sentence-transformers

In [104]:
import os
from dotenv import load_dotenv

# 載入 .env 檔案
load_dotenv()

# 取得 API Key
api_key = os.getenv("NVIDIA_API_KEY")


In [200]:
import os
import pandas as pd
import openai
from langchain.schema import Document
from langchain.memory import ConversationBufferMemory
from typing_extensions import TypedDict
from typing import List
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import GPT4AllEmbeddings
# from langchain.embeddings import OpenAIEmbeddings
from langchain_core.prompts import ChatPromptTemplate
from langchain_openai import ChatOpenAI, OpenAIEmbeddings
from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
from langchain_community.utilities import GoogleSerperAPIWrapper
from langchain.prompts import PromptTemplate
from langgraph.graph import END, StateGraph
from langchain_nvidia_ai_endpoints import ChatNVIDIA
import gradio as gr
import time
from langchain_community.embeddings import GPT4AllEmbeddings

# Initialize GPT4AllEmbeddings
model_name = "all-MiniLM-L6-v2.gguf2.f16.gguf"
gpt4all_kwargs = {'allow_download': 'True'}
GPT4AllEmbeddings(
    model_name=model_name,
    gpt4all_kwargs=gpt4all_kwargs
)


GPT4AllEmbeddings(model_name='all-MiniLM-L6-v2.gguf2.f16.gguf', n_threads=None, device='cpu', gpt4all_kwargs={'allow_download': 'True'}, client=<gpt4all.gpt4all.Embed4All object at 0x7f06951858d0>)

In [383]:
question = "Are there specific formulations of Simvastatin designed for patients with genetic risks?"

# **Set Up Agents and Tools**

## Query rewriter

In [384]:
# VectorStoreRewriter

model_id = "meta/llama-3.3-70b-instruct"

# LLM
llm = ChatNVIDIA(model=model_id, temperature=0)

prompt_vs = PromptTemplate(
    template="""
You are a vectorstore query rewriter, and you'll receive a question from the patient and a list of tags and drug names. You'll have to extract the drug name and possibly useful tags from the lists.
- Comprehend the question and extract "one drug" from the question
- Comprehend the question and select the tags from the list and store them in the string
- Return the drug: string of tags pair only, without any explanation

## drug names

Etanercept
Alteplase
Darbepoetin alfa
Goserelin
Pegfilgrastim
Asparaginase Escherichia coli
Desmopressin
Glucagon
Insulin glargine
Rasburicase
Adalimumab
Pegaspargase
Infliximab
Trastuzumab
Rituximab
Streptokinase
Filgrastim
Coagulation Factor IX (Recombinant)
Octreotide
Oxytocin
Bevacizumab
Ascorbic acid
Calcitriol
Riboflavin
Thiamine
Ergocalciferol
Folic acid
Pyridoxine
Fluvoxamine
Ramipril
Flunisolide
Lorazepam
Bortezomib
Carbidopa
Fluconazole
Oseltamivir
Erythromycin
Hydroxocobalamin
Pyrimethamine
Azithromycin
Torasemide
Citalopram
Moxifloxacin
Nevirapine
Cladribine
Mesalazine
Cabergoline
Dapsone
Phenytoin
Doxycycline
Clotrimazole
Cycloserine
Metoprolol
Lidocaine
Bleomycin
Chlorambucil
Morphine
Bupivacaine
Tenofovir disoproxil
Tranexamic acid
Chlorthalidone
Valproic acid
Acetaminophen
Gefitinib
Codeine
Piperacillin
Amitriptyline
Hydromorphone
Ethambutol
Metformin
Methadone
Olanzapine
Atenolol
Omeprazole
Pyrazinamide
Cetirizine
Tioguanine
Methylergometrine
Mefloquine
Sulfadiazine
Vinorelbine
Anidulafungin
Clozapine
Levonorgestrel
Timolol
Trihexyphenidyl
Palonosetron
Amlodipine
Carbimazole
Digoxin
Zoledronic acid
Griseofulvin
Mupirocin
Ampicillin
Phenoxymethylpenicillin
Spironolactone
Allopurinol
Ceftazidime
Trimethoprim
Gemcitabine
Entecavir
Betamethasone
Chloramphenicol
Levothyroxine
Loratadine
Quinine
Fluoxetine
Chlorpromazine
Amikacin
Lenalidomide
Cefotaxime
Zidovudine
Oxycodone
Flutamide
Haloperidol
Ritonavir
Vancomycin
Cisplatin
Albendazole
Caspofungin
Oxaliplatin
Erlotinib
Cyclophosphamide
Ciprofloxacin
Vincristine
Fluorouracil
Pyridostigmine
Propylthiouracil
Lamotrigine
Methotrexate
Carbamazepine
Vinblastine
Propranolol
Atropine
Valaciclovir
Lactulose
Voriconazole
Enalapril
Ethosuximide
Amiloride
Oxytetracycline
Thiopental
Linezolid
Ivermectin
Medroxyprogesterone acetate
Chloroquine
Ethionamide
Bisoprolol
Amodiaquine
Rifabutin
Imatinib
Fluphenazine
Testosterone
Efavirenz
Prednisone
Mebendazole
Nystatin
Magnesium sulfate
Latanoprost
Verapamil
Nilutamide
Epinephrine
Sumatriptan
Cefixime
Aprepitant
Tamoxifen
Benzyl benzoate
Losartan
Amphotericin B
Warfarin
Midazolam
Tobramycin
Fludrocortisone
Fluorescein
Daunorubicin
Furosemide
Nitrofurantoin
Naltrexone
Lamivudine
Diethylcarbamazine
Apomorphine
Paroxetine
Norethisterone
Lisinopril
Risperidone
Pentamidine
Hydrocortisone
Mannitol
Deferoxamine
Dolasetron
Clopidogrel
Tetracycline
Meropenem
Potassium chloride
Irinotecan
Methimazole
Mometasone
Clavulanic acid
Etoposide
Sulfasalazine
Gentamicin
Colistin
Indapamide
Tropicamide
Biperiden
Ribavirin
Fentanyl
Propofol
Acetazolamide
Natamycin
Fosfomycin
Diazepam
Mifepristone
Loperamide
Clofazimine
Levamisole
Dacarbazine
Terbinafine
Penicillamine
Prednisolone
Ranitidine
Tacrolimus
Terbutaline
Chlorhexidine
Emtricitabine
Chlorothiazide
Clomifene
Isosorbide dinitrate
Bumetanide
Granisetron
Ondansetron
Tinidazole
Metronidazole
Spectinomycin
Buprenorphine
Misoprostol
Salicylic acid
Salmeterol
Acetylsalicylic acid
Fexofenadine
Isoniazid
Netilmicin
Carboplatin
Methylprednisolone
Telmisartan
Methyldopa
Dactinomycin
Selenium Sulfide
Ethinylestradiol
Cyclopentolate
Formoterol
Glycopyrronium
Cytarabine
Dopamine
Azathioprine
Doxorubicin
Hydrochlorothiazide
Salbutamol
Hydroxyurea
Letrozole
Sulfamethoxazole
Mercaptopurine
Thalidomide
Melphalan
Rifampicin
Abacavir
Ibuprofen
Benzylpenicillin
Praziquantel
Amoxicillin
Fludarabine
Streptomycin
Pilocarpine
Primaquine
Oxamniquine
Flucytosine
Capecitabine
Sertraline
Miconazole
Cefuroxime
Nifedipine
Amiodarone
Diazoxide
Gliclazide
Bicalutamide
Proguanil
Carvedilol
Levofloxacin
Micafungin
Cloxacillin
Bupropion
Halothane
Ofloxacin
Itraconazole
Procarbazine
Arsenic trioxide
Kanamycin
Phenobarbital
Escitalopram
Cyclizine
Ifosfamide
Naloxone
Clindamycin
Bromocriptine
Rifapentine
Levetiracetam
Clarithromycin
Ceftriaxone
Anastrozole
Ketamine
Budesonide
Quetiapine
Enoxaparin
Paclitaxel
Metoclopramide
Dexamethasone
Levodopa
Sevoflurane
Aripiprazole
Clomipramine
Docetaxel
Ergometrine
Dasatinib
Darunavir
Paliperidone
Varenicline
Hydralazine
Carbetocin
Sulfadoxine
Insulin detemir
Cefazolin
Vecuronium
Iohexol
Calcium
Neostigmine
Tiotropium
Ciclesonide
Paromomycin
Everolimus
Cilastatin
Imipenem
Lopinavir
Tazobactam
Deferasirox
Valganciclovir
Hydroxychloroquine
Calcipotriol
Nicotinamide
Acetic acid
Glutaral
Nilotinib
Permethrin
Pretomanid
Silver sulfadiazine
Iodine
Liposomal prostaglandin E1
Sodium stibogluconate
Abiraterone
Acetylcysteine
Rivaroxaban
Eflornithine
Dapagliflozin
Apixaban
Golimumab
Nitrous oxide
Xylometazoline
Artemether
Lumefantrine
Potassium Iodide
Bendamustine
Dalteparin
Dimercaprol
Niclosamide
Raltegravir
Triptorelin
Diloxanide
Nadroparin
Deferiprone
Ulipristal
Asparaginase Erwinia chrysanthemi
Aclidinium
Enzalutamide
Bedaquiline
Certolizumab pegol
Fluticasone furoate
Canagliflozin
Afatinib
Dolutegravir
Sofosbuvir
Bisacodyl
Ledipasvir
Miltefosine
Nivolumab
Pembrolizumab
Empagliflozin
Tedizolid phosphate
Ceftolozane
Ibrutinib
Avibactam
Edoxaban
Umeclidinium
Tetracaine
Chlortetracycline
Benzoyl peroxide
Daclatasvir
Methoxy polyethylene glycol-epoetin beta
Oxygen
Protamine sulfate
Sodium chloride
Artesunate
Activated charcoal
Procaine benzylpenicillin
Zinc sulfate
Insulin degludec
Rotavirus vaccine
Yellow fever vaccine
Hepatitis A Vaccine
Typhoid Vaccine Live
Coal tar
Chloroxylenol
Calcium gluconate
Barium sulfate
Pyrantel
Dexamethasone isonicotinate
Tuberculin purified protein derivative
Velpatasvir
Hepatitis B Vaccine (Recombinant)
Delamanid
Tropisetron
Nifurtimox
Benznidazole
Vaborbactam
Triclabendazole
Fexinidazole
Plazomicin
Protionamide
BCG vaccine
Benserazide
Melarsoprol
Terizidone
Atracurium
Tacalcitol
Meglumine antimoniate
Potassium permanganate
Fluticasone
Pibrentasvir
Glecaprevir
Estradiol cypionate
Typhoid vaccine
Lithium carbonate
Hydrocortisone aceponate
Dabigatran
Polymyxin B
Cefiderocol
Pertussis vaccine
Tick-borne encephalitis vaccine (whole virus, inactivated)
Ravidasvir
Senna leaf
Maftivimab
Odesivimab
Ansuvimab
Hepatitis A vaccine (live, attenuated)
Japanese Encephalitis Vaccine, Inactivated, Adsorbed
Japanese encephalitis vaccine (live, attenuated)
Pravastatin
Lovastatin
Simvastatin
Atorvastatin
Fluvastatin
Rosuvastatin
Pitavastatin
Mevastatin
Tenivastatin
Cerivastatin

## tag list

drug
targets
target
polypeptide
go-classifiers
go-classifier
description
category
pfams
pfam
name
identifier
organism
locus
amino-acid-sequence
external-identifiers
external-identifier
resource
synonyms
synonym
signal-regions
theoretical-pi
chromosome-location
general-function
specific-function
molecular-weight
gene-name
gene-sequence
transmembrane-regions
cellular-location
references
articles
article
citation
ref-id
pubmed-id
attachments
attachment
title
url
textbooks
textbook
isbn
links
link
actions
action
id
name
known-action
carriers
carrier
reactions
reaction
enzymes
enzyme
uniprot-id
drugbank-id
right-element
left-element
snp-adverse-drug-reactions
gene-symbol
protein-name
allele
adverse-reaction
rs-id
enzymes
induction-strength
inhibition-strength
transporters
transporter
general-references
products
product
ended-marketing-on
ndc-product-code
ema-ma-number
approved
dpd-id
over-the-counter
route
ndc-id
generic
country
ema-product-code
started-marketing-on
source
strength
fda-application-number
labeller
dosage-form
patents
patent
pediatric-extension
expires
number
calculated-properties
property
value
kind
pathways
pathway
drugs
category
smpdb-id
sequences
sequence
snp-effects
effect
defining-change
dosages
dosage
form
protein-binding
drug-interactions
drug-interaction
affected-organisms
ahfs-codes
packagers
packager
salts
salt
monoisotopic-mass
cas-number
unii
average-mass
inchikey
synthesis-reference
mixtures
mixture
ingredients
prices
price
cost
unit
international-brands
international-brand
company
fda-label
clearance
external-links
external-link
classification
alternative-parent
class
substituent
superclass
direct-parent
kingdom
subclass
average-mass
toxicity
food-interactions
food-interaction
groups
group
categories
category
mesh-id
state
experimental-properties
pharmacodynamics
monoisotopic-mass
manufacturers
manufacturer
unii
cas-number
indication
atc-codes
atc-code
level
mechanism-of-action
volume-of-distribution
pdb-entries
pdb-entry
absorption
metabolism
half-life
msds
route-of-elimination

Here is the user question: {question}
""",
    input_variables=["question"],
)

query_rewriter_vs = prompt_vs | llm | StrOutputParser()


# Start the timer
start_time = time.time()
query_vs = query_rewriter_vs.invoke({"question": question})
end_time = time.time()

# Calculate elapsed time
elapsed_time = end_time - start_time
print(f"The code took {elapsed_time:.2f} seconds to execute.")
print(query_vs)

The code took 0.83 seconds to execute.
Simvastatin: genetic-risks, formulations, drug-interactions, pharmacodynamics, mechanism-of-action


In [385]:
# GraphRAGRewriter

model_id = "meta/llama-3.3-70b-instruct"

# LLM
llm = ChatNVIDIA(model=model_id, temperature=0)

prompt_graph = PromptTemplate(
    template=
    """you are a graph retriever, and you all be given a user question. please determine which drugs are related to this question, and use the following template to query the graph.
cypher template:
MATCH (target)-[r]-(neighbor)
WHERE target.id = 'drug name 1' AND neighbor.id = 'drug name 2' 
RETURN neighbor, r
LIMIT 1;

drug list:
Etanercept, Alteplase, Darbepoetin alfa, Goserelin, Pegfilgrastim, Asparaginase Escherichia coli, Desmopressin, Glucagon, Insulin glargine, Rasburicase, Adalimumab, Pegaspargase, Infliximab, Trastuzumab, Rituximab, Streptokinase, Filgrastim, Coagulation Factor IX (Recombinant), Octreotide, Oxytocin, Bevacizumab, Ascorbic acid, Calcitriol, Riboflavin, Thiamine, Ergocalciferol, Folic acid, Pyridoxine, Fluvoxamine, Ramipril, Flunisolide, Lorazepam, Bortezomib, Carbidopa, Fluconazole, Oseltamivir, Erythromycin, Hydroxocobalamin, Pyrimethamine, Azithromycin, Torasemide, Citalopram, Moxifloxacin, Nevirapine, Cladribine, Mesalazine, Cabergoline, Dapsone, Phenytoin, Doxycycline, Clotrimazole, Cycloserine, Metoprolol, Lidocaine, Bleomycin, Chlorambucil, Morphine, Bupivacaine, Tenofovir disoproxil, Tranexamic acid, Chlorthalidone, Valproic acid, Acetaminophen, Gefitinib, Codeine, Piperacillin, Amitriptyline, Hydromorphone, Ethambutol, Metformin, Methadone, Olanzapine, Atenolol, Omeprazole, Pyrazinamide, Cetirizine, Tioguanine, Methylergometrine, Mefloquine, Sulfadiazine, Vinorelbine, Anidulafungin, Clozapine, Levonorgestrel, Timolol, Trihexyphenidyl, Palonosetron, Amlodipine, Carbimazole, Digoxin, Zoledronic acid, Griseofulvin, Mupirocin, Ampicillin, Phenoxymethylpenicillin, Spironolactone, Allopurinol, Ceftazidime, Trimethoprim, Gemcitabine, Entecavir, Betamethasone, Chloramphenicol, Levothyroxine, Loratadine, Quinine, Fluoxetine, Chlorpromazine, Amikacin, Lenalidomide, Cefotaxime, Zidovudine, Oxycodone, Flutamide, Haloperidol, Ritonavir, Vancomycin, Cisplatin, Albendazole, Caspofungin, Oxaliplatin, Erlotinib, Cyclophosphamide, Ciprofloxacin, Vincristine, Fluorouracil, Pyridostigmine, Propylthiouracil, Lamotrigine, Methotrexate, Carbamazepine, Vinblastine, Propranolol, Atropine, Valaciclovir, Lactulose, Voriconazole, Enalapril, Ethosuximide, Amiloride, Oxytetracycline, Thiopental, Linezolid, Ivermectin, Medroxyprogesterone acetate, Chloroquine, Ethionamide, Bisoprolol, Amodiaquine, Rifabutin, Imatinib, Fluphenazine, Testosterone, Efavirenz, Prednisone, Mebendazole, Nystatin, Magnesium sulfate, Latanoprost, Verapamil, Nilutamide, Epinephrine, Sumatriptan, Cefixime, Aprepitant, Tamoxifen, Benzyl benzoate, Losartan, Amphotericin B, Warfarin, Midazolam, Tobramycin, Fludrocortisone, Fluorescein, Daunorubicin, Furosemide, Nitrofurantoin, Naltrexone, Lamivudine, Diethylcarbamazine, Apomorphine, Paroxetine, Norethisterone, Lisinopril, Risperidone, Pentamidine, Hydrocortisone, Mannitol, Deferoxamine, Dolasetron, Clopidogrel, Tetracycline, Meropenem, Potassium chloride, Irinotecan, Methimazole, Mometasone, Clavulanic acid, Etoposide, Sulfasalazine, Gentamicin, Colistin, Indapamide, Tropicamide, Biperiden, Ribavirin, Fentanyl, Propofol, Acetazolamide, Natamycin, Fosfomycin, Diazepam, Mifepristone, Loperamide, Clofazimine, Levamisole, Dacarbazine, Terbinafine, Penicillamine, Prednisolone, Ranitidine, Tacrolimus, Terbutaline, Chlorhexidine, Emtricitabine, Chlorothiazide, Clomifene, Isosorbide dinitrate, Bumetanide, Granisetron, Ondansetron, Tinidazole, Metronidazole, Spectinomycin, Buprenorphine, Misoprostol, Salicylic acid, Salmeterol, Acetylsalicylic acid, Fexofenadine, Isoniazid, Netilmicin, Carboplatin, Methylprednisolone, Telmisartan, Methyldopa, Dactinomycin, Selenium Sulfide, Ethinylestradiol, Cyclopentolate, Formoterol, Glycopyrronium, Cytarabine, Dopamine, Azathioprine, Doxorubicin, Hydrochlorothiazide, Salbutamol, Hydroxyurea, Letrozole, Sulfamethoxazole, Mercaptopurine, Thalidomide, Melphalan, Rifampicin, Abacavir, Ibuprofen, Benzylpenicillin, Praziquantel, Amoxicillin, Fludarabine, Streptomycin, Pilocarpine, Primaquine, Oxamniquine, Flucytosine, Capecitabine, Sertraline, Miconazole, Cefuroxime, Nifedipine, Amiodarone, Diazoxide, Gliclazide, Bicalutamide, Proguanil, Carvedilol, Levofloxacin, Micafungin, Cloxacillin, Bupropion, Halothane, Ofloxacin, Itraconazole, Procarbazine, Arsenic trioxide, Kanamycin, Phenobarbital, Escitalopram, Cyclizine, Ifosfamide, Naloxone, Clindamycin, Bromocriptine, Rifapentine, Levetiracetam, Clarithromycin, Ceftriaxone, Anastrozole, Ketamine, Budesonide, Quetiapine, Enoxaparin, Paclitaxel, Metoclopramide, Dexamethasone, Levodopa, Sevoflurane, Aripiprazole, Clomipramine, Docetaxel, Ergometrine, Dasatinib, Darunavir, Paliperidone, Varenicline, Hydralazine, Carbetocin, Sulfadoxine, Insulin detemir, Cefazolin, Vecuronium, Iohexol, Calcium, Neostigmine, Tiotropium, Ciclesonide, Paromomycin, Everolimus, Cilastatin, Imipenem, Lopinavir, Tazobactam, Deferasirox, Valganciclovir, Hydroxychloroquine, Calcipotriol, Nicotinamide, Acetic acid, Glutaral, Nilotinib, Permethrin, Pretomanid, Silver sulfadiazine, Iodine, Liposomal prostaglandin E1, Sodium stibogluconate, Abiraterone, Acetylcysteine, Rivaroxaban, Eflornithine, Dapagliflozin, Apixaban, Golimumab, Nitrous oxide, Xylometazoline, Artemether, Lumefantrine, Potassium Iodide, Bendamustine, Dalteparin, Dimercaprol, Niclosamide, Raltegravir, Triptorelin, Diloxanide, Nadroparin, Deferiprone, Ulipristal, Asparaginase Erwinia chrysanthemi, Aclidinium, Enzalutamide, Bedaquiline, Certolizumab pegol, Fluticasone furoate, Canagliflozin, Afatinib, Dolutegravir, Sofosbuvir, Bisacodyl, Ledipasvir, Miltefosine, Nivolumab, Pembrolizumab, Empagliflozin, Tedizolid phosphate, Ceftolozane, Ibrutinib, Avibactam, Edoxaban, Umeclidinium, Tetracaine, Chlortetracycline, Benzoyl peroxide, Daclatasvir, Methoxy polyethylene glycol-epoetin beta, Oxygen, Protamine sulfate, Sodium chloride, Artesunate, Activated charcoal, Procaine benzylpenicillin, Zinc sulfate, Insulin degludec, Rotavirus vaccine, Yellow fever vaccine, Hepatitis A Vaccine, Typhoid Vaccine Live, Coal tar, Chloroxylenol, Calcium gluconate, Barium sulfate, Pyrantel, Dexamethasone isonicotinate, Tuberculin purified protein derivative, Velpatasvir, Hepatitis B Vaccine (Recombinant), Delamanid, Tropisetron, Nifurtimox, Benznidazole, Vaborbactam, Triclabendazole, Fexinidazole, Plazomicin, Protionamide, BCG vaccine, Benserazide, Melarsoprol, Terizidone, Atracurium, Tacalcitol, Meglumine antimoniate, Potassium permanganate, Fluticasone, Pibrentasvir, Glecaprevir, Estradiol cypionate, Typhoid vaccine, Lithium carbonate, Hydrocortisone aceponate, Dabigatran, Polymyxin B, Cefiderocol, Pertussis vaccine, "Tick-borne encephalitis vaccine (whole virus, inactivated)", Ravidasvir, Senna leaf, Maftivimab, Odesivimab, Ansuvimab, "Hepatitis A vaccine (live, attenuated)", "Japanese Encephalitis Vaccine, Inactivated, Adsorbed", "Japanese encephalitis vaccine (live, attenuated)"


here are the question from user: {question}

NOTICE THAT GIVE ME THE CYPHER QUERY ONLY. IF THERE ARE NO QUERY MATCH, RETRUN "" WITHOUT ANY EXPLANATION.PLEASE DO NOT USE THE DRUGS THAT NOT IN THE LIST.

IF THERE IS NOT ANY MATCH DRUG IN LIST, RETRUN  .

DO NOT ADD ```.

    """,
    input_variables=["question"],
    
)

query_rewriter = prompt_graph | llm | StrOutputParser()

graph_query = query_rewriter.invoke({"question": question})

print(graph_query)

MATCH (target)-[r]-(neighbor)
WHERE target.id = 'Simvastatin' AND neighbor.id = 'Atorvastatin' 
RETURN neighbor, r
LIMIT 1;


In [386]:
# SQLRAGRewriter

# LLM
llm = ChatNVIDIA(model=model_id, temperature=0)

prompt_sql = PromptTemplate(
    template=f"""
You are a pharmacist and data engineer, and the following are the cols and explanations in sqlite3 database. DO NOT MODIFIY THE COLUMN NAMES!

Table : drug
Columns:
drugid: Unique identifier for the drug within the CPIC database.
name: Name of the drug, typically the generic name.
pharmgkbid: Reference to the PharmGKB ID, a database identifier for pharmacogenomics information.
rxnormid: RxNorm identifier, a standardized nomenclature for clinical drugs by the National Library of Medicine.
drugbankid: Identifier for the drug in DrugBank, a bioinformatics and cheminformatics resource.
atcid: Anatomical Therapeutic Chemical (ATC) classification system code for the drug.
umlscui: Concept Unique Identifier from the Unified Medical Language System (UMLS).
flowchart: Link or reference to a CPIC guideline flowchart associated with the drug.
version: Version of the data or record for tracking updates.
guidelineid: Identifier for the CPIC guideline(s) associated with the drug.


Table : pair
Columns:
pairid: Unique identifier for the gene-drug pair within the CPIC database.
genesymbol: Gene symbol involved in the pharmacogenetic relationship (e.g., CYP2C19).
drugid: Reference to the associated drug from the drug table.
guidelineid: Identifier for the associated CPIC guideline.
usedforrecommendation: Boolean or flag indicating if the pair is used for specific clinical recommendations.
version: Version of the pair's record for tracking updates.
cpiclevel: CPIC level of evidence for the gene-drug pair (e.g., A, B, C).
pgkbcalevel: PharmGKB clinical annotation level of evidence for the gene-drug pair.
pgxtesting: Details or links about pharmacogenetic testing availability or methods.
citations: References to literature or data sources supporting the gene-drug pair information.
removed: Boolean or flag indicating whether the pair was removed from CPIC guidelines.
removeddate: Date the pair was removed from the guidelines.
removedreason: Reason for removal, such as updated evidence or redundancy.


Table : gene
Columns:
symbol: Standard symbol for the gene (e.g., CYP2D6).
chr: Chromosome where the gene is located.
genesequenceid: Identifier for the gene sequence, often referencing databases like GenBank.
proteinsequenceid: Identifier for the protein sequence produced by the gene.
chromosequenceid: Identifier for the chromosomal sequence where the gene resides.
mrnasequenceid: Identifier for the mRNA sequence of the gene.
hgncid: HGNC (HUGO Gene Nomenclature Committee) ID for the gene.
ncbiid: Identifier for the gene in NCBI’s Gene database.
ensemblid: Ensembl database identifier for the gene.
pharmgkbid: PharmGKB ID for the gene.
frequencymethods: Methods used to determine allele or phenotype frequencies.
lookupmethod: Methodology for identifying the gene in clinical or research settings.
version: Version of the gene's record for tracking updates.
notesondiplotype: Notes or annotations on the gene's diplotype (combination of two haplotypes).
url: Link to more information about the gene.
functionmethods: Methods used to assess gene or protein function.
notesonallelenaming: Notes or annotations on how alleles for the gene are named.
includephenotypefrequencies: Boolean or flag indicating if phenotype frequencies are included for the gene.
includediplotypefrequencies: Boolean or flag indicating if diplotype frequencies are included for the gene.


Table : allele
Columns:
id: Unique identifier for the allele.
version: Version of the allele record.
genesymbol: Symbol for the associated gene (e.g., CYP2D6).
name: Name of the allele.
functionalstatus: Functional status of the allele (e.g., normal, decreased, or increased function).
clinicalfunctionalstatus: Clinical interpretation of the allele's functional status.
clinicalfunctionalsubstrate: Specific substrate relevant to the allele's clinical functional status.
activityvalue: Activity score associated with the allele.
definitionid: Identifier linking to the allele definition.
citations: References or sources supporting the allele data.
strength: Strength of evidence for the allele data.
functioncomments: Comments or notes about the allele’s function.
findings: Observed findings related to the allele.
frequency: Population frequency of the allele.
inferredfrequency: Inferred frequency based on available data



Table : allele_definition
Columns:
id: Unique identifier for the allele definition.
version: Version of the allele definition record.
genesymbol: Symbol for the associated gene.
name: Name of the allele definition.
pharmvarid: Identifier in the PharmVar database.
matchesreferencesequence: Indicates whether the allele matches the reference sequence.
structuralvariation: Details about structural variations in the allele.
allele_frequency
alleleid: Identifier for the associated allele.
population: Population for which the frequency is reported.
frequency: Reported frequency of the allele in the population.
label: Label or description for the frequency data.
version: Version of the allele frequency record.
allele_location_value
alleledefinitionid: Identifier for the associated allele definition.
locationid: Identifier for the genomic location.
variantallele: Details about the variant allele at the location.
version: Version of the location value record.


Table : gene_result
Columns:
id: Unique identifier for the gene result record.
genesymbol: Symbol for the associated gene.
result: Reported result for the gene (e.g., genotype or phenotype).
activityscore: Activity score for the gene result.
ehrpriority: Priority level for integration into Electronic Health Records (EHR).
consultationtext: Text for clinical consultation based on the gene result.
version: Version of the gene result record.
frequency: Frequency of the result in the population.


Table : gene_result_diplotype
Columns:
id: Unique identifier for the gene result diplotype record.
functionphenotypeid: Identifier for the associated functional phenotype.
diplotype: Combination of haplotypes for a gene.
diplotypekey: Key used to identify the diplotype.
frequency: Frequency of the diplotype in the population.



Table : guideline
Columns:
id: Unique identifier for the guideline.
version: Version of the guideline record.
name: Name of the guideline.
url: Link to the guideline document.
pharmgkbid: PharmGKB identifier for the guideline.
genes: List of genes associated with the guideline.
notesonusage: Notes or comments on the guideline's usage.


Table : population
Columns:
id: Unique identifier for the population record.
publicationid: Identifier for the associated publication.
ethnicity: Ethnic group of the population.
population: Name of the population group.
populationinfo: Additional information about the population.
subjecttype: Type of subjects included in the study.
subjectcount: Number of subjects in the population.
version: Version of the population record.


Table : recommendation
Columns:
id: Unique identifier for the recommendation.
guidelineid: Identifier for the associated guideline.
drugid: Identifier for the drug associated with the recommendation.
implications: Clinical implications of the recommendation.
drugrecommendation: Specific drug recommendation.
classification: Classification of the recommendation.
phenotypes: Phenotypes relevant to the recommendation.
activityscore: Activity score associated with the recommendation.
allelestatus: Status of alleles related to the recommendation.
lookupkey: Key used for lookup in related databases.
population: Population for which the recommendation applies.
comments: Additional comments on the recommendation.
version: Version of the recommendation record.
dosinginformation: Specific dosing information.
alternatedrugavailable: Indicates if alternate drugs are available.
otherprescribingguidance: Additional guidance for prescribing.



Now, you will be given a user question. Based on the tables and columns provided above, please write an SQLite3 query to retrieve and select all the relative column(s) from the tables to answer the question. All drug names are in lowercase.
Before you generate the SQL, ENSURE THAT EACH COLUMN YOU USE IS IN THE CORRECT TABLE AND NOT FROM ANOTHER. DO NOT MODIFY THE COLUMN NAMES! THERE IS NO COLUMN NAMED 'cpilevel'; THE CORRECT COLUMN NAME IS 'cpiclevel'. DO NOT ADD ```sql``` IN THE SQL QUERY

notice that please give me the sql query ONLY first, then give the explanation for each column you select in format like table_name.column_name: explanation. Seperate these two part by '|', and do not provide any other text. DO NOT USE 'NOT NULL' TO SELECT ROWS
here are the question from user: {question}
""",
    input_variables=["question"],
)

query_rewriter = prompt_sql | llm | StrOutputParser()

sql_query = query_rewriter.invoke({"question": question})
sql_query = sql_query.replace("\n", "").replace("cpilevel", "cpiclevel")
print(sql_query)

SELECT T1.name, T2.genesymbol, T2.pgkbcalevel, T2.pgxtesting, T3.symbol, T3.functionmethods, T3.notesondiplotype FROM drug AS T1 INNER JOIN pair AS T2 ON T1.drugid = T2.drugid INNER JOIN gene AS T3 ON T2.genesymbol = T3.symbol WHERE T1.name = 'simvastatin' AND T2.usedforrecommendation = 1 | drug.name: Name of the drug, pair.genesymbol: Gene symbol involved in the pharmacogenetic relationship, pair.pgkbcalevel: PharmGKB clinical annotation level of evidence for the gene-drug pair, pair.pgxtesting: Details or links about pharmacogenetic testing availability or methods, gene.symbol: Standard symbol for the gene, gene.functionmethods: Methods used to assess gene or protein function, gene.notesondiplotype: Notes or annotations on the gene's diplotype (combination of two haplotypes)


# Need 3 RAG systems here ->

## Vectorstore

In [387]:
import pandas as pd
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.vectorstores import Chroma
from langchain.embeddings import HuggingFaceEmbeddings
from langchain.docstore.document import Document
from langchain_openai import OpenAIEmbeddings

# Loading Embedding Function
model_name = "all-MiniLM-L6-v2.gguf2.f16.gguf"
gpt4all_kwargs = {'allow_download': 'True'}
GPT4AllEmbeddings(
    model_name=model_name,
    gpt4all_kwargs=gpt4all_kwargs,
)

# Loading Vectorstore（Persisted data used.）

persist_directory="./Trial_chroma_langchain"
collection_name = "Trial_v1"

vectorstore = Chroma(
    collection_name=collection_name,
    embedding_function=GPT4AllEmbeddings(
    model_name=model_name,
    gpt4all_kwargs=gpt4all_kwargs,
    ),
    persist_directory=persist_directory
)


In [388]:
def extract_drug_and_tag(input_string):
    try:
        # Split the input string by the colon
        drug, tag = input_string.split(":", 1)
        return {"drug": drug.strip(), "tag": tag.strip()}
    except ValueError:
        # Handle cases where the input is not in the expected format
        return {"error": "Input must be in the format 'A:B'"}

q_vs = extract_drug_and_tag(query_vs)
 # Output: {'drug': 'Simvastatin', 'tag': 'Relevance'}

# 現在可直接使用 vectorstore 的檢索功能
retriever = vectorstore.as_retriever(search_kwargs={"k": 10})


tag = q_vs['tag']
# tag = "lipitor"
# tag = "articles"
drug = q_vs['drug']


if len(drug) == 0:
    query = f"{tag} :"
    results = retriever.invoke(query)
    # print(len(results))
    print(results[0])
else:
    query = f"\"{tag}: information\""
    results = vectorstore.similarity_search(query, filter={"drug_name": f"{drug}"}, k=10)


## GraphRAG

In [None]:
%%bash
export JAVA_HOME=../GraphRAG/java/jdk-21.0.5
../GraphRAG/opt/neo4j-community-5.26.0/bin/neo4j-admin dbms set-initial-password password
../GraphRAG/opt/neo4j-community-5.26.0/bin/neo4j start

In [None]:
%pip install neo4j

In [389]:
from neo4j import GraphDatabase

uri = "bolt://localhost:7687"
username = "neo4j"
password = "password"

driver = GraphDatabase.driver(uri, auth=(username, password))

def query_database(query):
    with driver.session() as session:
        result = session.run(query)
        res = []
        for record in result:
            res.append(record)
        return res

if (len(graph_query) != 0):
    graph_result = query_database(graph_query)
driver.close()

if (len(graph_query) != 0 and len(graph_result) != 0):
    graph_result = graph_result[0]["r"]["description"]
    print(graph_result)
else:
    graph_result = ""

Simvastatin may decrease the excretion rate of Atorvastatin which could result in a higher serum level.


## TabularRAG

In [390]:
import sqlite3

conn = sqlite3.connect('drug.db')

cursor = conn.cursor()

cursor.execute(sql_query.split("|")[0])
sql_result = cursor.fetchall()
print(sql_result)

cursor.close()
conn.close()

[('simvastatin', 'SLCO1B1', '1A', None, 'SLCO1B1', None, None)]


## Fuse info from 3 different sources

In [391]:
### Retrieval Filter

from langchain.prompts import PromptTemplate
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from langchain_core.output_parsers import JsonOutputParser

model_id = 'meta/llama-3.1-405b-instruct' #"meta/llama3.3-70b-instruct"

# LLM
llm = ChatNVIDIA(model=model_id, temperature=0)

prompt = PromptTemplate(
    template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing relevance
    of a retrieved document to a user question. If the document contains keywords related to the user question,
    grade it as relevant. It does not need to be a stringent test. The goal is to filter out erroneous retrievals. \n
    Give a binary score 'yes' or 'no' score to indicate whether the document is relevant to the question. \n
    Provide the binary score as a JSON with a single key 'score' and no premable or explanation.
     <|eot_id|><|start_header_id|>user<|end_header_id|>
    Here is the retrieved document: \n\n {document} \n\n
    Here is the user question: {question} \n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
    """,
    input_variables=["question", "document"],
)

retrieval_grader = prompt | llm | JsonOutputParser()

docs = retriever.invoke(question)

vectorstore_info = ""

for drug in results:
    vectorstore_info += drug.page_content

    
from langchain_core.documents import Document

document_Tab = Document(
    page_content=str(sql_result).replace("\'", "") + str(sql_query.split("|", 1)[1]).replace("\'", ""),
    metadata={"source": "TabularRAG"}
)

document_vs = Document(
    page_content=vectorstore_info,
    metadata={"source": "vectorstore"}
)

docs = [document_Tab, document_vs]

print(retrieval_grader.invoke({"question": question, "document": docs}))

{'score': 'yes'}


In [392]:
### Retrieval Filter

from langchain.prompts import PromptTemplate
from langchain_nvidia_ai_endpoints import ChatNVIDIA
from langchain_core.output_parsers import JsonOutputParser

model_id = 'meta/llama-3.1-405b-instruct' #"meta/llama3.3-70b-instruct"

# LLM
llm = ChatNVIDIA(model=model_id, temperature=0)

prompt = PromptTemplate(
    template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing relevance
    of retrieved documents to a user question. If a piece of document contains keywords related to the user question,
    keep it. If it does not contain keywords related to the user question, disregard it. The goal is to filter out erroneous retrievals. \n
    Provide the remaining documents as a strict JSON object with a single key 'filtered docs' and no premable or explanation.
     <|eot_id|><|start_header_id|>user<|end_header_id|>
    Here are the retrieved documents: \n\n {documents} \n\n
    Here is the user question: {question} \n <|eot_id|><|start_header_id|>assistant<|end_header_id|>
    """,
    input_variables=["question", "documents"],
)


retrieval_filter = prompt | llm | JsonOutputParser()
documents = docs

filtered_retrieval = retrieval_filter.invoke({"question": question, "documents": documents})
print(retrieval_filter)


first=PromptTemplate(input_variables=['documents', 'question'], input_types={}, partial_variables={}, template="<|begin_of_text|><|start_header_id|>system<|end_header_id|> You are a grader assessing relevance\n    of retrieved documents to a user question. If a piece of document contains keywords related to the user question,\n    keep it. If it does not contain keywords related to the user question, disregard it. The goal is to filter out erroneous retrievals. \n\n    Provide the remaining documents as a strict JSON object with a single key 'filtered docs' and no premable or explanation.\n     <|eot_id|><|start_header_id|>user<|end_header_id|>\n    Here are the retrieved documents: \n\n {documents} \n\n\n    Here is the user question: {question} \n <|eot_id|><|start_header_id|>assistant<|end_header_id|>\n    ") middle=[ChatNVIDIA(base_url='https://integrate.api.nvidia.com/v1', model='meta/llama-3.1-405b-instruct', temperature=0.0)] last=JsonOutputParser()


## Nvidia Re-ranker

In [393]:
temp = list()

for f in filtered_retrieval['filtered docs']:
    t = Document(page_content=f["page_content"])
    temp.append(t)


In [394]:
### Nvidia ReRanker
from langchain_nvidia_ai_endpoints import NVIDIARerank
ranker = NVIDIARerank(model= "nvidia/llama-3.2-nv-rerankqa-1b-v1", truncate="END") #model="nvidia/nv-rerankqa-mistral-4b-v3"
#api_key="$API_KEY_REQUIRED_IF_EXECUTING_OUTSIDE_NGC"

all_docs = temp # filtered vectorstore + graph + sql

ranker.top_n = 5
docs = ranker.compress_documents(query=query, documents=all_docs)
docs ### it should still contain the tags for vec, kg and sql

[Document(metadata={'relevance_score': -1.888671875}, page_content='For patients with the SLCO1B1*5 genotype, a maximum daily dose of 20mg of simvastatin is recommended to avoid adverse effects from the increased exposure to the drug, such as muscle pain and risk of rhabdomyolysis.[F4658]'),
 Document(metadata={'relevance_score': -2.841796875}, page_content='Following an oral dose of 14C-labeled simvastatin in man, 13% of the dose was excreted in urine and 60% in feces.[F4655,F4658]')]

# Answer the questions

In [395]:
model_id = "meta/llama-3.3-70b-instruct"

# LLM
llm = ChatNVIDIA(model=model_id, temperature=0)

prompt = PromptTemplate(
    template="""<|begin_of_text|><|start_header_id|>system<|end_header_id|>
You are a top professional medical doctor at Stanford. You have previously obtained pieces of medication context from a vectorstore, a knowledge graph, and a SQL database. Your goal is to answer medication-related questions from patients as accurately and factually as possible.
If you don't have sufficient information or knowledge to answer, respond with: "I don't have enough information to answer this question."

Below are suggested reasoning steps to use internally before you present your final answer. You must not reveal these steps to the user or mention that you have a reasoning process. These steps are for your own chain-of-thought:

1. **Identify Key Information**: Identify the key medication(s) or condition(s) the patient is asking about.

2. **Identify Types of Information Needed**: Determine the type of information requested: side effects, dosage, drug interactions, indication, mechanism of action, route of elimination, toxicity, food interactions, or adverse drug reactions.

3. **Assess Data Sources**: Consider which data sources (vectorstore (vc), knowledge graph (kg), SQL database (sql)) would be most relevant for the query at hand, even if you won't actually retrieve the data.
   (a) Consult vectorstore data for general medication background.
   (b) If the question involves how one drug relates to another drug (e.g., drug interactions, not food interactions), check the knowledge graph data.
   (c) If the question involves standardized data (e.g., drug to gene relationship information), check the SQL database.

4. **Formulate Steps for Information Gathering**:
   - **Drug Interactions**: Outline steps to check for known drug interactions, contraindications, or safety precautions.
   - **Dosage Information**: Detail the process to verify the recommended dosages, considering patient factors like age, weight, or existing conditions.
   - **Side Effects**: List steps to gather known side effects, their prevalence, and severity.
   - **Indication**: Identify the approved uses of the medication.
   - **Mechanism of Action**: Describe how the drug works at a molecular or biochemical level.
   - **Route of Elimination**: Determine the primary routes by which the drug is removed from the body.
   - **Toxicity**: Assess any known toxic effects or overdose symptoms.
   - **Food Interactions**: Check for any interactions between the medication and food or dietary components.
   - **Adverse Drug Reactions**: Gather information on adverse reactions reported with the drug.

5. **Synthesize Information**: Integrate information from all sources, ensuring consistency and accuracy.

6. **Provide Answer or Disclaimer**:
   - If sufficient data can be synthesized and confident with the information, provide a direct, evidence-based answer.
   - If there's not enough information available or not confident, state so explicitly.

Return your answer as a strict JSON object with a single key "answer."

Example:
For a question like: "Can I take ibuprofen with aspirin?"

Internal reasoning:

 - Identify that the query is about drug interactions between ibuprofen and aspirin.
 - Consider checking a knowledge graph for known drug interactions.
 - Look for common side effects or contraindications when these drugs are combined.
 - Evaluate if there are any specific patient conditions or warnings to consider.
 - Synthesize the information to determine if it's safe to take ibuprofen with aspirin.

A possible JSON might be:
{{
    "answer: 'It's generally safe to take ibuprofen with aspirin, but monitor for increased risk of bleeding or stomach irritation. However, always consult with a healthcare provider for your specific case.'"
}}

<|eot_id|><|start_header_id|>user<|end_header_id|>
Question: {question}
Context: {context}
<|eot_id|><|start_header_id|>assistant<|end_header_id|>
""",
    input_variables=["question", "context"],
)

rag_chain = prompt | llm | JsonOutputParser()

generation = rag_chain.invoke({"context": docs, "question": question})

In [396]:
question

'Are there specific formulations of Simvastatin designed for patients with genetic risks?'

In [397]:
generation

{'answer': 'There are no specific formulations of Simvastatin designed for patients with genetic risks, but dosing recommendations vary based on genetic factors. For patients with the SLCO1B1*5 genotype, a maximum daily dose of 20mg of simvastatin is recommended to minimize the risk of adverse effects such as muscle pain and rhabdomyolysis.'}

Implement these as a control flow in LangGraph.

In [None]:
from typing_extensions import TypedDict
from typing import List
import json

### State


class GraphState(TypedDict):
    """
    Represents the state of our graph.

    Attributes:
        question: question
        generation: LLM generation
        web_search: whether to add search
        documents: list of documents
    """

    question: str
    generation: str
    web_search: str
    three_queries: List[str]
    documents: List[str]


from langchain.schema import Document

### Nodes

def query_rewriter(state):
    """
    Retrieve documents from vectorstore

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, documents, that contains retrieved documents
    """
    print("---RETRIEVE---")
    question = state["question"]

    # Retrieval
    three_queries = query_rewriter.invoke(question)
    return {"three_queries": three_queries, "question": question}
    # return {"documents": documents, "question": question}

def retrieve(state):
    """
    Retrieve documents from vectorstore

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, documents, that contains retrieved documents
    """
    print("---RETRIEVE---")
    question = state["question"]
    three_queries = state["three_queries"]

    # Retrieval
    vc = retriever.invoke(three_queries[0])
    kg = retriever.invoke(three_queries[1])
    sql = retriever.invoke(three_queries[2])
    documents = {"vc": vc, "kg": kg, "sql": sql}

    return {"documents": documents, "question": question}


def grade_documents(state):
    """
    Determines whether the retrieved documents are relevant to the question
    If any document is not relevant, we will set a flag to run web search

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): Filtered out irrelevant documents and updated web_search state
    """

    print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
    question = state["question"]
    documents = state["documents"]

    # # Score each doc
    # filtered_docs = []
    # web_search = "No"
    # for t in documents:
    #   for d in t:
    #     score = retrieval_grader.invoke(
    #         {"question": question, "document": d.page_content}
    #     )
    #     grade = score["score"]
    #     # Document relevant
    #     if grade.lower() == "yes":
    #         print("---GRADE: DOCUMENT RELEVANT---")
    #         filtered_docs.append(d)
    #     # Document not relevant
    #     else:
    #         print("---GRADE: DOCUMENT NOT RELEVANT---")
    #         # We do not include the document in filtered_docs
    #         # We set a flag to indicate that we want to run web search
    #         # web_search = "Yes"
    #         # continue
    # return {"documents": filtered_docs, "question": question}


    for section, doc_list in documents.items():
        # Initialize the list for the current key in filtered_docs
        filtered_docs[section] = []
        print(f"Processing section: {section} with {len(doc_list)} documents.")

        for doc in doc_list:
            # Determine the content to pass to the retrieval grader
            # Attempt to extract 'page_content'; if unavailable, use a string representation of the entire document
            document_content = doc.get("page_content", None)
            if document_content is None:
                # Option 1: Use the entire document serialized as a JSON string
                import json
                document_content = json.dumps(doc)
                print(f"---INFO: 'page_content' not found for Document ID {doc.get('doc_id', 'N/A')}. Using serialized document content.")

            # Invoke the retrieval grader with the question and document content
            score = retrieval_grader.invoke(
                {"question": question, "document": document_content}
            )
            grade = score.get("score", "").lower()

            # Check if the grade is 'yes' indicating relevance
            if grade == "yes":
                print(f"---GRADE: Document ID {doc.get('doc_id', 'N/A')} is RELEVANT---")
                filtered_docs[section].append(doc)
            else:
                print(f"---GRADE: Document ID {doc.get('doc_id', 'N/A')} is NOT RELEVANT---")
    return {
        "documents": filtered_docs,
        "question": question,
    }

def rerank (state):
    """
    Rerank documents from vectorstore, kg, and sql

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, documents, that contains retrieved documents
    """
    print("---rerank---")
    question = state["question"]
    documents = state["documents"]

    # Rerank
    # all_docs = documents # filtered vectorstore + graph + sql
    # docs = ranker.compress_documents(query=query, documents=all_docs)

    ranker.top_n = 3
    reranked_docs = {}

    # Iterate through each category in documents
    for category, doc_list in documents.items():
        print(f"\nReranking category: '{category}'")

        if not doc_list:
            print(f"---INFO: No documents found in category '{category}'. Skipping reranking.")
            reranked_docs[category] = []
            continue

        try:
            # Rerank documents within the current category
            # Assuming 'ranker.compress_documents' accepts 'query' and 'documents' as arguments
            # and returns a list of documents sorted by relevance.
            reranked_list = ranker.compress_documents(query=question, documents=doc_list)

            # Assign the reranked list directly without limiting to top_n
            reranked_docs[category] = reranked_list

            print(f"---INFO: Reranked {len(reranked_list)} documents for category '{category}'.")

        except Exception as e:
            print(f"---ERROR: Failed to rerank documents in category '{category}'. Error: {e}")
            # Optionally, retain the original documents in case of failure
            reranked_docs[category] = doc_list  # or set to [] if preferred

    return {"documents": reranked_docs, "question": question}


def generate(state):
    """
    Generate answer using RAG on retrieved documents

    Args:
        state (dict): The current graph state

    Returns:
        state (dict): New key added to state, generation, that contains LLM generation
    """
    print("---GENERATE---")
    question = state["question"]
    documents = state["documents"]

    # RAG generation
    generation = rag_chain.invoke({"context": documents, "question": question})
    # return {"documents": documents, "question": question, "generation": generation}
    return {"documents": documents, "question": question, "generation": generation}





# def web_search(state):
#     """
#     Web search based based on the question

#     Args:
#         state (dict): The current graph state

#     Returns:
#         state (dict): Appended web results to documents
#     """

#     print("---WEB SEARCH---")
#     question = state["question"]
#     documents = state["documents"]

#     # Web search
#     docs = web_search_tool.invoke({"query": question})
#     web_results = "\n".join([d["content"] for d in docs])
#     web_results = Document(page_content=web_results)
#     if documents is not None:
#         documents.append(web_results)
#     else:
#         documents = [web_results]
#     return {"documents": documents, "question": question}


### Conditional edge


# def route_question(state):
#     """
#     Route question to web search or RAG.

#     Args:
#         state (dict): The current graph state

#     Returns:
#         str: Next node to call
#     """

#     print("---ROUTE QUESTION---")
#     question = state["question"]
#     print(question)
#     source = question_router.invoke({"question": question})
#     print(source)
#     print(source["datasource"])
#     if source["datasource"] == "web_search":
#         print("---ROUTE QUESTION TO WEB SEARCH---")
#         return "websearch"
#     elif source["datasource"] == "vectorstore":
#         print("---ROUTE QUESTION TO RAG---")
#         return "vectorstore"


# def decide_to_generate(state):
#     """
#     Determines whether to generate an answer, or add web search

#     Args:
#         state (dict): The current graph state

#     Returns:
#         str: Binary decision for next node to call
#     """

#     print("---ASSESS GRADED DOCUMENTS---")
#     question = state["question"]
#     web_search = state["web_search"]
#     filtered_documents = state["documents"]

#     if web_search == "Yes":
#         # All documents have been filtered check_relevance
#         # We will re-generate a new query
#         print(
#             "---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, INCLUDE WEB SEARCH---"
#         )
#         return "websearch"
#     else:
#         # We have relevant documents, so generate answer
#         print("---DECISION: GENERATE---")
#         return "generate"


### Conditional edge


# def grade_generation_v_documents_and_question(state):
#     """
#     Determines whether the generation is grounded in the document and answers question.

#     Args:
#         state (dict): The current graph state

#     Returns:
#         str: Decision for next node to call
#     """

#     print("---CHECK HALLUCINATIONS---")
#     question = state["question"]
#     documents = state["documents"]
#     generation = state["generation"]

#     score = hallucination_grader.invoke(
#         {"documents": documents, "generation": generation}
#     )
#     grade = score["score"]

#     # Check hallucination
#     if grade == "yes":
#         print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
#         # Check question-answering
#         print("---GRADE GENERATION vs QUESTION---")
#         score = answer_grader.invoke({"question": question, "generation": generation})
#         grade = score["score"]
#         if grade == "yes":
#             print("---DECISION: GENERATION ADDRESSES QUESTION---")
#             return "useful"
#         else:
#             print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
#             return "not useful"
#     else:
#         print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
#         return "not supported"


from langgraph.graph import START, END, StateGraph

workflow = StateGraph(GraphState)

# Define the nodes
# workflow.add_node("websearch", web_search)  # web search
workflow.add_node("query_rewriter", query_rewriter)  # query rewriter
workflow.add_node("retrieve", retrieve)  # retrieve
workflow.add_node("rerank", rerank)
workflow.add_node("grade_documents", grade_documents)  # grade documents
workflow.add_node("generate", generate)  # generate

In [None]:
# from typing_extensions import TypedDict
# from typing import List, Annotated
# import operator
# # from langgraph.graph.message import add_messages

# ### State

# class GraphState(TypedDict):
#     """
#     Represents the state of our graph.

#     Attributes:
#         question: question
#         generation: LLM generation
#         web_search: whether to add search
#         documents: list of documents
#     """
#     question : str
#     generation : str
#     web_search : str
#     # documents : List[str]
#     documents : Annotated[List[str], operator.add]

# # Define initial_state
# initial_state = GraphState(
#     question="",
#     generation="",
#     web_search="No",
#     documents=[],
#     #memory=memory,
# )

# from langchain.schema import Document

# ### Nodes

# def retrieve(state):
#     """
#     Retrieve documents from vectorstore

#     Args:
#         state (dict): The current graph state

#     Returns:
#         state (dict): New key added to state, documents, that contains retrieved documents
#     """
#     print("---RETRIEVE---")
#     question = state["question"]

#     # Retrieval
#     documents = retriever.invoke(question)
#     return {"documents": documents, "question": question}

# def generate(state):
#     """
#     Generate answer using RAG on retrieved documents

#     Args:
#         state (dict): The current graph state

#     Returns:
#         state (dict): New key added to state, generation, that contains LLM generation
#     """
#     print("---GENERATE---")
#     question = state["question"]
#     documents = state["documents"][-4:]

#     # RAG generation
#     generation = rag_chain.invoke({"context": documents, "question": question})
#     return {"documents": documents, "question": question, "generation": generation}

# def grade_documents(state):
#     """
#     Determines whether the retrieved documents are relevant to the question
#     If any document is not relevant, we will set a flag to run web search

#     Args:
#         state (dict): The current graph state

#     Returns:
#         state (dict): Filtered out irrelevant documents and updated web_search state
#     """

#     print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
#     question = state["question"]
#     documents = state["documents"][-4:]

#     # Score each doc
#     filtered_docs = []
#     web_search = "No"
#     for d in documents[:4]: #for d in documents:
#         score = retrieval_grader.invoke({"question": question, "document": d.page_content})
#         grade = score['score']
#         print(d)
#         # Document relevant
#         if grade.lower() == "yes":
#             print("---GRADE: DOCUMENT RELEVANT---")
#             filtered_docs.append(d)
#         # Document not relevant
#         else:
#             print("---GRADE: DOCUMENT NOT RELEVANT---")
#             # We do not include the document in filtered_docs
#             # We set a flag to indicate that we want to run web search
#             # web_search = "Yes"
#             continue
#     if not filtered_docs:
#       web_search = "Yes"
#     return {"documents": filtered_docs, "question": question, "web_search": web_search}

# def web_search(state):
#     """
#     Web search based based on the question

#     Args:
#         state (dict): The current graph state

#     Returns:
#         state (dict): Appended web results to documents
#     """

#     print("---WEB SEARCH---")
#     question = state["question"]
#     documents = state["documents"][-4:]

#     # Web search
#     docs = web_search_tool.run(({question}))#({"query": question})
#     # docs = web_search_tool.invoke({"query": question})


#     # Handle the docs as a list

#     web_results = "\n".join([d for d in docs.split('...')])
#     web_results = Document(page_content=web_results)

#     # web_results = "\n".join([d["content"] for d in docs])
#     # web_results = Document(page_content=web_results)

#     if documents is not None:
#         documents.append(web_results)
#     else:
#         documents = [web_results]


#     print(web_results)
#     print(documents) ### test
#     return {"documents": documents, "question": question}

# ### Conditional edge

# def route_question(state):
#     """
#     Route question to web search or RAG.

#     Args:
#         state (dict): The current graph state

#     Returns:
#         str: Next node to call
#     """

#     print("---ROUTE QUESTION---")
#     question = state["question"]
#     print(question)
#     source = question_router.invoke({"question": question})
#     print(source)
#     print(source['datasource'])
#     if source['datasource'] == 'web_search':
#         print("---ROUTE QUESTION TO WEB SEARCH---")
#         return "websearch"
#     elif source['datasource'] == 'vectorstore':
#         print("---ROUTE QUESTION TO RAG---")
#         return "vectorstore"

# def decide_to_generate(state):
#     """
#     Determines whether to generate an answer, or add web search

#     Args:
#         state (dict): The current graph state

#     Returns:
#         str: Binary decision for next node to call
#     """

#     print("---ASSESS GRADED DOCUMENTS---")
#     question = state["question"]
#     web_search = state["web_search"]
#     filtered_documents = state["documents"][-4:] #

#     if web_search == "Yes":
#         # All documents have been filtered check_relevance
#         # We will re-generate a new query
#         print("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, INCLUDE WEB SEARCH---")
#         return "websearch"
#     else:
#         # We have relevant documents, so generate answer
#         print("---DECISION: GENERATE---")
#         return "generate"

# ### Conditional edge

# def grade_generation_v_documents_and_question(state):
#     """
#     Determines whether the generation is grounded in the document and answers question.

#     Args:
#         state (dict): The current graph state

#     Returns:
#         str: Decision for next node to call
#     """

#     print("---CHECK HALLUCINATIONS---")
#     question = state["question"]
#     documents = state["documents"][-4:]
#     generation = state["generation"] #

#     score = hallucination_grader.invoke({"documents": documents, "generation": generation})
#     grade = score['score']

#     # Check hallucination
#     if grade == "yes":
#         print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
#         # Check question-answering
#         print("---GRADE GENERATION vs QUESTION---")
#         score = answer_grader.invoke({"question": question,"generation": generation})
#         grade = score['score']
#         if grade == "yes":
#             print("---DECISION: GENERATION ADDRESSES QUESTION---")
#             return "useful"
#         else:
#             print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
#             return "not useful"
#     else:
#         print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
#         return "not supported"

# from langgraph.graph import END, StateGraph
# workflow = StateGraph(GraphState)

# # Define the nodes
# workflow.add_node("websearch", web_search) # web search
# workflow.add_node("retrieve", retrieve) # retrieve
# workflow.add_node("grade_documents", grade_documents) # grade documents
# workflow.add_node("generate", generate) # generate

### Graph Build

In [None]:
# Build graph

# Set the entry point to the first node
workflow.add_edge(START, "query_rewriter")

# Define the sequential flow between nodes
workflow.add_edge("query_rewriter", "retrieve")
workflow.add_edge("retrieve", "rerank")
workflow.add_edge("rerank", "grade_documents")
workflow.add_edge("grade_documents", "generate")

# Define the end of the workflow
workflow.add_edge("generate", END)

In [None]:
# # Build graph
# workflow.set_conditional_entry_point(
#     route_question,
#     {
#         "websearch": "websearch",
#         "vectorstore": "retrieve",
#     },
# )

# workflow.add_edge("retrieve", "grade_documents")
# workflow.add_conditional_edges(
#     "grade_documents",
#     decide_to_generate,
#     {
#         "websearch": "websearch",
#         "generate": "generate",
#     },
# )
# workflow.add_edge("websearch", "generate")
# workflow.add_conditional_edges(
#     "generate",
#     grade_generation_v_documents_and_question,
#     {
#         "not supported": "generate",
#         "useful": END,
#         "not useful": "websearch",
#     },
# )

Trace:

https://smith.langchain.com/public/8d449b67-6bc4-4ecf-9153-759cd21df24f/r

# **Adding Memory to the Agent**

In [None]:
# The checkpointer lets the graph persist its state
from langgraph.checkpoint.sqlite import SqliteSaver
memory = SqliteSaver.from_conn_string(":memory:")

In [None]:
import uuid
_printed = set()
thread_id = str(uuid.uuid4())

config = {
    "configurable": {
        # Checkpoints are accessed by thread_id
        "thread_id": "0", #thread_id,
    }
}

# **Single Cell Agent Flow**

In [None]:
# import os
# # import streamlit as st
# import pandas as pd
# import openai
# from langchain.schema import Document
# from langchain.memory import ConversationBufferMemory
# from typing_extensions import TypedDict
# from typing import List
# from langchain.text_splitter import RecursiveCharacterTextSplitter
# from langchain_community.vectorstores import Chroma
# from langchain_community.embeddings import GPT4AllEmbeddings
# from langchain_core.prompts import ChatPromptTemplate
# from langchain_openai import ChatOpenAI
# from langchain_core.output_parsers import JsonOutputParser, StrOutputParser
# from langchain_community.utilities import GoogleSerperAPIWrapper
# from langchain.prompts import PromptTemplate
# from langgraph.graph import END, StateGraph
# import gradio as gr

# # Define tools
# os.environ['SERPER_API_KEY'] = "4bf2ebebf278aeef645193a533a3b94d7b011e17"
# serper_api_key = os.getenv('SERPER_API_KEY')
# web_search_tool = GoogleSerperAPIWrapper(api_key=serper_api_key, k=10)

# # Initialize ChatOpenAI
# openai.api_key = "sk-proj-M8297PbB5Nekd64YyBlpT3BlbkFJA5xRy149A4ceEcmke54P"
# os.environ["OPENAI_API_KEY"] = openai.api_key
# llm = ChatOpenAI(model="gpt-4o", temperature=0.5)

# # Initialize GPT4AllEmbeddings
# model_name = "all-MiniLM-L6-v2.gguf2.f16.gguf"
# gpt4all_kwargs = {'allow_download': 'True'}
# embeddings = GPT4AllEmbeddings(
#     model_name=model_name,
#     gpt4all_kwargs=gpt4all_kwargs
# )

# # Load the CSV file
# csv_file_path = "./UniFi_Help_Articles_Formatted.csv"
# df = pd.read_csv(csv_file_path, encoding='ISO-8859-1')

# # Convert text data to LangChain Document format
# docs_list = [Document(page_content=text) for text in df['body'].tolist()]

# # Split the text into smaller chunks
# text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
#     chunk_size=700, chunk_overlap=100
# )
# doc_splits = text_splitter.split_documents(docs_list)

# # Add to vectorDB
# vectorstore = Chroma.from_documents(
#     documents=doc_splits,
#     collection_name="rag-chroma",
#     embedding=GPT4AllEmbeddings(model_name=model_name, gpt4all_kwargs=gpt4all_kwargs),
# )
# retriever = vectorstore.as_retriever()


# ### Define tools ###

# ### Retrieval Grader

# from langchain.prompts import PromptTemplate
# from langchain_core.output_parsers import JsonOutputParser


# # LLM GPT-4o
# llm = ChatOpenAI(model="gpt-4o", temperature=0)


# prompt = PromptTemplate(
#     template="""system
#     You are a grader assessing relevance of a retrieved document to a user question. Please note that every user question is about Ubiquiti's *UniFi* consoles or devices or software service even if not explicitly specified.
#     If the document contains certain words / keywords related to the user question or the same words appeared in the user question, make sure to grade it as relevant (yes).
#     It does not need to be a stringent test. The goal is to accurately filter out erroneous retrievals and keep the correct retrievals. Give a binary 'yes' or 'no' score to indicate whether the document is relevant to the question or not.
#     Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.

#     user
#     Here is the retrieved document: {document}
#     Here is the user question: {question}

#     assistant""",
#     input_variables=["question", "document"],
# )

# retrieval_grader = prompt | llm | JsonOutputParser()

# ### Generate

# from langchain.prompts import PromptTemplate
# from langchain import hub
# from langchain_core.output_parsers import StrOutputParser


# # Prompt
# prompt = PromptTemplate(
#     template="""system
#     You are an assistant for question-answering tasks. Use the following pieces of retrieved context to carefully answer the question by staying grounded in the facts/context and being relevant.
#     Please note that every user question is about Ubiquiti's *UniFi* consoles or devices or software service even if not explicitly specified.
#     Check the urls in the context and make sure they do not contain any parenthesis (Remove the parenthesis and any url link inside them when you generate answers).
#     If you don't know the answer or don't see relevant keywords in the context, just say that you don't know. Use ten sentences maximum and keep the answer concise.

#     user
#     Question: {question}
#     Context: {context}
#     Answer:

#     assistant""",
#     input_variables=["question", "context"],
# )

# # Post-processing
# def format_docs(docs):
#     return "\n\n".join(doc.page_content for doc in docs)

# # Chain
# rag_chain = prompt | llm | StrOutputParser()

# ### Hallucination Grader

# # Prompt
# prompt = PromptTemplate(
#     template="""system
#     You are a good grader assessing whether an answer is grounded in / supported by a set of facts / documents fetched.
#     Please note that every user question is about Ubiquiti's *UniFi* consoles or devices or software service even if not explicitly specified.
#     Give a binary 'yes' or 'no' score to indicate whether the answer is grounded in / supported by the documents / facts provided or not.
#     Look at all the documents / facts carefully and understand everything about the answer and the documents.
#     Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.

#     user
#     Here are the facts:
#     -------
#     {documents}
#     -------
#     Here is the answer: {generation}

#     assistant""",
#     input_variables=["generation", "documents"],
# )

# hallucination_grader = prompt | llm | JsonOutputParser()

# ### Answer Grader

# # Prompt
# prompt = PromptTemplate(
#     template="""system
#     You are a grader assessing whether an answer is useful to resolve a question. Please note that every user question is about Ubiquiti's *UniFi* consoles or devices or software service even if not explicitly specified.
#     Give a binary 'yes' or 'no' to indicate whether the answer is useful to resolve a question. Provide the binary score as a JSON with a single key 'score' and no preamble or explanation.

#     user
#     Here is the answer:
#     -------
#     {generation}
#     -------
#     Here is the question: {question}

#     assistant""",
#     input_variables=["generation", "question"],
# )

# answer_grader = prompt | llm | JsonOutputParser()

# ### Router

# ### Original Prompt for GPT-4o
# from langchain.prompts import PromptTemplate
# from langchain_openai import ChatOpenAI
# from langchain_core.output_parsers import JsonOutputParser

# ### Original Prompt for GPT-4o
# prompt = PromptTemplate(
#     template="""system
#     You are an expert at routing a user question to a vectorstore or web search. Use the vectorstore for questions on UniFi products / consoles / sevices, etc. Pay special attention to devices like Dream Machine.
#     You do not need to be stringent with the keywords in the question related to these topics.Otherwise, use web search. Give a binary choice 'web_search' or 'vectorstore' based on the question. Return the a JSON with a single key 'datasource' and no preamble or explanation.

#     user
#     Question to route: {question}

#     assistant""",
#     input_variables=["question"],
# )

# question_router = prompt | llm | JsonOutputParser()


# ### Google Search
# import os
# from langchain.schema import Document
# from langchain_community.utilities import GoogleSerperAPIWrapper

# os.environ['SERPER_API_KEY'] = "4bf2ebebf278aeef645193a533a3b94d7b011e17"
# # Initialize the Serper search tool with your API key
# serper_api_key = os.getenv('SERPER_API_KEY')
# web_search_tool = GoogleSerperAPIWrapper(api_key=serper_api_key, k=10)

# from typing_extensions import TypedDict
# from typing import List

# ### State

# class GraphState(TypedDict):
#     """
#     Represents the state of our graph.

#     Attributes:
#         question: question
#         generation: LLM generation
#         web_search: whether to add search
#         documents: list of documents
#     """
#     question : str
#     generation : str
#     web_search : str
#     documents : List[str]

# # Define initial_state
# initial_state = GraphState(
#     question="",
#     generation="",
#     web_search="No",
#     documents=[],
#     #memory=memory,
# )

# from langchain.schema import Document

# ### Nodes

# def retrieve(state):
#     """
#     Retrieve documents from vectorstore

#     Args:
#         state (dict): The current graph state

#     Returns:
#         state (dict): New key added to state, documents, that contains retrieved documents
#     """
#     print("---RETRIEVE---")
#     question = state["question"]

#     # Retrieval
#     documents = retriever.invoke(question)
#     return {"documents": documents, "question": question}

# def generate(state):
#     """
#     Generate answer using RAG on retrieved documents

#     Args:
#         state (dict): The current graph state

#     Returns:
#         state (dict): New key added to state, generation, that contains LLM generation
#     """
#     print("---GENERATE---")
#     question = state["question"]
#     documents = state["documents"]

#     # RAG generation
#     generation = rag_chain.invoke({"context": documents, "question": question})
#     return {"documents": documents, "question": question, "generation": generation}

# def grade_documents(state):
#     """
#     Determines whether the retrieved documents are relevant to the question
#     If any document is not relevant, we will set a flag to run web search

#     Args:
#         state (dict): The current graph state

#     Returns:
#         state (dict): Filtered out irrelevant documents and updated web_search state
#     """

#     print("---CHECK DOCUMENT RELEVANCE TO QUESTION---")
#     question = state["question"]
#     documents = state["documents"]

#     # Score each doc
#     filtered_docs = []
#     web_search = "No"
#     for d in documents:
#         score = retrieval_grader.invoke({"question": question, "document": d.page_content})
#         grade = score['score']
#         print(d)
#         # Document relevant
#         if grade.lower() == "yes":
#             print("---GRADE: DOCUMENT RELEVANT---")
#             filtered_docs.append(d)
#         # Document not relevant
#         else:
#             print("---GRADE: DOCUMENT NOT RELEVANT---")
#             # We do not include the document in filtered_docs
#             # We set a flag to indicate that we want to run web search
#             # web_search = "Yes"
#             continue
#     if not filtered_docs:
#       web_search = "Yes"
#     return {"documents": filtered_docs, "question": question, "web_search": web_search}

# def web_search(state):
#     """
#     Web search based based on the question

#     Args:
#         state (dict): The current graph state

#     Returns:
#         state (dict): Appended web results to documents
#     """

#     print("---WEB SEARCH---")
#     question = state["question"]
#     documents = state["documents"]

#     # Web search
#     docs = web_search_tool.run(({question}))#({"query": question})
#     # docs = web_search_tool.invoke({"query": question})


#     # Handle the docs as a list

#     web_results = "\n".join([d for d in docs.split('...')])
#     web_results = Document(page_content=web_results)

#     # web_results = "\n".join([d["content"] for d in docs])
#     # web_results = Document(page_content=web_results)

#     if documents is not None:
#         documents.append(web_results)
#     else:
#         documents = [web_results]


#     print(web_results)
#     print(documents) ### test
#     return {"documents": documents, "question": question}

# ### Conditional edge

# def route_question(state):
#     """
#     Route question to web search or RAG.

#     Args:
#         state (dict): The current graph state

#     Returns:
#         str: Next node to call
#     """

#     print("---ROUTE QUESTION---")
#     question = state["question"]
#     print(question)
#     source = question_router.invoke({"question": question})
#     print(source)
#     print(source['datasource'])
#     if source['datasource'] == 'web_search':
#         print("---ROUTE QUESTION TO WEB SEARCH---")
#         return "websearch"
#     elif source['datasource'] == 'vectorstore':
#         print("---ROUTE QUESTION TO RAG---")
#         return "vectorstore"

# def decide_to_generate(state):
#     """
#     Determines whether to generate an answer, or add web search

#     Args:
#         state (dict): The current graph state

#     Returns:
#         str: Binary decision for next node to call
#     """

#     print("---ASSESS GRADED DOCUMENTS---")
#     question = state["question"]
#     web_search = state["web_search"]
#     filtered_documents = state["documents"]

#     if web_search == "Yes":
#         # All documents have been filtered check_relevance
#         # We will re-generate a new query
#         print("---DECISION: ALL DOCUMENTS ARE NOT RELEVANT TO QUESTION, INCLUDE WEB SEARCH---")
#         return "websearch"
#     else:
#         # We have relevant documents, so generate answer
#         print("---DECISION: GENERATE---")
#         return "generate"

# ### Conditional edge

# def grade_generation_v_documents_and_question(state):
#     """
#     Determines whether the generation is grounded in the document and answers question.

#     Args:
#         state (dict): The current graph state

#     Returns:
#         str: Decision for next node to call
#     """

#     print("---CHECK HALLUCINATIONS---")
#     question = state["question"]
#     documents = state["documents"]
#     generation = state["generation"]

#     score = hallucination_grader.invoke({"documents": documents, "generation": generation})
#     grade = score['score']

#     # Check hallucination
#     if grade == "yes":
#         print("---DECISION: GENERATION IS GROUNDED IN DOCUMENTS---")
#         # Check question-answering
#         print("---GRADE GENERATION vs QUESTION---")
#         score = answer_grader.invoke({"question": question,"generation": generation})
#         grade = score['score']
#         if grade == "yes":
#             print("---DECISION: GENERATION ADDRESSES QUESTION---")
#             return "useful"
#         else:
#             print("---DECISION: GENERATION DOES NOT ADDRESS QUESTION---")
#             return "not useful"
#     else:
#         print("---DECISION: GENERATION IS NOT GROUNDED IN DOCUMENTS, RE-TRY---")
#         return "not supported"

# from langgraph.graph import END, StateGraph
# workflow = StateGraph(GraphState)

# # Define the nodes
# workflow.add_node("websearch", web_search) # web search
# workflow.add_node("retrieve", retrieve) # retrieve
# workflow.add_node("grade_documents", grade_documents) # grade documents
# workflow.add_node("generate", generate) # generatae

# workflow.set_conditional_entry_point(
#     route_question,
#     {
#         "websearch": "websearch",
#         "vectorstore": "retrieve",
#     },
# )

# workflow.add_edge("retrieve", "grade_documents")
# workflow.add_conditional_edges(
#     "grade_documents",
#     decide_to_generate,
#     {
#         "websearch": "websearch",
#         "generate": "generate",
#     },
# )
# workflow.add_edge("websearch", "generate")
# workflow.add_conditional_edges(
#     "generate",
#     grade_generation_v_documents_and_question,
#     {
#         "not supported": "generate",
#         "useful": END,
#         "not useful": "websearch",
#     },
# )

# # # The checkpointer lets the graph persist its state
# # from langgraph.checkpoint.sqlite import SqliteSaver
# memory = SqliteSaver.from_conn_string(":memory:")

# config = {
#     "configurable": {
#         # Checkpoints are accessed by thread_id
#         "thread_id": "1", #thread_id,
#     }
# }

# app = workflow.compile(checkpointer=memory)




# # os.environ["OPENAI_API_KEY"] = "sk-..."  # Replace with your key

# # llm = ChatOpenAI(temperature=0.5, model='gpt-4o')

# # def predict(message):
# #     # history_langchain_format = []
# #     # for human, ai in history:
# #     #     history_langchain_format.append(HumanMessage(content=human))
# #     #     history_langchain_format.append(AIMessage(content=ai))
# #     # history_langchain_format.append(HumanMessage(content=message))

# #     inputs = {
# #         "question": message,
# #     }
# #     for output in app.stream(inputs, config, stream_mode="values"):
# #       results = output

# #     gpt_response = results["generation"]#llm(history_langchain_format)
# #     return gpt_response #gpt_response.content

# # gr.ChatInterface(predict).launch(debug=True)

# **Running the chatbot with gradio**

In [None]:
# from langgraph.checkpoint.aiosqlite import AsyncSqliteSaver

# memory = AsyncSqliteSaver.from_conn_string(":memory:")
# abot = Agent(model, [tool], system=prompt, checkpointer=memory)

In [None]:
# messages = [HumanMessage(content="What is the weather in SF?")]
# thread = {"configurable": {"thread_id": "4"}}
# async for event in abot.graph.astream_events({"messages": messages}, thread, version="v1"):
#     kind = event["event"]
#     if kind == "on_chat_model_stream":
#         content = event["data"]["chunk"].content
#         if content:
#             # Empty content in the context of OpenAI means
#             # that the model is asking for a tool to be invoked.
#             # So we only print non-empty content
#             print(content, end="|")

In [None]:
app = workflow.compile()

def predict(message):

    inputs = {
        "question": message,
    }

    results = None
    try:
        for output in app.stream(inputs, config, stream_mode="values"):
            results = output
    except Exception as e:
        print(f"Error encountered: {e}")
        results = {"generation": ["Sorry, something went wrong. Please try again later."]}

    gpt_response = ["Sorry, I am not able to answer this question."]
    if results.get("generation"):
        gpt_response = results["generation"]
    return gpt_response

predict("")

# gr.ChatInterface(predict).launch(share=True, debug=True)

## gradio implementation

In [None]:
app = workflow.compile(checkpointer=memory)

def predict(message, history):

    inputs = {
        "question": message,
    }

    results = None
    try:
        for output in app.stream(inputs, config, stream_mode="values"):
            results = output
    except Exception as e:
        print(f"Error encountered: {e}")
        results = {"generation": ["Sorry, something went wrong. Please try again later."]}

    gpt_response = ["Sorry, I am not able to answer this question."]
    if results.get("generation"):
        gpt_response = results["generation"]
    return gpt_response

gr.ChatInterface(predict).launch(share=True, debug=True)

In [None]:
app = workflow.compile(checkpointer=memory)

def predict(message, history):
    # history_langchain_format = []
    # for human, ai in history:
    #     history_langchain_format.append(HumanMessage(content=human))
    #     history_langchain_format.append(AIMessage(content=ai))
    # history_langchain_format.append(HumanMessage(content=message)

    inputs = {
        "question": message,
    }
    for output in app.stream(inputs, config, stream_mode="values"):
      results = output

    gpt_response = ["Sorry, I am not able to answer this question."]
    if results["generation"]:
        gpt_response = results["generation"]#llm(history_langchain_format)
    return gpt_response #gpt_response.content

gr.ChatInterface(predict).launch(share=True, debug=True) #(share=True)