### Import libraries, custom classes and functions

In [1]:
import random
from pprint import pprint

from llama_index import ServiceContext, set_global_service_context, set_global_handler
from llama_index.text_splitter import SentenceSplitter

from task_dataset import PubMedQATaskDataset
from rag_utils import (
    DocumentReader, RAGEmbedding, RAGLLM, RAGIndex, RAGQueryEngine, 
    extract_yes_no, evaluate, validate_rag_cfg
    )

### Set RAG configuration

In [2]:
rag_cfg = {
    # Node parser config
    "chunk_size": 256,
    "chunk_overlap": 0,

    # Embedding model config
    "embed_model_type": "hf",
    "embed_model_name": "BAAI/bge-base-en-v1.5",

    # LLM config
    "llm_type": "local",
    "llm_name": "Llama-2-7b-chat-hf",
    "max_new_tokens": 256,
    "temperature": 1.0,
    "top_p": 1.0,
    "top_k": 50,
    "do_sample": False,

    # Vector DB config
    "vector_db_type": "weaviate", # "chromadb", "weaviate"
    "vector_db_name": "Pubmed_QA",
    # MODIFY THIS
    "weaviate_url": "https://vector-rag-lab-xsxuylwh.weaviate.network",

    # Retriever and query config
    "retriever_type": "vector_index",
    "retriever_similarity_top_k": 3,
    "query_mode": "hybrid", # "default", "hybrid"
    "hybrid_search_alpha": 1.0, # float from 0.0 (sparse search - bm25) to 1.0 (vector search)
    "response_mode": "compact",
}

## STAGE 0 - Preliminary config checks

In [3]:
pprint(rag_cfg)
validate_rag_cfg(rag_cfg)

{'chunk_overlap': 0,
 'chunk_size': 256,
 'do_sample': False,
 'embed_model_name': 'BAAI/bge-base-en-v1.5',
 'embed_model_type': 'hf',
 'hybrid_search_alpha': 1.0,
 'llm_name': 'Llama-2-7b-chat-hf',
 'llm_type': 'local',
 'max_new_tokens': 256,
 'query_mode': 'hybrid',
 'response_mode': 'compact',
 'retriever_similarity_top_k': 3,
 'retriever_type': 'vector_index',
 'temperature': 1.0,
 'top_k': 50,
 'top_p': 1.0,
 'vector_db_name': 'Pubmed_QA',
 'vector_db_type': 'weaviate',
 'weaviate_url': 'https://vector-rag-lab-xsxuylwh.weaviate.network'}


## STAGE 1 - Load dataset and documents

