In [1]:
from langchain.vectorstores import Chroma
from langchain.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain.prompts import PromptTemplate
from langchain.llms import Ollama
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chains import RetrievalQA
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split


In [2]:
VECTOR_DB_PATH = "../../data/vectorDB/disease_context_chromaDB"
SENTENCE_EMBEDDING_MODEL = "all-MiniLM-L6-v2"
DISEASE_FILE = "../../data/disease_with_TREATS_edge.csv"

TEST_SIZE = 400

In [3]:
disease_df = pd.read_csv(DISEASE_FILE)
disease_array = disease_df.diseases.unique()
disease_array

array(['thyroid gland papillary carcinoma', 'glycogen storage disease II',
       'follicular lymphoma', ..., 'SADDAN', 'systolic heart failure',
       'glycogen storage disease I'], dtype=object)

In [4]:
diseae_validation, diseae_test = train_test_split(disease_array, test_size=TEST_SIZE, random_state=42)
diseae_validation.shape

(1225,)

In [13]:
fetch_k_list = [30, 40]
lambda_mult_list = [0.3, 0.5, 0.75]
k_list = [1, 2, 3]


In [5]:
embedding_function = SentenceTransformerEmbeddings(model_name=SENTENCE_EMBEDDING_MODEL)

vectorstore = Chroma(persist_directory=VECTOR_DB_PATH, 
                     embedding_function=embedding_function)


  from .autonotebook import tqdm as notebook_tqdm


In [42]:
fetch_k = fetch_k_list[1]
lambda_mult = lambda_mult_list[1]


question = "What compound treats parkinson's disease?"
search_result = vectorstore.similarity_search_with_score(question, 
                                                         search_type="mmr", 
                                                         search_kwargs={"fetch_k": fetch_k, "lambda_mult":0.2, "k":2})

# search_result = vectorstore.similarity_search_with_relevance_scores(question, 
#                                                          search_type="similarity_score_threshold", 
#                                                          search_kwargs={"score_threshold":0.4, "k":2})

search_result = vectorstore.similarity_search_with_score(question, 
                                                         search_type="similarity_score_threshold", 
                                                         search_kwargs={"score_threshold":0.4, "k":2})




print(search_result[1])

(Document(page_content="(1) 'phase' is 2 (2) 'sources' is ['ChEMBL'] (3) 'source' is ['ChEMBL', 'DrugCentral'] \nCompound Talipexole TREATS Parkinson's disease. Attributes of this relationship are:\n(1) 'source' is ['DrugCentral'] \nCompound (alphaS)-alpha-Cyclohexyl-alpha-phenyl-1-pyrrolidine-1-propanol TREATS Parkinson's disease. Attributes of this relationship are:\n(1) 'source' is ['DrugCentral'] \nCompound Tricyclamol cation TREATS Parkinson's disease. Attributes of this relationship are:\n(1) 'source' is ['DrugCentral'] \nCompound Adamantan-1-amine sulfate TREATS Parkinson's disease. Attributes of this relationship are:\n(1) 'source' is ['DrugCentral'] \nCompound Cycrimine TREATS Parkinson's disease. Attributes of this relationship are:\n(1) 'source' is ['DrugCentral'] \nCompound N-Propargyl-1(S)-aminoindan TREATS Parkinson's disease. Attributes of this relationship are:\n(1) 'source' is ['DrugCentral'] \nCompound 5-bromo-N-[1-hydroxy-8-oxo-4,7-di(propan-2-yl)-3-oxa-6,9-diazatric

In [47]:
vectorstore.as_retriever(search_type="mmr", search_kwargs={"fetch_k": 40, "lambda_mult":0.2, "k":2})

vectorstore.similarity_search_with_score(question, search_type="mmr")


[(Document(page_content="(1) 'phase' is 3 (2) 'sources' is ['ChEMBL'] \nCompound Memantine TREATS Parkinson's disease. Attributes of this relationship are:\n(1) 'phase' is 3 (2) 'sources' is ['ChEMBL'] \nCompound Orphenadrine TREATS Parkinson's disease. Attributes of this relationship are:\n(1) 'phase' is 4 (2) 'sources' is ['ChEMBL'] \nCompound Pergolide TREATS Parkinson's disease. Attributes of this relationship are:\n(1) 'phase' is 4 (2) 'sources' is ['ChEMBL'] \nCompound (2-(4-((2-Chloro-4,4-difluoro-spiro(5H-thieno(2,3-C)pyran-7,4'-piperidine)-1'-yl)methyl)-3-methyl-pyrazol-1-yl)-3-pyridyl)methanol TREATS Parkinson's disease. Attributes of this relationship are:\n(1) 'phase' is 2 (2) 'sources' is ['ChEMBL'] \nCompound Mevidalen TREATS Parkinson's disease. Attributes of this relationship are:\n(1) 'phase' is 1 (2) 'sources' is ['ChEMBL'] \nCompound Perampanel TREATS Parkinson's disease. Attributes of this relationship are:\n(1) 'phase' is 3 (2) 'sources' is ['ChEMBL'] \nCompound Ca

In [27]:
dir(vectorstore)


['_Chroma__query_collection',
 '_LANGCHAIN_DEFAULT_COLLECTION_NAME',
 '__abstractmethods__',
 '__annotations__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__slots__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_abc_impl',
 '_client',
 '_client_settings',
 '_collection',
 '_cosine_relevance_score_fn',
 '_embedding_function',
 '_euclidean_relevance_score_fn',
 '_get_retriever_tags',
 '_max_inner_product_relevance_score_fn',
 '_persist_directory',
 '_select_relevance_score_fn',
 '_similarity_search_with_relevance_scores',
 'aadd_documents',
 'aadd_texts',
 'add_documents',
 'add_texts',
 'afrom_documents',
 'afrom_texts',
 'amax_marginal_relevance_search',
 'amax_marginal_relevance_search_by_vector',
 'as

In [None]:
# Configuring RAG

template = """Use the following pieces of context to answer the question at the end and also to return the provenance. 
If you don't know the answer, just say that you don't know, don't try to make up an answer.   
{context}
Question: {question}
Helpful Answer:"""

QA_CHAIN_PROMPT = PromptTemplate(
    input_variables=["context", "question"],
    template=template,
)

llm = Ollama(base_url="http://localhost:11434",
             model="llama2:13b",
             temperature=0.01,
             verbose=True,
             callback_manager=CallbackManager([StreamingStdOutCallbackHandler()]))


qa_chain = RetrievalQA.from_chain_type(
    llm,
    retriever=vectorstore.as_retriever(search_type="mmr", search_kwargs={"fetch_k": 30, "lambda_mult":0.5, "k":1}),
    chain_type_kwargs={"prompt": QA_CHAIN_PROMPT},
    return_source_documents=True
)

