In [None]:
from pathlib import Path
import os
import dotenv
import simplejson as json
import polars as pl
import json_repair
from tqdm import tqdm
from more_itertools import chunked
from loguru import logger
import asyncio as asio
from justatom.tooling.stl import reuuid
from justatom.storing.dataset import API as DatasetApi

from justatom.etc.io import io_snapshot
from justatom.tooling.reqs import openai_chat
from justatom.tooling.coro import _limit_concurrency
from justatom.running.igni import IGNIRunner

In [None]:
dotenv.load_dotenv()

In [None]:
def source_from_dataset(dataset_name_or_path, **props):
    from justatom.storing.dataset import API as DatasetApi
    import polars as pl

    maybe_df_or_iter = DatasetApi.named(dataset_name_or_path).iterator(**props)
    if isinstance(maybe_df_or_iter, pl.DataFrame):
        pl_data = maybe_df_or_iter
    else:
        dataset = list(maybe_df_or_iter)
        pl_data = pl.from_dicts(dataset)
    return pl_data

In [None]:
def wrapper_for_props(d: dict, must_include_keys: list[str] = None) -> dict:
    """
    :param d: Source doc
    :param must_include_keys: List of keys to include
    :return: New doc with only specified `must_include_keys`
    """
    must_include_keys = d.keys() if must_include_keys is None else must_include_keys
    return {key: d[key] for key in must_include_keys if key in d}

def wrapper_json_decode(pl_docs: pl.DataFrame, columns: list[str], must_include_keys: list[str] = None) -> pl.DataFrame:
    """
    :param pl.DataFrame: Source documents, usually coming `pl_docs.from pl.DataFrame.to_dicts()`.
    :param must_include_keys: List of keys to include from source `js_docs`.
    :return: New Polars.DataFrame with `json_decode` on the given `columns` where each row is a list of string.
    """
    js_docs = pl_docs.to_dicts()
    
    js_answer= [
        {
            **{col: json_repair.loads(js_doc[col]) for col in columns},
            **wrapper_for_props(js_doc, must_include_keys=["doc_id", "doc_name"])
        } for js_doc in js_docs
    ]
    
    return pl.from_dicts(js_answer)

In [None]:
def source_from_dataset(dataset_name_or_path, **props):
    from justatom.storing.dataset import API as DatasetApi
    import polars as pl

    maybe_df_or_iter = DatasetApi.named(dataset_name_or_path).iterator(**props)
    if isinstance(maybe_df_or_iter, pl.DataFrame):
        pl_data = maybe_df_or_iter
    else:
        dataset = list(maybe_df_or_iter)
        pl_data = pl.from_dicts(dataset)
    return pl_data

In [None]:
dataset_name_or_path = Path.home() / "IDataset" / "SEVERSTAL" / "SEVERSTAL_48_split_by_line.xlsx"

content_col = "chunk_text"
title_col = "doc_name"
chunk_id_col = "chunk_id"

In [None]:
pl_source_docs = source_from_dataset(dataset_name_or_path, sheet_name="Chunks")
pl_meta_docs = source_from_dataset(dataset_name_or_path, sheet_name="Documents")
pl_meta_docs = wrapper_json_decode(pl_meta_docs, columns=["doc_chunk_ids", "doc_sections"], must_include_keys=["doc_id", "doc_name"])

In [None]:
pl_source_docs.head()

In [None]:
logger.info(f"There are N=[{pl_source_docs.select('chunk_id').unique().shape[0]}] unique chunks")
logger.info(f"There are D=[{pl_meta_docs.select('doc_id').unique().shape[0]}] unique docs")

In [None]:
async def pipeline(
    js_docs: list[dict],
    pr_runner,
    openai_model_name: str,
    batch_size: int = 16,
    coros_size: int = 2,
    save_snapshot_every: int = 5,
    snapshot_prefix: str = None,
    snapshot_where: str = None,
    timeout: int = 512,
    must_include_keys: list[str] | None = None,
    validate_json_response: bool = False,
):
    """
    We process `js_docs` by chunks where each chunk is of size `batch_size`.
    Each chunk is processed asynchronously via parallel `coros_size` coroutines.

    :param js_docs: documents to process
    :param pr_runner: One of the instance `IPromptRunner` to create specific prompt
    """
    pipes = []

    for i, batch in tqdm(enumerate(chunked(js_docs, n=batch_size))):
        _batch = batch
        cur_result = await asio.gather(
            *_limit_concurrency(
                [
                    openai_chat(
                        pr_runner.prompt(**d),
                        timeout=timeout,
                        model=openai_model_name,
                        props=wrapper_for_props(d, must_include_keys=must_include_keys),
                    )
                    for d in _batch
                ],
                concurrency=coros_size,
            )
        )
        if not validate_json_response:
            pipes.extend(cur_result)
        else:
            # NOTE: Order of execution is preserved.
            js_answer_docs = [
                pr_runner.finalize(
                    raw_response=js_res["response"], **wrapper_for_props(js_doc, must_include_keys=must_include_keys)
                )
                for js_doc, js_res in zip(batch, cur_result, strict=True)
            ]
            pipes.extend(js_answer_docs)

        if (i + 1) % save_snapshot_every == 0:
            io_snapshot(pipes, where=snapshot_where, snapshot_number=str(i + 1), snapshot_prefix=snapshot_prefix)
    return pipes

