In [None]:
import os
import logging
from dotenv import load_dotenv
import weaviate
from tqdm import tqdm

from askem.ingest_docs import append_terms
from askem.terms_extractor import MoreThanOneCapStrategy, get_blacklist

load_dotenv()
logging.basicConfig(level=logging.ERROR)

WEAVIATE_URL = os.getenv("WEAVIATE_URL")
WEAVIATE_APIKEY = os.getenv("WEAVIATE_APIKEY")
print(WEAVIATE_URL)

Create client

In [None]:
client = weaviate.Client(
    url=WEAVIATE_URL, auth_client_secret=weaviate.AuthApiKey(api_key=WEAVIATE_APIKEY)
)

schema = client.schema.get()
client.query.aggregate("Passage").with_meta_count().do()

Check properties

In [None]:
schema = client.schema.get()
all_properties = [x["name"] for x in schema["classes"][0]["properties"]]
print(all_properties)

## Make a function to patch a paper

In [None]:
def patch_doc(client, class_name: str, doc: str) -> None:
    """Patch a given paragraph."""

    extractor = MoreThanOneCapStrategy(
        min_length=3, min_occurrence=1, top_k=3, blacklist=get_blacklist(doc["topic"])
    )

    # Strip old terms
    for i in range(10):
        doc[f"article_terms_{i}"] = None
    for i in range(3):
        doc[f"paragraph_terms_{i}"] = None

    # Add new terms
    docs = append_terms([doc], extractor)

    # Update the data objects
    doc = docs[0]
    uuid = doc.pop("_additional")["id"]
    _ = doc.pop("text_content")
    client.data_object.update(uuid=uuid, class_name=class_name, data_object=doc)

## Patch all records

In [None]:
def get_batch_with_cursor(client, class_name, class_properties, batch_size, cursor):
    query = (
        client.query.get(class_name, class_properties)
        .with_additional(["id"])
        .with_limit(batch_size)
    )

    if cursor is not None:
        return query.with_after(cursor).do()
    else:
        return query.do()


def patch_all(client, batch_size: int = 5000, class_name: str = "Passage") -> None:
    """Append terms to all records."""

    cursor = None
    progress_bar = tqdm(total=1080)

    while True:
        results = get_batch_with_cursor(
            client,
            class_name,
            ["paper_id", "topic", "text_content", "article_terms_0"],
            batch_size,
            cursor,
        )

        # Stop if there are no more results
        if not results["data"]["Get"][class_name]:
            break

        objects = results["data"]["Get"][class_name]
        cursor = results["data"]["Get"][class_name][-1]["_additional"]["id"]

        for obj in objects:
            if obj["article_terms_0"] is None:
                continue
            patch_doc(client, class_name, obj)

        progress_bar.update(1)

    progress_bar.close()

In [None]:
patch_all(client=client)
# 25 hours runtime

Check all are patched (article_term_0 had set to None)


In [None]:
cursor = None
unpatched = []
progress_bar = tqdm(total=1080)
while True:
    results = get_batch_with_cursor(
        client, "Passage", ["paper_id", "article_terms_0"], 5000, cursor
    )

    # Stop if there are no more results
    if not results["data"]["Get"]["Passage"]:
        break

    objects = results["data"]["Get"]["Passage"]
    cursor = results["data"]["Get"]["Passage"][-1]["_additional"]["id"]

    for obj in objects:
        if obj["article_terms_0"] is not None:
            unpatched.append(obj)

    progress_bar.update(1)
progress_bar.close()