In [None]:
from pathlib import Path
import os
import simplejson as json
import polars as pl
import json_repair
from tqdm import tqdm
from more_itertools import chunked
from loguru import logger
import asyncio as asio
from justatom.tooling.stl import reuuid
from justatom.storing.dataset import API as DatasetApi

from justatom.etc.io import io_snapshot
from justatom.tooling.reqs import openai_chat
from justatom.tooling.coro import _limit_concurrency
from justatom.running.igni import IGNIRunner

In [None]:
def source_from_dataset(dataset_name_or_path, **props):
    from justatom.storing.dataset import API as DatasetApi
    import polars as pl

    maybe_df_or_iter = DatasetApi.named(dataset_name_or_path).iterator(**props)
    if isinstance(maybe_df_or_iter, pl.DataFrame):
        pl_data = maybe_df_or_iter
    else:
        dataset = list(maybe_df_or_iter)
        pl_data = pl.from_dicts(dataset)
    return pl_data

In [None]:
dataset_name_or_path = Path.home() / "IDataset" / "SEVERSTAL" / "SEVERSTAL_48_split_by_line.xlsx"

In [None]:
pl_source_docs = source_from_dataset(dataset_name_or_path, sheet_name="Chunks")
pl_meta_docs = source_from_dataset(dataset_name_or_path, sheet_name="Documents")
pl_docs = pl_source_docs.join(pl_meta_docs, on="doc_id", how="inner")

In [None]:
pl_source_docs.head()

In [None]:
content_col = "chunk_text"
title_col = "doc_name"
chunk_id_col = "chunk_id"

In [None]:
logger.info(f"There are N=[{pl_docs.select('chunk_id').unique().shape[0]}] unique chunks")
logger.info(f"There are D=[{pl_docs.select('doc_id').unique().shape[0]}] unique docs")
logger.info(f"There L=[{pl_docs.select('doc_name').unique().shape[0]}] unique doc names")

In [None]:
def wrapper_for_props(d: dict, must_include_keys: list[str] = None) -> dict:
    """
    :param d: Source doc
    :param must_include_keys: List of keys to include
    :return: New doc with only specified `must_include_keys`
    """
    must_include_keys = d.keys() if must_include_keys is None else must_include_keys
    return {key: d[key] for key in must_include_keys if key in d}

In [None]:
async def pipeline(
    js_docs: list[dict],
    pr_runner,
    openai_model_name: str,
    batch_size: int = 16,
    coros_size: int = 2,
    save_snapshot_every: int = 5,
    snapshot_prefix: str = None,
    snapshot_where: str = None,
    timeout: int = 512,
    must_include_keys: list[str] | None = None,
    validate_json_response: bool = False,
):
    """
    We process `js_docs` by chunks where each chunk is of size `batch_size`.
    Each chunk is processed asynchronously via parallel `coros_size` coroutines.

    :param js_docs: documents to process
    :param pr_runner: One of the instance `IPromptRunner` to create specific prompt
    """
    pipes = []

    for i, batch in tqdm(enumerate(chunked(js_docs, n=batch_size))):
        _batch = batch
        cur_result = await asio.gather(
            *_limit_concurrency(
                [
                    openai_chat(
                        pr_runner.prompt(**d),
                        timeout=timeout,
                        model=openai_model_name,
                        props=wrapper_for_props(d, must_include_keys=must_include_keys),
                    )
                    for d in _batch
                ],
                concurrency=coros_size,
            )
        )
        if not validate_json_response:
            pipes.extend(cur_result)
        else:
            # NOTE: Order of execution is preserved.
            js_answer_docs = [
                pr_runner.finalize(
                    raw_response=js_res["response"], **wrapper_for_props(js_doc, must_include_keys=must_include_keys)
                )
                for js_doc, js_res in zip(batch, cur_result, strict=True)
            ]
            pipes.extend(js_answer_docs)

        if (i + 1) % save_snapshot_every == 0:
            io_snapshot(pipes, where=snapshot_where, snapshot_number=str(i + 1), snapshot_prefix=snapshot_prefix)
    return pipes

In [None]:
title_col

In [None]:
pl_source_docs.head()

In [None]:
pl_meta_docs.head()

In [None]:
system_prompt = f"""
Generate synthetic data from short questions that users can ask chatbots or customer support in Russian.

The goal is to anticipate possible user questions, ensuring that they are clearly worded, appropriate to the context, and can be easily answered based on the context.

# Stages of creating synthetic data
1. Carefully study the document submitted to you and its context.
2. Formulate brief questions related to the document and make sure that there are clear answers in the text.
3. Make sure that each question is unique, and use different formulations to ensure diversity.
4. Formulate answers strictly based on the content of the document.
5. Write the question text in Russian only and make it suitable for chatbot or user support scenarios.

(Optional: for real-world examples, more complex documents may be required, as well as various pairs of quality tests. Use PLACEHOLDERS for real texts and quality control.)
"""

In [None]:
from justatom.running.prompt import QUERIESPropmtRunner

In [None]:
openai_model_name = "gpt-4o-mini"
batch_size = 4
coros_size = 2
save_snapshot_every = 1
must_include_keys = ["chunk_id", content_col, title_col]
snapshot_prefix = "SEVERSTAL|QUERIES"
snapshot_where = "outputs"
source_language="Russian"
timeout  = 512

In [None]:
pr_runner = QUERIESPropmtRunner(
    system_prompt=system_prompt.strip(),
    source_language=source_language
)

In [None]:
pl_docs = pl_docs.rename({content_col: "content", title_col: "title"})
js_docs = pl_docs.to_dicts()

In [None]:
pr_runner._prepare(
    **js_docs[0]
)

In [None]:
response = await pipeline(js_docs,openai_model_name=openai_model_name, pr_runner=pr_runner, batch_size=batch_size, coros_size=coros_size, save_snapshot_every=save_snapshot_every, must_include_keys=must_include_keys, snapshot_prefix=snapshot_prefix, snapshot_where=snapshot_where, timeout=timeout)