#### 1. Load PubMed QA dataset
PubMedQA ([github](https://github.com/pubmedqa/pubmedqa)) is a biomedical question answering dataset. Each instance consists of a question, a context (extracted from PubMed abstracts), a long answer and a yes/no/maybe answer. We make use of the test split of [this](https://huggingface.co/datasets/bigbio/pubmed_qa) huggingface dataset for this notebook.

**The context for each instance is stored as a text file** (referred to as documents), to align the task as a standard RAG use-case.

In [4]:
print('Loading PubMed QA data ...')
pubmed_data = PubMedQATaskDataset('bigbio/pubmed_qa')
print(f"Loaded data size: {len(pubmed_data)}")
# pubmed_data.mock_knowledge_base(output_dir='./data', one_file_per_sample=True)

Loading PubMed QA data ...


Preparing data: 100%|██████████| 500/500 [00:00<00:00, 1373.71it/s]

Loaded data size: 500





#### 2. Load documents

In [5]:
print('Loading documents ...')
reader = DocumentReader(input_dir="./data/pubmed_doc")
docs = reader.load_data()
print(f'No. of documents loaded: {len(docs)}')

Loading documents ...
No. of documents loaded: 500


## STAGE 2 - Load node parser, embedding, LLM and set service context

#### 1. Load node parser to split documents into smaller chunks

In [6]:
print('Loading node parser ...')
node_parser = SentenceSplitter(chunk_size=rag_cfg['chunk_size'], chunk_overlap=rag_cfg['chunk_overlap'])
# nodes = node_parser.get_nodes_from_documents(docs)

Loading node parser ...


[nltk_data] Downloading package punkt to /tmp/llama_index...
[nltk_data]   Unzipping tokenizers/punkt.zip.


#### 2. Load embedding model

In [7]:
embed_model = RAGEmbedding(model_type=rag_cfg['embed_model_type'], model_name=rag_cfg['embed_model_name']).load_model()

Loading hf embedding model ...


config.json:   0%|          | 0.00/777 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/438M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/366 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

#### 3. Load LLM for generation

In [8]:
llm = RAGLLM(rag_cfg['llm_type'], rag_cfg['llm_name']).load_model(**rag_cfg)

Loading local LLM model ...


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

#### 4. Use service context to set the node parser, embedding model, LLM, etc.

In [9]:
service_context = ServiceContext.from_defaults(
    node_parser=node_parser,
    embed_model=embed_model,
    llm=llm,
)
# Set it globally to avoid passing it to every class, this sets it even for rag_utils.py
set_global_service_context(service_context)

## STAGE 3 - Create index using the appropriate vector store

In [10]:
index = RAGIndex(db_type=rag_cfg['vector_db_type'], db_name=rag_cfg['vector_db_name'])\
    .create_index(docs, weaviate_url=rag_cfg["weaviate_url"])

Loading index from ./.weaviate_index_store/ ...


## STAGE 4 - Build query engine

Now build a query engine using *retriever* and *response_synthesizer*

In [12]:
query_engine = RAGQueryEngine(
    retriever_type=rag_cfg['retriever_type'], vector_index=index, llm_model_name=rag_cfg['llm_name']).create(
        similarity_top_k=rag_cfg['retriever_similarity_top_k'], response_mode=rag_cfg['response_mode'], 
        query_mode=rag_cfg["query_mode"], hybrid_search_alpha=rag_cfg["hybrid_search_alpha"])

## STAGE 5 - Finally query the model!

In [13]:
random.seed(41)
sample_idx = random.randint(0, len(pubmed_data)-1)
sample_elm = pubmed_data[sample_idx]
pprint(sample_elm)

{'answer': ['yes'],
 'context': 'There is controversy surrounding the optimal management of the '
            'testicular remnant associated with the vanishing testes syndrome. '
            'Some urologists advocate the need for surgical exploration, '
            'whereas others believe this is unnecessary. These differing '
            'opinions are based on the variable reports of viable germ cell '
            'elements found within the testicular remnants. To better '
            'understand the pathology associated with this syndrome and the '
            'need for surgical management, we reviewed our experience '
            'regarding the incidence of viable germ cell elements within the '
            'testicular remnant. An institutional review board-approved, '
            'retrospective review was performed of all consecutive patients '
            'undergoing exploration for a nonpalpable testis at Eastern '
            'Virginia Medical School and Geisinger Medical Center

In [14]:
query = sample_elm['question']

response = query_engine.query(query)

print(f'QUERY: {query}')
print(f'RESPONSE: {response}')
print(f'YES/NO: {extract_yes_no(response.response)}')
print(f'GT ANSWER: {sample_elm["answer"]}')
print(f'GT LONG ANSWER: {sample_elm["long_answer"]}')

QUERY: Histologic evaluation of the testicular remnant associated with the vanishing testes syndrome: is surgical management necessary?
RESPONSE:  Based on the context information provided, I would answer the query as follows:
Yes.
The study suggests that histologic evaluation of the testicular remnant associated with the vanishing testes syndrome may be necessary for several reasons:
1. Variable reports of viable germ cell elements: The study found that 14% of the testicular remnants had viable germ cell elements, which suggests that histologic evaluation may be necessary to determine the presence of these elements and the potential for fertility.
2. Controversy surrounding management: The study highlights the controversy among urologists regarding the need for surgical management of the testicular remnant, with some advocating for exploration and others believing it to be unnecessary. Histologic evaluation may help to resolve this controversy by providing a more accurate assessment o

#### Run evaluation on all samples

In [15]:
# result_dict = evaluate(pubmed_data, query_engine)
# output_dict = {
#     "num_samples": len(pubmed_data),
#     "config": rag_cfg,
#     "result": result_dict,
# }
# pprint(output_dict)

In [16]:
# {'config': {'chunk_overlap': 0,
#         'chunk_size': 256,
#         'do_sample': False,
#         'embed_model_name': 'BAAI/bge-base-en-v1.5',
#         'embed_model_type': 'hf',
#         'hybrid_search_alpha': 1.0,
#         'llm_name': 'Llama-2-7b-chat-hf',
#         'llm_type': 'local',
#         'max_new_tokens': 256,
#         'query_mode': 'hybrid',
#         'response_mode': 'compact',
#         'retriever_similarity_top_k': 3,
#         'retriever_type': 'vector_index',
#         'temperature': 1.0,
#         'top_k': 50,
#         'top_p': 1.0,
#         'vector_db_name': 'Pubmed_QA',
#         'vector_db_type': 'weaviate',
#         'weaviate_url': 'https://vector-rag-lab-xsxuylwh.weaviate.network'},
# 'num_samples': 500,
# 'result': {'acc': 0.666, 'retriever_acc': 0.994}}
# same as above for {'chunk_overlap': 32, 'chunk_size': 128,}