In [8]:
import json
import os
import uuid
from langchain_core.documents import Document
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings import SentenceTransformerEmbeddings
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.retrievers.multi_vector import MultiVectorRetriever
from langchain.storage import InMemoryByteStore, RedisStore

In [2]:
embedding_function = SentenceTransformerEmbeddings(model_name="BAAI/bge-m3")

  from tqdm.autonotebook import tqdm, trange


In [10]:
from langchain_community.document_loaders import JSONLoader


def metadata_func(record: dict, metadata: dict) -> dict:
    metadata["url"] = record.get("url")
    metadata["title"] = record.get("title")
    metadata["id"] = record.get("id")

    return metadata


loader = JSONLoader(
    file_path='../dataset_unique_epta.json',
    jq_schema='.data[]',
    text_content=False,
    content_key='description',
    metadata_func=metadata_func,
)

docs = loader.load()

# for doc in docs:
#     doc.page_content = f"{doc.metadata.get('title')}\n{doc.page_content}"

In [11]:
len(docs)

2880

In [21]:
with open('../save_summaries.json', 'r') as f:
    summaries = json.load(f)

In [24]:
from langchain.retrievers.contextual_compression import ContextualCompressionRetriever
from langchain_cohere import CohereRerank

# The vectorstore to use to index the child chunks
vectorstore = Chroma(collection_name="test", embedding_function=embedding_function, persist_directory='./chroma_test_summary')

base_retriever = vectorstore.as_retriever(search_kwargs={"k": 100})
compressor = CohereRerank(top_n=10, cohere_api_key='cAK7dl8djEA6qUE9y0aMsQEm3xkTyPT4cU511Y0W', model='35a0461f-d7ab-4c97-9e33-61b533b930c0-ft')
retriever = ContextualCompressionRetriever(
    base_compressor=compressor, base_retriever=base_retriever
)

In [8]:
summary_docs = [
    Document(page_content=s, metadata={id_key: doc_ids[i]})
    for i, s in enumerate(summaries)
]

In [9]:
retriever.vectorstore.add_documents(summary_docs)

['4cd343fc-3e38-44fe-bc20-97c8d5d3d12e',
 'be5bd49e-2ada-409a-bf90-df0ac5777580',
 '1fd36352-9b1d-4b53-8864-cb017a6afd6b',
 '63914a43-eab5-49cf-af07-7fe586430757',
 '77cebb1f-f7f6-4c3a-b6e7-59bd6ed33c07',
 '999ebac9-398c-4011-86b7-660fddfc6e43',
 'bc75e251-5ed5-4d59-be0c-8128d1dc9f42',
 'aaabb02d-bea6-4734-b3e8-fea1af21e153',
 'b1f39623-1795-4310-8cce-5822cf1c5aec',
 '59accd39-dcf8-4d92-85b4-02aaa8db47c8',
 '1737480e-8268-4202-a6ad-bd4e65d7d090',
 'ce4bc62e-f2dc-42b0-9ffd-dd03045ce4e6',
 '7b9ef0ec-9e11-48f8-949c-3fb1c5a90ef7',
 'be97c56c-d1b8-4229-9cd5-bcac3afc3bf1',
 '6952e4ed-2e76-4fa5-832e-5908cf0e6c2f',
 '73b42841-2c33-4dc6-8009-c53b9104cb86',
 '7bac3616-f81e-4c96-8759-a3bd120908b0',
 'fdbb78bd-95ed-4f09-9d85-c66babe8f952',
 'd31099f6-b45f-4cc8-a13c-e9ac676554e8',
 'd185109d-7807-4f42-8f9c-ea90d1f581de',
 '5d8536c1-d03e-4c11-b6d6-99d41884e080',
 '5556dd38-70c7-444d-be93-350321477679',
 '14b7c74e-6e23-4a0f-baff-0cddfd23ea17',
 'b472995d-d9bf-49cc-86d5-c7944dd05701',
 'b5cca7df-850d-

In [25]:
text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
    chunk_size=512, chunk_overlap=256
)

