# Revamp data ingest pipeline

New requirements:
1. Have a resume mechanism to avoid duplication of text
2. Have a append topic mechanism to modify existing paragraph topics.
3. Don't break production demo.
4. Directly ingest from ElasticSearch service without writing too much txt to disk.
5. Make a cron job to do this automatically.

Files to be ingested:

- /hdd/iaross/askem/criticalmaas_text

Steps:

1. Make a new Class: `Paragraph` to replace `Passage` class. `Passage` will be deprecated after the entire migration is done.
1. Create canonical `id2topics` pickle file. Hopefully it is small enough to be loaded into memory.
1. Use batch mechanism to ingest data from ElasticSearch service. e.g., 1000 paragraphs per batch.
1. Upgrade frontend to use `Paragraph` class.
1. Setup cron job to do this automatically.


In [3]:
import os
import requests
import weaviate
import pickle
from tqdm import tqdm
from pathlib import Path
from dotenv import load_dotenv

load_dotenv()

# Temporary fix to get the path right
import sys

sys.path.append("/hdd/clo36/repo/ask-xDD/askem/retriever")

from askem.retriever.base import get_schema
from askem.utils import get_ingested_ids, get_text

### Step 1: Create `Paragraph` class

In [4]:
weaviate_client = weaviate.Client(
    url=os.getenv("WEAVIATE_URL"),
    auth_client_secret=weaviate.AuthApiKey(api_key=os.getenv("WEAVIATE_APIKEY")),
)

            Please consider upgrading to the latest version. See https://weaviate.io/developers/weaviate/client-libraries/python for details.


In [None]:
weaviate_client.query.aggregate("Passage").with_meta_count().do()

In [None]:
paragraph_schema = get_schema("Paragraph")
paragraph_schema

In [None]:
weaviate_client.schema.create_class(paragraph_schema)

In [None]:
weaviate_client.query.aggregate("Paragraph").with_meta_count().do()

### Step 2: Create id2topics pickle file

In [None]:
def invert(d: dict[str : list[str]]) -> dict[str : list[str]]:
    """Invert a dictionary."""
    inverted = {}
    for topic, ids in d.items():
        for id in ids:
            if id not in inverted:
                inverted[id] = [topic]
            elif topic not in inverted[id]:
                inverted[id].append(topic)
    return inverted


class DocumentTopicFactory:
    """A factory class to create document-topic objects."""

    def __init__(self, set_names: list[str]) -> None:
        self.set_names = set_names

        self.id2topics: dict[str : list[str]] = {}
        self.topic2ids: dict[str : list[str]] = {}

    def run(self) -> dict[str : list[str]]:
        """Run the factory."""
        for set_name in self.set_names:
            print(f"Getting ids for {set_name}")
            self.topic2ids[set_name] = self.get_ids(set_name)
            print(f"Found {len(self.topic2ids[set_name])} ids for {set_name}")

        self.id2topics = invert(self.topic2ids)

        # Write to file
        with open("tmp/id2topics.pkl", "wb") as f:
            pickle.dump(self.id2topics, f)

        return self.id2topics

    def get_ids(self, topic: str) -> list[str]:
        """Get all ids for a topic."""

        next_page = f"https://xdd.wisc.edu/api/articles?set={topic}&full_results=true&fields=_gddid"
        progress_bar = tqdm()
        ids = []
        while next_page:
            response = requests.get(next_page)
            data = response.json()
            ids.extend(self._parse_response(data))
            next_page = data["success"]["next_page"]
            progress_bar.update(1)
        return ids

    def __str__(self) -> str:
        return "\n".join(
            [f"{topic}: n={len(ids)}" for topic, ids in self.topic2ids.items()]
        )

    @staticmethod
    def _parse_response(data: dict) -> list[str]:
        """Get all ids from a xDD json response."""

        if "success" not in data:
            raise ValueError("Not a valid xDD response.")

        docs = data["success"]["data"]
        return [doc["_gddid"] for doc in docs]

In [None]:
set_names = [
    "climate-change-modeling",
    "criticalmaas",
    "dolomites",
    "geoarchive",
    "xdd-covid-19",
]
doc_topic_factory = DocumentTopicFactory(set_names)

In [None]:
id2topics = doc_topic_factory.run()

### Step 3: Ingest into `Paragraph` class directly from ElasticSearch

In [5]:
class WeaviateIngester:
    def __init__(
        self, client: weaviate.Client, id2topics_pkl: Path, ingested_pkl: Path
    ) -> None:
        self.client = client

        with open(id2topics_pkl, "rb") as f:
            self.id2topics = pickle.load(f)

        with open(ingested_pkl, "rb") as f:
            self.ingested = pickle.load(f)


ingester = WeaviateIngester(
    client=weaviate_client,
    id2topics_pkl="tmp/id2topics.pkl",
    ingested_pkl="tmp/ingested.pkl",
)

In [None]:
ingested_ids = get_ingested_ids(weaviate_client, class_name="Paragraph")

In [None]:
ingested_ids

In [None]:
with open("tmp/id2topics.pkl", "rb") as f:
    id2topics = pickle.load(f)
id2topics

In [None]:
i = 0
batch_size = 10
batch_ids = sorted(id2topics.keys())[i : i + batch_size]

In [None]:
from pathlib import Path

ingest_tmp_folder = Path("tmp/ingest")


