# Setup

In [4]:
from dotenv import load_dotenv

load_dotenv()

True

In [5]:
import os

connection_string = os.getenv('CONNECTION_STRING')
openai_api_key = os.getenv('OPENAI_API_KEY')
gemini_api_key = os.getenv('GOOGLE_API_KEY')

# Define global context

In [10]:
from llama_index import ServiceContext, set_global_service_context
from llama_index.callbacks import CallbackManager, LlamaDebugHandler
from llama_index.embeddings import GeminiEmbedding
from llama_index.llms import OpenAI

llm = OpenAI(model='gpt-3.5-turbo-1106', temperature=0.0, api_key=openai_api_key)
embed_model = GeminiEmbedding(api_key=gemini_api_key)
callback_manager = CallbackManager([LlamaDebugHandler(print_trace_on_end=True)])

service_context = ServiceContext.from_defaults(
    llm=llm, embed_model=embed_model, callback_manager=callback_manager
)

set_global_service_context(service_context)

# Connect to storage

In [7]:
from sqlalchemy import make_url
from llama_index.vector_stores.postgres import PGVectorStore

uri = make_url(connection_string)
vector_store = PGVectorStore.from_params(
    host=uri.host,
    port=str(uri.port),
    database=uri.database,
    user=uri.username,
    password=uri.password,
    embed_dim=768, # REMEMBER TO CHANGE THIS TO 1536 if using OpenAI Embedding Model
)

# Create retrievers

In [47]:
from llama_index import VectorStoreIndex

index = VectorStoreIndex.from_vector_store(vector_store=vector_store)
similarity_top_k = 6

**********
Trace: index_construction
**********


## Base retriever

In [48]:
retriever = index.as_retriever(
    similarity_top_k=similarity_top_k,
    vector_store_kwargs={
        'ivfflat_probes': 10,  # higher is better for recall, lower is better for speed. Default = 1
        'hnsw_ef_search': 300, # Specify the size of the dynamic candidate list for search. Default = 40
    },
)

## Recursive retriever

In [49]:
from llama_index.retrievers import RecursiveRetriever

recursive_retriever = RecursiveRetriever(
    'vector', retriever_dict={'vector': retriever}, verbose=True
)


# Fusion Retriver

## Generate new queries

In [50]:
from typing import List

from llama_index import PromptTemplate
from llama_index.llms import LLM

query_gen_prompt_template = """
You are a helpful assistant that generates multiple search queries based on a single input query.
Generate {num_queries} search queries, one on each line related to the following input query:
Query: {query}
Queries:
"""

# Should I using RAG to generate new queries based on those information?
# To avoid ambiguous queries?

def generate_queries(llm: LLM, query: str, num_queries: int = 4) -> List[str]:
    query_gen_prompt = PromptTemplate(query_gen_prompt_template)
    prompt = query_gen_prompt.format(num_queries=num_queries, query=query)
    response = llm.complete(prompt)
    queries = response.text.split('\n')

    return queries

In [34]:
queries = generate_queries(llm, "Hawaii conferences")

## Run queries

In [43]:
from typing import Dict, List, Tuple

from llama_index.retrievers import BaseRetriever
from llama_index.schema import NodeWithScore

def run_queries(queries: List[str], retrievers: List[BaseRetriever]) -> Dict[Tuple[str, int], List[NodeWithScore]]:
    results = []
    for query in queries:
        for retriever in retrievers:
            results.append(retriever.retrieve(query)) # This one should be promisable

    results_dict: Dict[Tuple[str, int], List[NodeWithScore]] = {}
    for i, (query, query_result) in enumerate(zip(queries, results)):
        results_dict[(query, i)] = query_result

    return results_dict

In [46]:
run_queries(queries, [retriever, recursive_retriever])