In [26]:
sub_docs = []
for i, doc in enumerate(docs):
    _id = doc_ids[i]
    _sub_docs = text_splitter.split_documents([doc])
    for _doc in _sub_docs:
        _doc.metadata[id_key] = _id
    sub_docs.extend(_sub_docs)

NameError: name 'doc_ids' is not defined

In [14]:
sub_docs = []
for i, doc in enumerate(docs):
    _sub_docs = text_splitter.split_documents([doc])
    sub_docs.extend(_sub_docs)

In [15]:
len(sub_docs)

8825

In [16]:
vectorstore.add_documents(sub_docs[:5000])
vectorstore.add_documents(sub_docs[5000:])

['4a67d822-df87-4c95-8801-f9aa3a6383e5',
 'f6975760-36e2-4936-83d3-0c14c3c9beb5',
 'a2b56790-2d29-4601-9a12-3668ba1a0ada',
 '27b0e3be-40c4-45fe-bbef-0926e7d582e7',
 'd78a7d0e-c164-4e9f-95d3-71108d7f6db5',
 '7f08577a-5a2d-4cd0-8de4-6fabc635547e',
 'a7b0e855-b06b-431e-b960-1276f1c8f930',
 '9bc11031-6470-4423-9bbb-44bb1a55ddd7',
 'c3dbe4c0-31fd-40d1-a1ff-7a7acb61ab1a',
 'ff26694e-5ec5-42b1-8bf1-d17d0d99a051',
 '8a18d60a-9cee-4103-8832-f22b5a5405cc',
 '53c63692-c353-4031-b4c9-1d6214011d69',
 'ef61b3b6-c6b2-4cf6-9220-abd15d18368f',
 '292e8b10-94a5-4089-9560-b1379f096a1d',
 '974dd70f-e632-4f02-86d5-44564c15af48',
 '5dec6309-c098-4987-904f-7829b4540865',
 'd245f407-51bd-4b6b-808c-4bf839e4cdcc',
 '1f0abca4-fadf-4156-be1e-203c68f6850a',
 '7792dd77-6128-4979-b1b4-20b9b975a245',
 'e63c0300-f060-48cc-939f-d3ac985eb07f',
 'f87eb4c6-4849-45bb-bb46-28adb095b476',
 '37c5cde4-c5b2-4b57-bd4f-c1725f09cc28',
 '3a135032-4a05-483a-899b-7d85accc298f',
 'b4ab21be-51f5-433c-9575-471beca8c3e9',
 'e54a7b73-8b24-

In [11]:
retriever.docstore.mset(list(zip(doc_ids, docs)))

In [12]:
def check_docs_correctness(true_context, _predicted_contexts):
    for predicted_context in _predicted_contexts:
        if true_context in predicted_context:
            return True
    return False

In [13]:
retriever.search_kwargs = {"k": 20}

In [None]:
import pandas as pd

df = pd.read_json('./test_retriever_big_id.json')
correct_in_top5 = 0
correct_present = 0
total_tests = 0

for index, row in df.iterrows():
    context = row["id"]
    questions = row["questions"]
    total_tests += len(questions)

    for question in questions:
        _q = question["question"]
        predicted_contexts_data = retriever.invoke(_q)
        predicted_contexts = [doc.metadata.get('id') for doc in predicted_contexts_data]
        
        if context in predicted_contexts:
            correct_present += 1
            
            if context in predicted_contexts[:7]:
                correct_in_top5 += 1


top5_percentage = correct_in_top5 / total_tests * 100
present_percentage = correct_present / total_tests * 100

print(f"Верный ответ присутствует в топ-10 в {present_percentage:.2f}% тестов")
print(f"Верный ответ в топ-5 в {top5_percentage:.2f}% тестов")