In [1]:
import argparse
from utils.chunking import create_chunks_from_directory, create_summaries_from_chunks
from utils.chunking import summary_chain
from utils.utils import batch_iterator, convert_defaultdict
from qdrant_client import QdrantClient, models
from FlagEmbedding import BGEM3FlagModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
embeddings = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)

Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 43888.78it/s]
  colbert_state_dict = torch.load(os.path.join(model_dir, 'colbert_linear.pt'), map_location='cpu')
  sparse_state_dict = torch.load(os.path.join(model_dir, 'sparse_linear.pt'), map_location='cpu')


In [3]:
chunks = create_chunks_from_directory('extracted/TÁC HẠI')

In [24]:
from tqdm import tqdm
from langchain_core.documents import Document

In [25]:
summaries = []
for chunk in tqdm(chunks[250:]):
    title, summary, unique_id = chunk.metadata['Title'], chunk.metadata['Summary'], chunk.metadata['doc_id']
    res = title + '\n' + summary + '\n\n' + chunk.page_content
    chunk_summary = summary_chain.invoke(res)
    chunk_summary_document = Document(page_content=chunk_summary, metadata={"doc_id": unique_id, "title": title})
    summaries.append(chunk_summary_document)

100%|██████████| 322/322 [1:15:25<00:00, 14.05s/it]


In [26]:
batch_size = 8
for batch in batch_iterator(summaries, batch_size):
    text = [summary.page_content for summary in batch]
    res = embeddings.encode(text, return_sparse=True, return_colbert_vecs=True)

    for i, _ in enumerate(batch):
        doc_id = batch[i].metadata['doc_id']
        title = batch[i].metadata['title']
        content = batch[i].page_content
        try:
            client.upload_points(
                "summary",
                points = [
                    models.PointStruct(
                        id = doc_id,
                        vector = {
                            "dense": res['dense_vecs'][i].tolist(),
                            "colbert": res['colbert_vecs'][i].tolist(),
                            "sparse": convert_defaultdict(res['lexical_weights'][i])
                        },
                        payload = {
                            "doc_id": doc_id,
                            "title": title,
                            "content": content,
                        }

                    )
                ],
                batch_size=1
            )
        except:
            print(f"Error when uploading - {doc_id}")
            continue

In [27]:
client.count("summary")

CountResult(count=572)

In [23]:
filter_res = client.scroll(
    collection_name="original",
    scroll_filter=models.Filter(
        must=[
            models.FieldCondition(
                key="doc_id",
                match=models.MatchValue(value='cef9d1d3-77d5-4600-ab04-6bd4404e9a64')
            )
        ]
    )
)