[1;3;34mRetrieving with query id None: 1. Best Hawaii conferences 2022
[0m[1;3;38;5;200mRetrieving text node: id 64
created_at 2024-01-04 10:26:21.324
updated_at 2024-01-05 02:33:01.038
name HN
start_date 2024-01-05 00:26:12.846
end_date 2024-01-05 00:26:13.998
cme_provider HN
venue Hanoi, Vietnam
lat 21.030384
lng 105.85531
description 
instructors HN
cme_credit 
self_assessment_module 
crawl_url 
cme_course_webpage_url 
background_image 
price 
event_type On-site meeting
is_edited False
need_map_data False
status ACCEPTED
medical_field 
points 14
event_remark 
destination_id 27
[0m[1;3;34mRetrieving with query id None: 2. Hawaii conference venues
[0m[1;3;38;5;200mRetrieving text node: id 64
created_at 2024-01-04 10:26:21.324
updated_at 2024-01-05 02:33:01.038
name HN
start_date 2024-01-05 00:26:12.846
end_date 2024-01-05 00:26:13.998
cme_provider HN
venue Hanoi, Vietnam
lat 21.030384
lng 105.85531
description 
instructors HN
cme_credit 
self_assessment_module 
crawl_url 
cme_c

{('1. Best Hawaii conferences 2022',
  0): [NodeWithScore(node=TextNode(id_='7c0b09a9-e0b3-41dc-94d8-532ce7ae1344', embedding=None, metadata={}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={<NodeRelationship.SOURCE: '1'>: RelatedNodeInfo(node_id='695115fa-f15d-473b-ac4d-d4eab50fde25', node_type=<ObjectType.DOCUMENT: '4'>, metadata={}, hash='5571456d8fc3dc1b96829f28768de1d98de62cf74e8cff6805554ddf8521ce7e'), <NodeRelationship.PREVIOUS: '2'>: RelatedNodeInfo(node_id='fb48f079-a169-4ea9-9108-382e6cbd3761', node_type=<ObjectType.TEXT: '1'>, metadata={}, hash='4d638fd299a52438470f6946a6b9169319c14aaee8e017650008e236757150e9'), <NodeRelationship.NEXT: '3'>: RelatedNodeInfo(node_id='44967a0d-f90d-4a02-9220-377758c68bb9', node_type=<ObjectType.TEXT: '1'>, metadata={}, hash='8b61eebcc4204893d123fc3ef92433ab455a240b9def81f867f92b4a96965172')}, hash='44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a', text='id 64\ncreated_at 2024-01-04 10:26:21.324

## Fuse results

In [80]:
from typing import Dict, List, Tuple

from llama_index.schema import NodeWithScore

# k: control the impact of outlier rankings
def fuse_results(
    results_dict: Dict[Tuple[str, int], List[NodeWithScore]], similarity_top_k: int, k: float = 60.0
) -> List[NodeWithScore]:
    fused_scores: Dict[str, float] = {}
    text_to_node: Dict[str, NodeWithScore] = {}

    for node_with_scores in results_dict.values():
        for rank, node_with_score in enumerate(sorted(node_with_scores, key=lambda x: x.score or 0.0, reverse=True)):
            text = node_with_score.get_content()
            text_to_node[text] = node_with_score

            if text not in fused_scores:
                fused_scores[text] = 0.0
            fused_scores[text] = 1.0 / (rank + k)

    reranked_results = dict(sorted(fused_scores.items(), key=lambda x: x[1], reverse=True))

    reranked_nodes: List[NodeWithScore] = []
    for text, score in reranked_results.items():
        reranked_nodes.append(text_to_node[text])
        reranked_nodes[-1].score = score

    return reranked_nodes[:similarity_top_k]

## Fusion Retriever

In [57]:
from typing import List, Optional

from llama_index import QueryBundle
from llama_index.callbacks import CallbackManager
from llama_index.llms import LLM
from llama_index.retrievers import BaseRetriever
from llama_index.schema import NodeWithScore

class FusionRetriever(BaseRetriever):
    def __init__(
        self,
        llm: LLM,
        retrievers: List[BaseRetriever],
        similarity_top_k: int,
        k: float = 60.0,
        num_queries: int = 4,
        callback_manager: Optional[CallbackManager] = None,
    ) -> None:
        self._callback_manager = callback_manager
        self._llm = llm
        self._num_queries = num_queries
        self._k = k
        self._similarity_top_k = similarity_top_k
        self._retrievers = retrievers
        super().__init__(callback_manager)

    def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
        queries = generate_queries(llm, query_bundle, self._num_queries)
        results = run_queries(queries, self._retrievers)
        fused_results = fuse_results(results, self._similarity_top_k, self._k)

        return fused_results


# MAIN

In [65]:
fusion_retriever = FusionRetriever(
    llm=llm,
    retrievers=[retriever, recursive_retriever],
    similarity_top_k=similarity_top_k
)

In [67]:
from llama_index import get_response_synthesizer
from llama_index.query_engine import RetrieverQueryEngine
from llama_index.response_synthesizers import ResponseMode

query_engine = RetrieverQueryEngine.from_args(
    retriever=fusion_retriever,
    response_synthesizer=get_response_synthesizer(
        callback_manager=callback_manager,
        streaming=True,
        verbose=True,
    ),
    response_mode=ResponseMode.NO_TEXT,
    streaming=True,
)

In [68]:
query = "Radiology conference in exotic places, >3 days and at least 20 points. Return in array of id only."

In [81]:
response = query_engine.query(query)

[1;3;34mRetrieving with query id None: 1. "Radiology conference exotic locations duration >3 days"
[0m[1;3;38;5;200mRetrieving text node: id 29
created_at 2024-01-04 09:38:10.401
updated_at 2024-01-05 03:34:02.707
name 2023 Top 3 Differentials in Radiology
start_date 2023-04-24 00:00:00.000
end_date 2023-04-28 00:00:00.000
cme_provider American Osteopathic College of Radiology ›, 119 East Second Street, Milan, Missouri (MO)  63556-4011, United States
venue The Wigwam, 300 East Wigwam Blvd, Litchfield Park, Arizona (AZ), United States
lat 33.495228
lng -112.355125
description The 2023 “Top 3” Conference, sponsored by the American Osteopathic College of Radiology, is intended for practicing radiologists and radiology trainees as a comprehensive and practical case-based radiology review.  All subspecialties and imaging modalities will be covered with lectures divided into the following sections: neuroimaging, musculoskeletal imaging, gastrointestinal imaging, genitourinary imaging, ult

In [82]:
print(response.get_formatted_sources())

> Source (Node id: beef315e-56e6-4efa-8876-5a42cc246500): id 29
created_at 2024-01-04 09:38:10.401
updated_at 2024-01-05 03:34:02.707
name 2023 Top 3 Diffe...

> Source (Node id: 7c0b09a9-e0b3-41dc-94d8-532ce7ae1344): id 64
created_at 2024-01-04 10:26:21.324
updated_at 2024-01-05 02:33:01.038
name HN
start_date 20...

> Source (Node id: dfa5270d-00b5-4731-a416-2622961501ce): id 59
created_at 2024-01-04 09:38:12.491
updated_at 2024-01-05 03:33:42.889
name Top 3 Differenti...

> Source (Node id: 44967a0d-f90d-4a02-9220-377758c68bb9): id 66
created_at 2024-01-04 10:33:43.029
updated_at 2024-01-05 02:33:02.300
name HN
start_date 20...

> Source (Node id: bc557a50-bafd-43bd-8957-57ea2daa8196): id 22
created_at 2024-01-04 09:38:10.271
updated_at 2024-01-05 03:34:02.671
name Advanced Imaging...

> Source (Node id: 8f349db8-7814-4a96-a593-98de715af3be): id 52
created_at 2024-01-04 09:38:12.400
updated_at 2024-01-05 03:33:42.649
name Multi-Modality S...


In [85]:
response.source_nodes

[NodeWithScore(node=TextNode(id_='beef315e-56e6-4efa-8876-5a42cc246500', embedding=None, metadata={}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={<NodeRelationship.SOURCE: '1'>: RelatedNodeInfo(node_id='695115fa-f15d-473b-ac4d-d4eab50fde25', node_type=<ObjectType.DOCUMENT: '4'>, metadata={}, hash='5571456d8fc3dc1b96829f28768de1d98de62cf74e8cff6805554ddf8521ce7e'), <NodeRelationship.PREVIOUS: '2'>: RelatedNodeInfo(node_id='9452773b-7968-4eb4-8d8c-2885dab2b221', node_type=<ObjectType.TEXT: '1'>, metadata={}, hash='edf2a299ed718a9228d10dfdfb2402099d73a833076e27cf3609075efd38d910'), <NodeRelationship.NEXT: '3'>: RelatedNodeInfo(node_id='68144552-a0b9-4971-a50d-89cbef1f0290', node_type=<ObjectType.TEXT: '1'>, metadata={}, hash='ac5b891f5452785563ce3a891b8317fbc1b31067ba63c7d48c66e6c03f14323a')}, hash='44136fa355b3678a1146ad16f7e8649e94fb4fc21fe77e8310c060f61caaff8a', text='id 29\ncreated_at 2024-01-04 09:38:10.401\nupdated_at 2024-01-05 03:34:02.707\nname 

In [83]:
response.print_response_stream()

[22]