In [1]:
from qdrant_client import QdrantClient, models
from FlagEmbedding import BGEM3FlagModel
from langchain_groq import ChatGroq
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langchain_core.documents import Document
from utils.utils import convert_defaultdict, format_docs

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from dotenv import load_dotenv
load_dotenv()

True

In [3]:
client = QdrantClient("http://localhost:6333")
llm = ChatGroq(model="llama3-70b-8192")
embeddings = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)

Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 15363.75it/s]
  colbert_state_dict = torch.load(os.path.join(model_dir, 'colbert_linear.pt'), map_location='cpu')
  sparse_state_dict = torch.load(os.path.join(model_dir, 'sparse_linear.pt'), map_location='cpu')


In [4]:
prompt = ChatPromptTemplate.from_template("""Answer the question based on the provided context only. Try your best to provide the most accurate response.
<context>
{context}
</context>

Question: {query}
""")

chain = prompt | llm | StrOutputParser()

In [5]:
def retrieve(query, embeddings, client):
    res = embeddings.encode([query], return_sparse=True, return_colbert_vecs=True)
    result = client.query_points(
        "semantic_summary_vectorstore",
        prefetch=[
            models.Prefetch(
                query=res['dense_vecs'][0],
                using="dense",
                limit=20
            ),
            models.Prefetch(
                query=models.SparseVector(**convert_defaultdict(res['lexical_weights'][0])),
                using="sparse",
                limit=20
            ),
            models.Prefetch(
                query=res['colbert_vecs'][0],
                using='colbert',
                limit=20
            )
        ],
        query=models.FusionQuery(
            fusion=models.Fusion.RRF,
        ),
        limit=10
    )

    relevant_docs = []
    for point in result.points:
        doc = client.scroll(
            collection_name="semantic_original",
            scroll_filter=models.Filter(
                must=[
                    models.FieldCondition(
                        key="doc_id",
                        match=models.MatchValue(value=point.id)
                    )
                ]
            )
        )

        temp_payload = doc[0][0].payload
        res_doc = Document(page_content=temp_payload['page_content'], metadata={'source':temp_payload['source'], 'doc_id': temp_payload['doc_id'], 'title': temp_payload['title']})
        relevant_docs.append(res_doc)
    
    
    return relevant_docs

In [6]:
query = "I usually sleep at 2 AM. Is it bad for my health?"

In [10]:
relevant_docs = retrieve(query, embeddings=embeddings, client = client)
context = format_docs(relevant_docs[:5])

In [11]:
response = chain.invoke({"context": context, "query": query})