In [1]:
from utils.semantic_chunking import create_semantic_chunks_from_directory_with_overlap, reformat_semantic_chunks_with_overlap
from utils.utils import batch_iterator, convert_defaultdict
from qdrant_client import QdrantClient, models

import torch
from semantic_encoder import BGEM3FlagEmbedEncoder
from FlagEmbedding import BGEM3FlagModel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
client = QdrantClient("http://localhost:6333")

In [3]:
#Creating vectorstore (if not exist)
if not client.collection_exists(collection_name="semantic_vectorstore"):
    client.create_collection(
        "semantic_vectorstore",
        vectors_config={
            "dense": models.VectorParams(
                size=1024,
                distance=models.Distance.COSINE
            ),
            "colbert": models.VectorParams(
                size=1024,
                distance=models.Distance.COSINE,
                multivector_config=models.MultiVectorConfig(
                    comparator=models.MultiVectorComparator.MAX_SIM,
                )
            ),
        },
        sparse_vectors_config={
            "sparse": models.SparseVectorParams()
        }
    )

In [4]:
embeddings = BGEM3FlagEmbedEncoder()

Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 80867.04it/s]
  colbert_state_dict = torch.load(os.path.join(model_dir, 'colbert_linear.pt'), map_location='cpu')
  sparse_state_dict = torch.load(os.path.join(model_dir, 'sparse_linear.pt'), map_location='cpu')


In [5]:
chunks = create_semantic_chunks_from_directory_with_overlap(dir='extracted/TÁC HẠI', encoder=embeddings, overlap=2)

[32m2024-09-25 16:43:22 INFO semantic_chunkers.utils.logger Single document exceeds the maximum token limit of 700. Splitting to sentences before semantically merging.[0m
100%|██████████| 7/7 [00:12<00:00,  1.72s/it]
[32m2024-09-25 16:43:34 INFO semantic_chunkers.utils.logger Single document exceeds the maximum token limit of 700. Splitting to sentences before semantically merging.[0m
100%|██████████| 10/10 [00:16<00:00,  1.67s/it]
[32m2024-09-25 16:43:51 INFO semantic_chunkers.utils.logger Single document exceeds the maximum token limit of 700. Splitting to sentences before semantically merging.[0m
100%|██████████| 8/8 [00:17<00:00,  2.19s/it]
[32m2024-09-25 16:44:08 INFO semantic_chunkers.utils.logger Single document exceeds the maximum token limit of 700. Splitting to sentences before semantically merging.[0m
100%|██████████| 20/20 [00:38<00:00,  1.90s/it]
[32m2024-09-25 16:44:47 INFO semantic_chunkers.utils.logger Single document exceeds the maximum token limit of 700. Spl

In [6]:
len(chunks)

1607

In [7]:
embeddings = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)
torch.cuda.empty_cache()

Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 12078.05it/s]
  colbert_state_dict = torch.load(os.path.join(model_dir, 'colbert_linear.pt'), map_location='cpu')
  sparse_state_dict = torch.load(os.path.join(model_dir, 'sparse_linear.pt'), map_location='cpu')


In [None]:
batch_size = 8
for batch in batch_iterator(chunks, batch_size):
    text = [(chunk.metadata['Title'] + '\n' + chunk.metadata['Summary'] + '\n\n' + 'Chunk of Text:\n' + chunk.page_content) for chunk in batch]
    res = embeddings.encode(text, return_sparse=True, return_colbert_vecs=True)

    for i, _ in enumerate(batch):
        doc_id = batch[i].metadata['doc_id']
        title = batch[i].metadata['Title']
        summary = batch[i].metadata['Summary']
        source = batch[i].metadata['source']
        content = batch[i].page_content
        try:
            client.upload_points(
                "semantic_vectorstore",
                points = [
                    models.PointStruct(
                        id = doc_id,
                        vector = {
                            "dense": res['dense_vecs'][i].tolist(),
                            "colbert": res['colbert_vecs'][i].tolist(),
                            "sparse": convert_defaultdict(res['lexical_weights'][i])
                        },
                        payload = {
                            "doc_id": doc_id,
                            "title": title,
                            "summary": summary,
                            "source": source,
                            "content": content,
                        }

                    )
                ],
                batch_size=1
            )
        except:
            print(f"Error when uploading - {doc_id}")
            continue


In [90]:
embeddings = None
torch.cuda.empty_cache()