In [None]:
import os
import logging
from dotenv import load_dotenv
import weaviate
from askem.ingest_docs import append_terms
from askem.retriever.base import get_v2_schema

load_dotenv()
logging.basicConfig(level=logging.INFO)

WEAVIATE_URL = os.getenv("WEAVIATE_URL")
WEAVIATE_APIKEY = os.getenv("WEAVIATE_APIKEY")
print(WEAVIATE_URL)

Create client

In [None]:
client = weaviate.Client(
    url=WEAVIATE_URL, auth_client_secret=weaviate.AuthApiKey(api_key=WEAVIATE_APIKEY)
)

schema = client.schema.get()
client.query.aggregate("Passage").with_meta_count().do()

In [None]:
properties_v1 = [x["name"] for x in schema["classes"][0]["properties"]]
print(properties_v1)

update to v2 schema

In [None]:
schema_v2 = get_v2_schema()
new_properties = [p for p in schema_v2["properties"] if p["name"] not in properties_v1]
print([p["name"] for p in new_properties])

In [None]:
for p in new_properties:
    client.schema.property.create("Passage", p)

Check new properties

In [None]:
schema = client.schema.get()
all_properties_v2 = [x["name"] for x in schema["classes"][0]["properties"]]
print(all_properties_v2)

## Make a function to patch a paper

In [None]:
def get_paper(
    client, class_name: str, paper_id: str, extra_properties: list = None
) -> list:
    """Get a list of paragraphs for a given paper."""

    where_filter = {"path": ["paper_id"], "operator": "Equal", "valueString": paper_id}

    properties = ["text_content"]
    if extra_properties:
        properties.extend(extra_properties)

    return (
        client.query.get(class_name, properties)
        .with_additional(["id"])
        .with_where(where_filter)
        .do()
    )["data"]["Get"][class_name]


def patch_paper(client, class_name: str, paper_id: str) -> None:
    """Patch a given paper with its paragraphs."""

    paragraphs = get_paper(client, class_name, paper_id)
    paragraphs = append_terms(paragraphs)

    # Update the data objects
    for paragraph in paragraphs:
        uuid = paragraph.pop("_additional")["id"]
        _ = paragraph.pop("text_content")
        client.data_object.update(
            uuid=uuid, class_name=class_name, data_object=paragraph
        )

In [None]:
# get one paper id
client.query.get("Passage", ["text_content", "paper_id"]).with_limit(1).do()

In [None]:
patch_paper(client, "Passage", "616ea0a767467f7269d4a7e4")

In [None]:
get_paper(
    client=client,
    class_name="Passage",
    paper_id="616ea0a767467f7269d4a7e4",
    extra_properties=all_properties_v2,
)

Tested, it works in one paper.

## Patch all records

In [None]:
def get_batch_with_cursor(client, class_name, class_properties, batch_size, cursor):
    query = (
        client.query.get(class_name, class_properties)
        .with_additional(["id"])
        .with_limit(batch_size)
    )

    if cursor is not None:
        return query.with_after(cursor).do()
    else:
        return query.do()


def patch_all(client, batch_size: int = 5000, class_name: str = "Passage") -> None:
    """Append terms to all records."""

    cursor = None
    processed_paper_ids = set()

    while True:
        results = get_batch_with_cursor(
            client, class_name, ["paper_id"], batch_size, cursor
        )

        # Stop if there are no more results
        if not results["data"]["Get"][class_name]:
            break

        objects = results["data"]["Get"][class_name]
        cursor = results["data"]["Get"][class_name][-1]["_additional"]["id"]

        # Get a list of unique paper ids from the batch
        paper_ids = set([object["paper_id"] for object in objects])

        # Unprocessed paper ids
        paper_ids = [x for x in paper_ids if x not in processed_paper_ids]

        for paper_id in paper_ids:
            logging.info(f"Processing paper {paper_id}")
            patch_paper(client, class_name, paper_id)
            processed_paper_ids.add(paper_id)

In [None]:
patch_all(client=client)
# 25 hours runtime