In [1]:
from src.configs.env_config import config
from src.services.db import chroma_service
from pathlib import Path
from langchain.retrievers.multi_query import MultiQueryRetriever
from langchain.retrievers import ContextualCompressionRetriever
from langchain.retrievers.document_compressors import FlashrankRerank
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_chroma import Chroma
from pprint import pprint
from langchain_core.prompts import PromptTemplate
from src.services.utils import (
    text_splitter_recursive_char,
    create_chunk_ids,
    json_to_documents,
)
from src.services.processors import DocumentsPreprocessing
from src.services.vectorstore import ChromaStore
from src.services.retrievers import MultiQRerankedRetriever

In [2]:
client = chroma_service()
client.heartbeat()

1744037033256746210

In [3]:
collection_name = "local_collection"

In [None]:
if client.get_collection(collection_name):
    client.delete_collection(collection_name)

collection = client.get_or_create_collection(collection_name)
collection

In [4]:
pdf_data_src = Path("_dev_nb/output_data/pdf_loader")
web_data_src = Path("_dev_nb/output_data/web_loader")

### classes


In [5]:
json_path = web_data_src / "setics_stad_docs_clean.json"
docs = json_to_documents(filename=json_path)

print(f"Got {len(docs)} documents")

Got 525 documents


In [6]:
processor = DocumentsPreprocessing()
chunks, ids = await processor(documents=docs)

print(f"Created {len(chunks)} chunks")

Created 1145 chunks


In [9]:
store = ChromaStore()

In [13]:
coll = client.get_collection(collection_name)
coll.count()

1145

In [None]:
results = collection.get(where={"$exists": "source"}, include=["metadatas"])


In [15]:
coll.get(limit=5)

