In [1]:
from utils.semantic_chunking import reformat_semantic_chunks_with_overlap, create_semantic_chunks_from_directory_with_overlap
from utils.semantic_chunking import create_summaries_from_chunks, summary_chain
from utils.utils import batch_iterator, convert_defaultdict

from semantic_encoder import BGEM3FlagEmbedEncoder
from FlagEmbedding import BGEM3FlagModel
from qdrant_client import QdrantClient, models

from langchain_core.documents import Document


from tqdm import tqdm
import time

import torch

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
client = QdrantClient("http://localhost:6333")

In [54]:
#Creating vectorstore (if not exist)
if not client.collection_exists(collection_name="semantic_summary_vectorstore"):
    client.create_collection(
        "semantic_summary_vectorstore",
        vectors_config={
            "dense": models.VectorParams(
                size=1024,
                distance=models.Distance.COSINE
            ),
            "colbert": models.VectorParams(
                size=1024,
                distance=models.Distance.COSINE,
                multivector_config=models.MultiVectorConfig(
                    comparator=models.MultiVectorComparator.MAX_SIM,
                )
            ),
        },
        sparse_vectors_config={
            "sparse": models.SparseVectorParams()
        }
    )

if not client.collection_exists(collection_name="semantic_original"):
    client.create_collection("semantic_original", vectors_config={})


In [3]:
encoder = BGEM3FlagEmbedEncoder()
chunks = create_semantic_chunks_from_directory_with_overlap('extracted/TÁC HẠI', encoder=encoder, min_split_tokens=300, max_split_tokens=850, window_size=20, overlap=0)

encoder = None
torch.cuda.empty_cache()

[32m2024-09-27 11:18:24 INFO semantic_chunkers.utils.logger Single document exceeds the maximum token limit of 850. Splitting to sentences before semantically merging.[0m
100%|██████████| 7/7 [00:11<00:00,  1.70s/it]
[32m2024-09-27 11:18:36 INFO semantic_chunkers.utils.logger Single document exceeds the maximum token limit of 850. Splitting to sentences before semantically merging.[0m
100%|██████████| 10/10 [00:17<00:00,  1.72s/it]
[32m2024-09-27 11:18:54 INFO semantic_chunkers.utils.logger Single document exceeds the maximum token limit of 850. Splitting to sentences before semantically merging.[0m
100%|██████████| 8/8 [00:15<00:00,  1.88s/it]
[32m2024-09-27 11:19:09 INFO semantic_chunkers.utils.logger Single document exceeds the maximum token limit of 850. Splitting to sentences before semantically merging.[0m
100%|██████████| 20/20 [00:35<00:00,  1.78s/it]
[32m2024-09-27 11:19:45 INFO semantic_chunkers.utils.logger Single document exceeds the maximum token limit of 850. Spl

In [13]:
#summaries = []

In [47]:
for chunk in tqdm(chunks[861:]):
    title, summary, unique_id = chunk.metadata['Title'], chunk.metadata['Summary'], chunk.metadata['doc_id']
    res = title + '\n' + summary + '\n\nChunk of Text:\n' + chunk.page_content

    chunk_summary = summary_chain.invoke(res)
    chunk_summary_document = Document(page_content=chunk_summary, metadata={"doc_id": unique_id, "title": title})
    summaries.append(chunk_summary_document)

    time.sleep(1.5)

100%|██████████| 204/204 [10:30<00:00,  3.09s/it]


In [6]:
embeddings = BGEM3FlagModel('BAAI/bge-m3', use_fp16=True)

Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 10698.85it/s]
  colbert_state_dict = torch.load(os.path.join(model_dir, 'colbert_linear.pt'), map_location='cpu')
  sparse_state_dict = torch.load(os.path.join(model_dir, 'sparse_linear.pt'), map_location='cpu')


In [56]:
len(chunks)

1065

In [57]:
chunks[0].metadata

{'source': 'extracted/TÁC HẠI/Journal of Sleep Research - 2022 - Felső - Total sleep deprivation decreases saliva ghrelin levels in adolescents.md',
 'Title': 'Title: Total sleep deprivation decreases saliva ghrelin levels in adolescents',
 'Summary': 'Summary: This study investigates the impact of a single night of total sleep deprivation on fasting saliva ghrelin levels in adolescents. It finds that total sleep deprivation significantly blunts the increase in total-ghrelin concentration that typically occurs overnight, particularly in adolescents with overweight or obesity. The research highlights the physiological implications of sleep deprivation on ghrelin levels, suggesting a need for further studies to explore these effects in greater detail.',
 'doc_id': 'a2e4104a-6748-409f-af66-31b794895732'}

In [58]:
for chunk in chunks:
    id = chunk.metadata['doc_id']
    try:
        client.upload_points(
            "semantic_original",
            points = [
                models.PointStruct(
                    id = id,
                    vector = {},
                    payload = {
                        "doc_id": id,
                        "source": chunk.metadata['source'],
                        "title": chunk.metadata["Title"],
                        "page_content": chunk.page_content
                    }
                )
            ],
            batch_size=1
        )
    except:
        print(f"Error when uploading - {id}")
        continue

In [61]:
batch_size = 16
for batch in batch_iterator(summaries, batch_size):
    text = [summary.page_content for summary in batch]
    res = embeddings.encode(text, return_sparse=True, return_colbert_vecs=True, batch_size=16)

    for i, _ in enumerate(batch):
        doc_id = batch[i].metadata['doc_id']
        title = batch[i].metadata['title']
        content = batch[i].page_content
        try:
            client.upload_points(
                "semantic_summary_vectorstore",
                points = [
                    models.PointStruct(
                        id = doc_id,
                        vector = {
                            "dense": res['dense_vecs'][i].tolist(),
                            "colbert": res['colbert_vecs'][i].tolist(),
                            "sparse": convert_defaultdict(res['lexical_weights'][i])
                        },
                        payload = {
                            "doc_id": doc_id,
                            "title": title,
                            "content": content,
                        }

                    )
                ],
                batch_size=1
            )
        except:
            print(f"Error when uploading - {doc_id}")
            continue