In [15]:
'''Load Documents and store within a ChromaDB vector DB following the MulitHop RAG example'''

import chromadb
import importlib
JSONReader = importlib.import_module('submodules.MultiHop-RAG.util').JSONReader
from llama_index.vector_stores.chroma import ChromaVectorStore
from llama_index.storage import StorageContext
from llama_index.text_splitter import SentenceSplitter
from llama_index.extractors import BaseExtractor
from llama_index.ingestion import IngestionPipeline
from llama_index.embeddings import HuggingFaceEmbedding
from llama_index.llms import  OpenAI
from llama_index import set_global_service_context, PromptHelper, ServiceContext, VectorStoreIndex
from typing import List, Dict


In [2]:
# initialize client, setting path to save data
db = chromadb.PersistentClient(path="./chroma_db")

# create collection
chroma_collection = db.get_or_create_collection("quickstart")

# assign chroma as the vector_store to the context
vector_store = ChromaVectorStore(chroma_collection=chroma_collection)
storage_context = StorageContext.from_defaults(vector_store=vector_store)


In [3]:
# read data from corpus
reader = JSONReader()
data = reader.load_data('submodules/MultiHop_RAG/dataset/corpus.json')

In [4]:
class CustomExtractor(BaseExtractor):
    async def aextract(self, nodes) -> List[Dict]:
        metadata_list = [
            {
                "title": (
                    node.metadata["title"]
                ),
                "source": (
                    node.metadata["source"]
                ),      
                "published_at": (
                    node.metadata["published_at"]
                )
            }
            for node in nodes
        ]
        return metadata_list

In [5]:
# Parse inputs
text_splitter = SentenceSplitter(chunk_size=256)

transformations = [text_splitter,CustomExtractor()] 
pipeline = IngestionPipeline(transformations=transformations)
nodes = await pipeline.arun(documents=data)

In [6]:
# Create Index
embed_model = HuggingFaceEmbedding(model_name='BAAI/llm-embedder', trust_remote_code=True)
llm = OpenAI(model='gpt-3.5-turbo-1106', temperature=0, max_tokens=2048)
prompt_helper = PromptHelper(
    context_window=2048,
    num_output=256,
    chunk_overlap_ratio=0.1,
    chunk_size_limit=None,
)
service_context = ServiceContext.from_defaults(
    llm=llm,
    embed_model=embed_model,
    text_splitter=text_splitter,
    prompt_helper=prompt_helper,
)
set_global_service_context(service_context)

index = VectorStoreIndex(nodes, show_progress=True, storage_context=storage_context)
print('Finish Indexing...')

  from .autonotebook import tqdm as notebook_tqdm
Generating embeddings: 100%|██████████| 2048/2048 [00:10<00:00, 199.82it/s]
Generating embeddings: 100%|██████████| 2048/2048 [00:10<00:00, 204.28it/s]
Generating embeddings: 100%|██████████| 2048/2048 [00:10<00:00, 202.44it/s]
Generating embeddings: 100%|██████████| 2048/2048 [00:10<00:00, 204.51it/s]
Generating embeddings: 100%|██████████| 2048/2048 [00:09<00:00, 205.99it/s]
Generating embeddings: 100%|██████████| 2048/2048 [00:10<00:00, 199.56it/s]
Generating embeddings: 100%|██████████| 2048/2048 [00:10<00:00, 204.06it/s]
Generating embeddings: 100%|██████████| 2048/2048 [00:09<00:00, 206.72it/s]
Generating embeddings: 100%|██████████| 2048/2048 [00:09<00:00, 207.75it/s]
Generating embeddings: 100%|██████████| 2048/2048 [00:10<00:00, 201.89it/s]
Generating embeddings: 100%|██████████| 2048/2048 [00:09<00:00, 206.21it/s]
Generating embeddings: 100%|██████████| 2048/2048 [00:10<00:00, 204.78it/s]
Generating embeddings: 100%|██████████

Finish Indexing...


In [11]:
# Retrieving from DB
query_engine = index.as_retriever()

In [13]:
query = "What is the capital of France?"
query_engine.retrieve(query)

[NodeWithScore(node=TextNode(id_='77f88542-6f3b-4856-a18a-648565c86c80', embedding=None, metadata={'title': 'There’s something going on with AI startups in France', 'published_at': '2023-11-09T14:51:44+00:00', 'source': 'TechCrunch'}, excluded_embed_metadata_keys=[], excluded_llm_metadata_keys=[], relationships={<NodeRelationship.SOURCE: '1'>: RelatedNodeInfo(node_id='135ad82d-3025-47e1-b10e-29ed3415bf7a', node_type=<ObjectType.DOCUMENT: '4'>, metadata={'title': 'There’s something going on with AI startups in France', 'published_at': '2023-11-09T14:51:44+00:00', 'source': 'TechCrunch'}, hash='254916260b14101b1935eb1ba88b838d7cecc57ed62f5ce8b480f7879beac393'), <NodeRelationship.PREVIOUS: '2'>: RelatedNodeInfo(node_id='f4e88771-6568-4864-b5a3-a9e572057bbc', node_type=<ObjectType.TEXT: '1'>, metadata={'title': 'There’s something going on with AI startups in France', 'published_at': '2023-11-09T14:51:44+00:00', 'source': 'TechCrunch'}, hash='fe9e0da725559fa0db39ad42275b9cf782a297eb2bddba5c

## TODO: Implement GraphRAG with NebulaGraph locally
https://docs.llamaindex.ai/en/stable/examples/query_engine/knowledge_graph_rag_query_engine.html#graph-rag