In [None]:
import weaviate
from tqdm import tqdm
import os
from dotenv import load_dotenv

load_dotenv()


def get_batch_with_cursor(
    client, class_name, class_properties, batch_size, cursor=None
):
    query = (
        client.query.get(class_name, class_properties)
        .with_additional(["id"])
        .with_limit(batch_size)
    )

    if cursor is not None:
        return query.with_after(cursor).do()
    else:
        return query.do()

In [None]:
source_client = weaviate.Client(
    url=os.getenv("WEAVIATE_URL"),
    auth_client_secret=weaviate.AuthApiKey(api_key=os.getenv("WEAVIATE_APIKEY")),
)


def count_docs(
    client: weaviate.Client,
    class_name: str = "Passage",
    batch_size: int = 5000,
) -> dict:
    """Count the number of documents in a topic."""

    _tmp = client.query.aggregate(class_name).with_meta_count().do()
    n = _tmp["data"]["Aggregate"][class_name][0]["meta"]["count"]

    paper_ids = {}
    count_paragraphs = {}
    cursor = None

    with tqdm(total=n) as progress_bar:
        while True:
            batch = get_batch_with_cursor(
                source_client,
                class_name,
                ["topic", "paper_id"],
                batch_size,
                cursor=cursor,
            )

            if len(batch["data"]["Get"][class_name]) == 0:
                break

            objects_list = batch["data"]["Get"][class_name]
            for obj in objects_list:
                # Count paragraphs
                count_paragraphs[obj["topic"]] = (
                    count_paragraphs.get(obj["topic"], 0) + 1
                )

                # Store paper ids as set
                paper_ids[obj["topic"]] = paper_ids.get(obj["topic"], set())
                paper_ids[obj["topic"]].add(obj["paper_id"])

            cursor = batch["data"]["Get"][class_name][-1]["_additional"]["id"]
            progress_bar.update(batch_size)

    return {
        "n_paragraphs": count_paragraphs,
        "n_papers": {k: len(v) for k, v in paper_ids.items()},
    }

In [None]:
count_docs(source_client)