In [None]:
import asyncio
import json
import random
from typing import cast

import pandas as pd
from openai import OpenAI
from pydantic import BaseModel
from tqdm.auto import tqdm

from amlta.app import config as app_config

app_config.ollama_base_url = "https://gentle-vulture-knowing.ngrok-free.app/"
app_config.ollama_model = "qwen2.5:32b-instruct-q3_K_M"

from amlta.app.agent import graph
from amlta.app.agent.core import FlowQueries
from amlta.config import config
from amlta.question_generation.process import QuestionData, load_batches

In [5]:
question_data = load_batches()

In [6]:
training_df = pd.read_parquet(
    config.data_dir / "tapas-ft" / "data" / "tapas_train_batched_dfs_shuffled.parquet"
)

In [None]:
not_trained_on_data = training_df.iloc[int(len(training_df) * 0.8) :]
start_batch = not_trained_on_data["batch"].values[0]
start_question_id = int(not_trained_on_data["question_id"].values[0])
start_process_uuid = not_trained_on_data["process_uuid"].values[0]

# find the last question that was part of the training data
end_tain_idx = next(
    i
    for i, q in enumerate(question_data)
    if q["batch"] == start_batch
    and q["question_id"] == start_question_id
    and q["process_uuid"] == start_process_uuid
)
valid_data = question_data[end_tain_idx + 1 :]
len(valid_data)

721

In [8]:
random.seed(42)

questions = random.choices(valid_data, k=50)

In [29]:
class RewrittenQuestions(BaseModel):
    questions: list[str]

In [27]:
openai = OpenAI()

In [None]:
system_prompt = """
You are a helpful assistant. Your task is to rewrite a list of questions to be more realistic, natural
and human.

While rewriting the questions, please make sure to keep the meaning of the question intact -- the
question should remain answerable with the same accuracy as the original question.
""".strip()


def rewrite_batch(questions: list[QuestionData]) -> RewrittenQuestions:
    resp = openai.beta.chat.completions.parse(
        model="gpt-4o-mini",
        messages=[
            {
                "role": "system",
                "content": system_prompt,
            },
            {
                "role": "user",
                "content": "\n".join(
                    f"{i}. " + q["question_replaced_general"]
                    for i, q in enumerate(questions)
                ),
            },
        ],
        response_format=RewrittenQuestions,
        temperature=0.9,
        top_p=0.9,
    )

    resp_model = resp.choices[0].message.parsed
    assert resp_model is not None
    assert len(resp_model.questions) == len(questions)

    return resp_model

In [39]:
class RewrittenQuestionData(QuestionData):
    rewritten: str

In [None]:
rewritten_questions_resps = []
rewritten_questions: list[RewrittenQuestionData] = []

batch_size = 10
for i in tqdm(range(0, len(questions), batch_size)):
    batch = questions[i : i + batch_size]
    resp = rewrite_batch(batch)
    rewritten_questions_resps.append(resp)

    for q, rewritten in zip(batch, resp.questions):
        rewritten_questions.append(
            cast(RewrittenQuestionData, q | {"rewritten": rewritten})
        )

  0%|          | 0/5 [00:00<?, ?it/s]

In [41]:
unwrapped_task = graph.rewrite_flows_query.__wrapped__  # type: ignore


async def generate_query(question: RewrittenQuestionData):
    random.seed(42)
    # question_str = random.choice(
    #     [
    #         question["question_replaced_basic"],
    #         question["question_replaced_general"],
    #         question["question_replaced_specific"],
    #     ]
    # )
    question_str = question["rewritten"]
    return await unwrapped_task(question_str)

In [None]:
generated: list[FlowQueries] = []

batch_size = 10
for i in tqdm(range(0, len(rewritten_questions), batch_size)):
    batch = rewritten_questions[i : i + batch_size]
    generated.extend(await asyncio.gather(*[generate_query(q) for q in batch]))

  0%|          | 0/5 [00:00<?, ?it/s]

Failed to parse FlowQueries from completion {"queries": [{"justification": "The user is asking for specific information about the amount of waste produced, which directly relates to a type of flow (waste flow) associated with a particular process.", "query": "What are the amounts of waste flows produced during the bleaching of pulp?"}], "join_type": "union"}. Got: 1 validation error for FlowQueries
queries.0.query
  Value error, query must end with 'the process?' [type=value_error, input_value='What are the amounts of ... the bleaching of pulp?', input_type=str]
    For further information visit https://errors.pydantic.dev/2.10/v/value_error
For troubleshooting, visit: https://python.langchain.com/docs/troubleshooting/errors/OUTPUT_PARSING_FAILURE 
Failed to parse FlowQueries from completion {"queries": [{"justification": "The user wants to know the total volume of resources used for a specific type of process (coal-based electricity generation). The query should focus on inputs as it 

In [48]:
len(generated)

50

In [49]:
out_path = config.data_dir / "generated" / "tapas-eval-questions.jsonl"

In [53]:
out_path.write_text(
    "\n".join(
        json.dumps({"question_data": q, "rewritten_flows_query": g.queries[0].query})
        for q, g in zip(rewritten_questions, generated)
    )
)

55519