In [None]:
pl_source_docs.head()

In [None]:
pl_meta_docs.head()

In [None]:
def wrapper_for_paragraphs(paragraphs: list[str]):
    content = "\n".join([f"Paragraph {pos}: {text}" for pos, text in enumerate(paragraphs)])
    return content

In [None]:
pl_docs = pl_meta_docs\
    .explode("doc_chunk_ids")\
    .join(
        pl_source_docs,
        how="left",
        left_on=["doc_id", "doc_chunk_ids"],  # doc_id + текущий chunk_id
        right_on=["doc_id", "chunk_id"]
    ).groupby(["doc_id", "doc_name"]).agg([
        pl.col("doc_chunk_ids").flatten().alias("doc_chunk_ids"),
        pl.col("doc_sections").flatten().alias("doc_sections"),
        pl.col(content_col).explode().map_elements(lambda paragraphs: wrapper_for_paragraphs(paragraphs)).alias("content")
    ])

In [None]:
logger.info(f"There are total G=[{len(pl_docs)}] groups of chunks")

In [None]:
system_prompt = f"""
Generate synthetic data from short questions that users can ask chatbots or customer support in Russian.

The goal is to anticipate possible user questions, ensuring that they are clearly worded, appropriate to the context, and can be easily answered based on the context.

# Stages of creating synthetic data
1. Carefully study the document submitted to you and its context.
2. Formulate brief questions related to the document with multiple paragraphs and make sure that there are clear answers in the text.
3. Make sure that each question is unique, and use different formulations to ensure diversity.
4. Formulate answers strictly based on the content (paragraphs) of the document.

(Optional: for real-world examples, more complex documents may be required, as well as various pairs of quality tests. Use PLACEHOLDERS for real texts and quality control.)
"""

In [None]:
from justatom.running.prompt import QUERIESWithSourcesPromptRunner

In [None]:
# openai_model_name = "gpt-4o-mini"
openai_model_name = "o3-mini"
batch_size = 4
coros_size = 2
save_snapshot_every = 1
must_include_keys = ["doc_id", content_col, title_col]
snapshot_prefix = "SEVERSTAL|QueriesWithSources"
snapshot_where = "outputs"
source_language="Russian"
timeout  = 512

In [None]:
pr_runner = QUERIESWithSourcesPromptRunner(
    source_language=source_language,
    system_prompt=system_prompt
)

In [None]:
pr_runner.system_prompt

In [None]:
pl_view_docs = pl_docs.select(["doc_id", "content", "doc_chunk_ids", "doc_name"])
pl_view_docs = pl_view_docs.rename({title_col: "title"})
js_view_docs = pl_view_docs.to_dicts()

In [None]:
pr_runner._prepare(
    **js_view_docs[0]
)

In [None]:
pr_runner.prompt(**js_view_docs[0])

In [None]:
response = await pipeline(js_view_docs,openai_model_name=openai_model_name, pr_runner=pr_runner, batch_size=batch_size, coros_size=coros_size, save_snapshot_every=save_snapshot_every, must_include_keys=must_include_keys, snapshot_prefix=snapshot_prefix, snapshot_where=snapshot_where, timeout=timeout)

In [None]:
pl_docs.head()

#### `Merge` results

Now, after the annotation, let's take the latest file—or the one we want to check—from the `outputs` directory, where the `LLM` synthetic annotation results were periodically saved. We'll load it from there and merge it with the correct paragraphs by `chunk_id` according to `doc_id`.

In [None]:
pl_response = source_from_dataset(Path(os.getcwd()) / "outputs" / "SEVERSTAL|QueriesWithSources3.json")
js_response = pl_response.to_dicts()

In [None]:
js_response_struct = [{
    **json_repair.loads(js_doc['response']),
    **wrapper_for_props(js_doc, must_include_keys=["doc_id"])
    } for js_doc in js_response[:6]
]

In [None]:
js_response_struct[1]

In [None]:
pl_response_struct = pl.from_dicts(js_response_struct)

In [None]:
pl_response_struct.head()

In [None]:
pl_final_response = pl_response_struct.join(pl_docs, on="doc_id", how="inner")

In [None]:
openai_model_name

In [None]:
pl_final_response.explode("answer").rename({"answer": "QA"}).write_excel(f"SEVERSTAL.{openai_model_name}.QueriesWithMultiSources.xlsx")