def write_batch_to_file(batch_ids: list[str], folder: Path) -> None:
    """Write a batch of ids to a tmp file."""

    folder.mkdir(parents=True, exist_ok=True)

    for docid in batch_ids:
        text = get_text(docid)
        with open(f"{folder}/{docid}.txt", "w") as f:
            f.write(text)


write_batch_to_file(batch_ids, ingest_tmp_folder)

In [None]:
from askem.preprocessing import HaystackPreprocessor

preprocessor = HaystackPreprocessor()

In [None]:
input_files = ingest_tmp_folder.glob("**/*.txt")

In [None]:
from tqdm import tqdm

doc_type = "paragraph"
class_name = "Paragraph"
weaviate_client.batch.configure(batch_size=5, dynamic=True)

with weaviate_client.batch as batch:
    for input_file in input_files:
        docid = input_file.stem
        topics = id2topics[docid]
        docs = preprocessor.run(input_file=input_file, topics=topics, doc_type=doc_type)

        # paragraph level loop (each paragraph)
        for doc in docs:
            batch.add_data_object(data_object=doc, class_name=class_name)

In [None]:
weaviate_client.query.aggregate("Paragraph").with_meta_count().do()

In [None]:
set_names = [
    "climate-change-modeling",
    "criticalmaas",
    "dolomites",
    "geoarchive",
    "xdd-covid-19",
]

In [None]:
weaviate_client.query.get(
    "Paragraph", ["paper_id", "topic_list", "hashed_text"]
).with_limit(5).do()

In [None]:
# weaviate_client.query.get("Paragraph", ["paper_id", "topic_list", "hashed_text"]).with_where(
#     {
#         "path": "topic_list",
#         "operator": "ContainsAny",
#         "valueText": ["geoarchive"],
#     }
# ).with_limit(5).do()

### Prepartions

Dump current id and topic to a file

In [None]:
from dotenv import load_dotenv
import os
from tqdm import tqdm
import weaviate
import hashlib

load_dotenv()

In [None]:
auth = weaviate.auth.AuthApiKey(os.getenv("WEAVIATE_APIKEY"))
client = weaviate.Client(os.getenv("WEAVIATE_URL"), auth)

In [None]:
# Check backup status
client.backup.get_create_status(
    backup_id="pre_duduplication",
    backend="filesystem",
)

In [None]:
def get_batch_with_cursor(
    client, class_name, class_properties, batch_size, cursor=None
):
    query = (
        client.query.get(class_name, class_properties)
        .with_additional(["id"])
        .with_limit(batch_size)
    )

    if cursor is not None:
        return query.with_after(cursor).do()
    else:
        return query.do()

In [None]:
def get_hash(text):
    return hashlib.sha256(text.encode()).hexdigest()

In [None]:
# Get number of documents

metadata = client.query.aggregate("passage").with_meta_count().do()
n = metadata["data"]["Aggregate"]["Passage"][0]["meta"]["count"]
n

Dump topic to a file

In [None]:
cursor = None
class_name = "Passage"
id2topic = {}

pbar = tqdm(total=n)
while True:
    # From the SOURCE instance, get the next group of objects
    results = get_batch_with_cursor(
        client,
        class_name,
        class_properties=["paper_id", "topic"],
        batch_size=1024,
        cursor=cursor,
    )

    # If empty, we're finished
    if len(results["data"]["Get"][class_name]) == 0:
        break

    # A batch of objects
    objects = results["data"]["Get"][class_name]
    for object in objects:
        paper_id = object["paper_id"]
        topic = object["topic"]

        if paper_id not in id2topic:
            id2topic[paper_id] = [topic]
        elif topic not in id2topic[paper_id]:
            id2topic[paper_id].append(topic)

    # Update the cursor to the id of the last retrieved object
    cursor = results["data"]["Get"][class_name][-1]["_additional"]["id"]
    pbar.update(len(objects))
pbar.close()

In [None]:
import pickle
import datetime

# today date in YYMMDD format
today = datetime.datetime.now().strftime("%y%m%d")
today

In [None]:
with open(f"topic_dump_{today}.pkl", "wb") as f:
    pickle.dump(id2topic, f)

Deduplicate with text hash

In [None]:
class_name = "Passage"
existing_hash = set()
batch_size = 32
cursor = None
deleted = 0
pbar = tqdm(total=n)
while True:
    # From the SOURCE instance, get the next group of objects
    results = get_batch_with_cursor(
        client,
        class_name,
        class_properties=["paper_id", "text_content"],
        batch_size=batch_size,
        cursor=cursor,
    )

    # If empty, we're finished
    if len(results["data"]["Get"][class_name]) == 0:
        break

    # A batch of objects
    objects = results["data"]["Get"][class_name]
    for object in objects:
        uuid = object["_additional"]["id"]
        paper_id = object["paper_id"]
        text = object["text_content"]
        hashed_text = get_hash(text)

        if hashed_text not in existing_hash:
            print(f"Updating object: {uuid}")
            existing_hash.add(hashed_text)
            client.data_object.update(
                uuid=uuid,
                class_name=class_name,
                data_object={
                    "topic_list": id2topic[paper_id],
                    "text_hash": hashed_text,
                },
            )
        else:
            # Delete the duplicated object
            print(f"Deleting object: {uuid}")
            try:
                client.data_object.delete(uuid, class_name)
                deleted += 1
            except weaviate.exceptions.UnexpectedStatusCodeException as e:
                print(e)

    # Update the cursor to the id of the last retrieved object
    cursor = results["data"]["Get"][class_name][-1]["_additional"]["id"]
    pbar.update(len(objects))
pbar.close()