### Import libraries, custom classes and functions

In [1]:
from pathlib import Path
from pprint import pprint
import sys
import os
import random

from llama_index.core import ServiceContext, set_global_service_context, set_global_handler
from llama_index.core.node_parser import SentenceSplitter

from task_dataset import PubMedQATaskDataset

sys.path.append("..")
from utils.hosting_utils import RAGLLM
from utils.rag_utils import (
    DocumentReader, RAGEmbedding, RAGQueryEngine, RagasEval, 
    extract_yes_no, evaluate, validate_rag_cfg
    )
from utils.storage_utils import RAGIndex

The chromadb package is not available on this system, skipping


In [2]:
with open(Path.home() / ".cohere_api_key", "r") as f:
    os.environ["COHERE_API_KEY"] = f.read().rstrip("\n")
with open(Path.home() / ".hfhub_api_token", "r") as f:
    os.environ["HUGGINGFACEHUB_API_TOKEN"] = f.read().rstrip("\n")
with open(Path.home() / ".openai_api_key", "r") as f:
    os.environ["OPENAI_API_KEY"] = f.read().rstrip("\n")

### Set RAG configuration

In [3]:
rag_cfg = {
    # Node parser config
    "chunk_size": 256,
    "chunk_overlap": 0,

    # Embedding model config
    "embed_model_type": "hf",
    "embed_model_name": "BAAI/bge-base-en-v1.5",

    # LLM config
    "llm_type": "local",
    "llm_name": "Llama-2-7b-chat-hf",
    "max_new_tokens": 256,
    "temperature": 1.0,
    "top_p": 1.0,
    "top_k": 50,
    "do_sample": False,

    # Vector DB config
    "vector_db_type": "weaviate", # "chromadb", "weaviate"
    "vector_db_name": "Pubmed_QA",
    # MODIFY THIS
    "weaviate_url": "https://rag-bootcamp-pubmed-qa-lsqv7od4.weaviate.network",

    # Retriever and query config
    "retriever_type": "bm25", # "vector_index", "bm25"
    "retriever_similarity_top_k": 3,
    "query_mode": "hybrid", # "default", "hybrid"
    "hybrid_search_alpha": 0.5, # float from 0.0 (sparse search - bm25) to 1.0 (vector search)
    "response_mode": "compact",
    "use_reranker": True,
    "rerank_top_k": 2,
}

### Weaviate Key

In [4]:
try:
    f = open(Path.home() / ".weaviate_api_key", "r")
    f.close()
except Exception as err:
    print(f"Could not read your Weaviate key. Please make sure this is available in plain text under your home directory in ~/.weaviate_api_key: {err}")

## STAGE 0 - Preliminary config checks

In [5]:
pprint(rag_cfg)
validate_rag_cfg(rag_cfg)

{'chunk_overlap': 0,
 'chunk_size': 256,
 'do_sample': False,
 'embed_model_name': 'BAAI/bge-base-en-v1.5',
 'embed_model_type': 'hf',
 'hybrid_search_alpha': 0.5,
 'llm_name': 'Llama-2-7b-chat-hf',
 'llm_type': 'local',
 'max_new_tokens': 256,
 'query_mode': 'hybrid',
 'rerank_top_k': 2,
 'response_mode': 'compact',
 'retriever_similarity_top_k': 3,
 'retriever_type': 'bm25',
 'temperature': 1.0,
 'top_k': 50,
 'top_p': 1.0,
 'use_reranker': True,
 'vector_db_name': 'Pubmed_QA',
 'vector_db_type': 'weaviate',
 'weaviate_url': 'https://rag-bootcamp-pubmed-qa-lsqv7od4.weaviate.network'}


## STAGE 1 - Load dataset and documents