{'ids': ['topology-0-c5a70271',
  'topology-1-2acf1e65',
  'topology-2-1cf93343',
  'topology-3-d267b2c6',
  'endpoint-support-context-menu-4-1aca362e'],
 'embeddings': None,
 'metadatas': [{'title': 'Topology - Setics Sttar Advanced Designer  |  User Manual - Version 2.3',
   'description': 'The Topology tab allows you to specify how Setics Sttar Advanced Designer should interpret and model the support entities: Infrastructure tab - support...',
   'source': 'https://docs.setics-sttar.com/advanced-designer-user-manual/2.3/en/topic/topology',
   'id': 'topology-0-c5a70271',
   'language': 'en'},
  {'title': 'Topology - Setics Sttar Advanced Designer  |  User Manual - Version 2.3',
   'language': 'en',
   'id': 'topology-1-2acf1e65',
   'source': 'https://docs.setics-sttar.com/advanced-designer-user-manual/2.3/en/topic/topology',
   'description': 'The Topology tab allows you to specify how Setics Sttar Advanced Designer should interpret and model the support entities: Infrastructure ta

In [11]:
sources = await store._get_source_tracker(collection_name=collection_name)
sources

set()

In [None]:
count = await store.add_documents(
    documents=chunks, ids=ids, collection_name=collection_name
)

print(f"Added {count} documents")

In [None]:
client.get_collection(collection_name).count()

In [None]:
added_count, docs_replaced, sources_updated = await store.replace_documents(
    documents=chunks, ids=ids, collection_name=collection_name
)

print(f"Added {added_count} documents")
print(f"Replaced {docs_replaced} documents")
print(f"Updated {sources_updated} documents")

In [None]:
client.get_collection(collection_name).count()

In [None]:
query = "What is the purpose of the Advanced Designer?"

In [None]:
retriever = MultiQRerankedRetriever()
results = await retriever(query=query, collection_name=collection_name)

In [None]:
for result in results:
    pprint(f"Metadata: {result.metadata}")
    print("-" * 80)

In [None]:
print(results[0])

### prototyping


In [None]:
llm = ChatOpenAI(
    model="gpt-4o-mini",
    temperature=0,
    api_key=config.OPENAI_API_KEY,
    max_tokens=1000,
)

In [None]:
pdf_json_path = pdf_data_src / "xplore_pdf_3_clean.json"
pdf_docs = json_to_documents(filename=pdf_json_path)
len(pdf_docs)

In [None]:
web_json_path = web_data_src / "setics_stad_docs_clean.json"
web_docs = json_to_documents(filename=web_json_path)
len(web_docs)

In [None]:
web_json_path_2 = web_data_src / "setics_stpl_docs_clean.json"
web_docs_2 = json_to_documents(filename=web_json_path_2)
len(web_docs_2)

In [None]:
img_json_path = web_data_src / "setics_stad_img_docs.json"
img_docs = json_to_documents(filename=img_json_path)
len(img_docs)

In [None]:
# for i, doc in enumerate(pdf_docs):
#     print(f"Doc {i}: length {len(doc.page_content)}")

In [None]:
pdf_chunks = text_splitter_recursive_char(pdf_docs)
len(pdf_chunks)

In [None]:
# for i, doc in enumerate(web_docs):
#     print(f"Doc {i}: length {len(doc.page_content)}")

In [None]:
web_chunks = text_splitter_recursive_char(web_docs)
len(web_chunks)

In [None]:
web_chunks_2 = text_splitter_recursive_char(web_docs_2)
len(web_chunks_2)

In [None]:
# for i, doc in enumerate(web_chunks):
#     print(f"Doc {i}: length {len(doc.page_content)}")

In [None]:
pdf_chunks_ids = create_chunk_ids(pdf_chunks)
web_chunks_ids = create_chunk_ids(web_chunks)
web_chunks_2_ids = create_chunk_ids(web_chunks_2)

img_ids = [i.metadata["id"] for i in img_docs]

print(
    pdf_chunks_ids[:2], web_chunks_ids[:2], web_chunks_2_ids[:2], img_ids[:2], sep="\n"
)

In [None]:
# pprint(web_chunks[10].metadata)

In [None]:
openai_embedding = OpenAIEmbeddings(
    model="text-embedding-3-large", openai_api_key=config.OPENAI_API_KEY
)

In [None]:
vector_store = Chroma(
    client=client,
    collection_name=collection.name,
    embedding_function=openai_embedding,
)

In [None]:
documents_with_ids = [
    (web_chunks, web_chunks_ids),
    (pdf_chunks, pdf_chunks_ids),
    (web_chunks_2, web_chunks_2_ids),
    (img_docs, img_ids),
]

for docs, ids in documents_with_ids:
    vector_store.add_documents(documents=docs, ids=ids)

In [None]:
collection.count()

In [None]:
# retriever = vector_store.as_retriever(
#     search_type="mmr",
#     # search_type="similarity_score_threshold",
#     # search_kwargs={"k": 3, "score_threshold": 0.5},
#     search_kwargs={"k": 3},
# )

In [None]:
# retriever = MultiQueryRetriever.from_llm(retriever=vector_store.as_retriever(), llm=llm)

# retriever = SelfQueryRetriever.from_llm(
#     llm=llm,
#     vectorstore=vector_store,
# )

In [None]:
# wrapping base retriever with FlashRank compressor

# create MultiQueryRetriever
base_retriever = vector_store.as_retriever(search_kwargs={"k": 10})
multi_query_retriever = MultiQueryRetriever.from_llm(retriever=base_retriever, llm=llm)

# add reranker on top
compressor = FlashrankRerank(top_n=3)
retriever = ContextualCompressionRetriever(
    base_compressor=compressor, base_retriever=multi_query_retriever
)

In [None]:
# query = "What is the installation requirement for flower pot?"
# query = "What can you tell me about Setics Sttar?"
# query = "In Sttar, how to add a new infrastructure layer?"
# query = "In Sttar, how to manually split some lines in the interface, in the map view?"
# query = " In sttar, how can we manage the support properties, for the reusable infrastructure?"
query = " What is the differences between the advanced designer and the planner?"

In [None]:
# # Set logging for the queries
# import logging

# logging.basicConfig()
# logging.getLogger("langchain.retrievers.multi_query").setLevel(logging.INFO)

In [None]:
results = retriever.invoke(query)
results

In [None]:
# for result in results:
#     print(result.page_content)
#     print("\n\n===\n\n")

### chatbot


In [None]:
llm = ChatOpenAI(
    model="gpt-4o-mini",
    temperature=0,
    api_key=config.OPENAI_API_KEY,
    # max_tokens=1000,
)

In [None]:
template = """You are an assistant for question-answering tasks.
Use the following pieces of retrieved context to answer the question.
If you don't know the answer, just say that you don't know.
Use three sentences maximum and keep the answer concise.
Context: {context}
Question: {question}
Answer:"""

prompt = PromptTemplate.from_template(template)

docs_content = "\n\n".join(doc.page_content for doc in results)

messages = prompt.invoke({"question": query, "context": docs_content})
response = llm.invoke(messages)

In [None]:
pprint(response.content)

In [None]:
collection.get(
    ids=results[0].metadata["id"], include=["documents", "metadatas", "embeddings"]
)