#### 1. Load PubMed QA dataset
PubMedQA ([github](https://github.com/pubmedqa/pubmedqa)) is a biomedical question answering dataset. Each instance consists of a question, a context (extracted from PubMed abstracts), a long answer and a yes/no/maybe answer. We make use of the test split of [this](https://huggingface.co/datasets/bigbio/pubmed_qa) huggingface dataset for this notebook.

**The context for each instance is stored as a text file** (referred to as documents), to align the task as a standard RAG use-case.

In [6]:
print('Loading PubMed QA data ...')
pubmed_data = PubMedQATaskDataset('bigbio/pubmed_qa')
print(f"Loaded data size: {len(pubmed_data)}")
# pubmed_data.mock_knowledge_base(output_dir='./data', one_file_per_sample=True)

Loading PubMed QA data ...


Preparing data: 100%|██████████| 500/500 [00:00<00:00, 1400.04it/s]

Loaded data size: 500





#### 2. Load documents

In [7]:
print('Loading documents ...')
reader = DocumentReader(input_dir="./data/pubmed_doc")
docs = reader.load_data()
print(f'No. of documents loaded: {len(docs)}')

Loading documents ...
No. of documents loaded: 500


## STAGE 2 - Load node parser, embedding, LLM and set service context

#### 1. Load node parser to split documents into smaller chunks

In [8]:
print('Loading node parser ...')
node_parser = SentenceSplitter(chunk_size=rag_cfg['chunk_size'], chunk_overlap=rag_cfg['chunk_overlap'])
# nodes = node_parser.get_nodes_from_documents(docs)

Loading node parser ...


#### 2. Load embedding model

In [9]:
embed_model = RAGEmbedding(model_type=rag_cfg['embed_model_type'], model_name=rag_cfg['embed_model_name']).load_model()

Loading hf embedding model ...


#### 3. Load LLM for generation

In [10]:
llm = RAGLLM(rag_cfg['llm_type'], rag_cfg['llm_name']).load_model(**rag_cfg)

Loading local LLM model ...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

#### 4. Use service context to set the node parser, embedding model, LLM, etc.

In [11]:
service_context = ServiceContext.from_defaults(
    node_parser=node_parser,
    embed_model=embed_model,
    llm=llm,
)
# Set it globally to avoid passing it to every class, this sets it even for rag_utils.py
set_global_service_context(service_context)

  service_context = ServiceContext.from_defaults(


## STAGE 3 - Create index using the appropriate vector store

In [12]:
index = RAGIndex(db_type=rag_cfg['vector_db_type'], db_name=rag_cfg['vector_db_name'])\
    .create_index(docs, weaviate_url=rag_cfg["weaviate_url"])

Loading index from ./.weaviate_index_store/ ...


            Please consider upgrading to the latest version. See https://weaviate.io/developers/weaviate/client-libraries/python for details.


## STAGE 4 - Build query engine

Now build a query engine using *retriever* and *response_synthesizer*  
[Weaviate hybrid search](https://weaviate.io/blog/hybrid-search-explained)

In [13]:
query_engine_args = {
    "similarity_top_k": rag_cfg['retriever_similarity_top_k'], 
    "response_mode": rag_cfg['response_mode'],
}

if (rag_cfg["retriever_type"] == "vector_index") and (rag_cfg["vector_db_type"] == "weaviate"):
    query_engine_args.update({
        "query_mode": rag_cfg["query_mode"], 
        "hybrid_search_alpha": rag_cfg["hybrid_search_alpha"]
    })
elif rag_cfg["retriever_type"] == "bm25":
    nodes = service_context.node_parser.get_nodes_from_documents(docs)
    tokenizer = service_context.embed_model._tokenizer
    query_engine_args.update({"nodes": nodes, "tokenizer": tokenizer})
    
if rag_cfg["use_reranker"]:
    query_engine_args.update({"use_reranker": True, "rerank_top_k": rag_cfg["rerank_top_k"]})

In [14]:
query_engine = RAGQueryEngine(
    retriever_type=rag_cfg['retriever_type'], vector_index=index, llm_model_name=rag_cfg['llm_name']).create(**query_engine_args)

## STAGE 5 - Finally query the model!

In [15]:
random.seed(237)
sample_idx = random.randint(0, len(pubmed_data)-1)
sample_elm = pubmed_data[sample_idx]
pprint(sample_elm)

{'answer': ['no'],
 'context': 'Human immunodeficiency virus (HIV)-infected patients have '
            'generally been excluded from transplantation. Recent advances in '
            'the management and prognosis of these patients suggest that this '
            'policy should be reevaluated. To explore the current views of '
            'U.S. transplant centers toward transplanting asymptomatic '
            'HIV-infected patients with end-stage renal disease, a written '
            'survey was mailed to the directors of transplantation at all 248 '
            'renal transplant centers in the United States. All 148 responding '
            'centers said they require HIV testing of prospective kidney '
            'recipients, and 84% of these centers would not transplant an '
            'individual who refuses HIV testing. The vast majority of '
            'responding centers would not transplant a kidney from a cadaveric '
            '(88%) or a living donor (91%) into an asymp

In [16]:
query = sample_elm['question']

response = query_engine.query(query)

print(f'QUERY: {query}')
print(f'RESPONSE: {response}')
print(f'YES/NO: {extract_yes_no(response.response)}')
print(f'GT ANSWER: {sample_elm["answer"]}')
print(f'GT LONG ANSWER: {sample_elm["long_answer"]}')

QUERY: Should all human immunodeficiency virus-infected patients with end-stage renal disease be excluded from transplantation?
RESPONSE:  Based on the context information provided, I cannot make a direct answer to your query as it is not directly related to the study or the discharge coordinator's role. The study focused on the effectiveness of a discharge coordinator in improving the quality of discharge planning for medical ward patients, and did not investigate the eligibility of patients for transplantation. Therefore, I cannot provide an answer to your query.
However, I can tell you that the decision to exclude patients from transplantation is a complex one that depends on various factors, including the patient's medical history, current health status, and the availability of organs for transplantation. It is important to consult with a medical professional to determine the most appropriate course of action for each individual patient.
In summary, I cannot answer your query based

In [17]:
retrieved_nodes = query_engine.retriever.retrieve(query)

## Ragas evaluation

In [18]:
eval_data = {
    "question": [query],
    "answer": [response.response],
    "contexts": [[node.text for node in retrieved_nodes]],
    "ground_truths": [[sample_elm['long_answer']]],
    }
eval_obj = RagasEval(metrics=["faithfulness", "relevancy", "recall", "precision"])
eval_result = eval_obj.evaluate(eval_data)
print(eval_result)

passing column names as 'ground_truths' is deprecated and will be removed in the next version, please use 'ground_truth' instead. Note that `ground_truth` should be of type string and not Sequence[string] like `ground_truths`


Evaluating:   0%|          | 0/4 [00:00<?, ?it/s]

/fs01/home/odige/rag_bootcamp/envs/rag_pubmed_qa_env/lib/python3.10/site-packages/pydantic/main.py:1024: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.6/migration/
/fs01/home/odige/rag_bootcamp/envs/rag_pubmed_qa_env/lib/python3.10/site-packages/pydantic/main.py:1024: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.6/migration/
/fs01/home/odige/rag_bootcamp/envs/rag_pubmed_qa_env/lib/python3.10/site-packages/pydantic/main.py:1024: PydanticDeprecatedSince20: The `dict` method is deprecated; use `model_dump` instead. Deprecated in Pydantic V2.0 to be removed in V3.0. See Pydantic V2 Migration Guide at https://errors.pydantic.dev/2.6/migration/
/fs01/home/odige/rag_bootcamp/envs/rag_p

{'faithfulness': 0.5000, 'answer_relevancy': 0.0000, 'context_recall': 0.0000, 'context_precision': 0.0000}


  user_id = json.load(open(uuid_filepath))["userid"]


#### [WIP] Run evaluation on all samples