In [1]:
from typing import TYPE_CHECKING


if TYPE_CHECKING:
    from math_rag.application.containers import ApplicationContainer
    from math_rag.infrastructure.containers import InfrastructureContainer

    application_container: ApplicationContainer
    infrastructure_container: InfrastructureContainer

In [2]:
RESET = False
%load_ext hooks.notebook_hook

2025-06-28 14:48:40,379 - INFO - datasets - config.py:54 - PyTorch version 2.6.0 available.


In [3]:
from pathlib import Path


google_drive_repository = infrastructure_container.google_drive_repository()
math_article_parser_service = infrastructure_container.math_article_parser_service()

file_id = google_drive_repository.get_file_id(
    Path('ml/lectures/L07-LogisticRegression2/2024_08_10_2174b40686820b4cb591g.tex')
)

if not file_id:
    raise ValueError()

file_content = google_drive_repository.get_file_by_id(file_id)

2025-06-28 14:48:45,982 - INFO - googleapiclient.discovery_cache - __init__.py:49 - file_cache is only supported with oauth2client<4.0.0


In [None]:
from math_rag.core.models import Index


index = Index()

In [None]:
katex_corrector_assistant = application_container.katex_corrector_assistant()
katex_corrector_retrier_assistant = application_container.katex_corrector_retrier_assistant()
math_expression_description_writer_assistant = (
    application_container.math_expression_description_writer_assistant()
)
math_expression_description_optimizer_assistant = (
    application_container.math_expression_description_optimizer_assistant()
)
math_expression_comparator_assistant = application_container.math_expression_comparator_assistant()
math_expression_relationship_description_writer_assistant = (
    application_container.math_expression_relationship_description_writer_assistant()
)

In [None]:
from math_rag.core.models import MathArticle, MathExpression


math_article = MathArticle(
    math_expression_dataset_id=None,
    index_id=None,
    name='article',
    bytes=file_content.getvalue(),
)
math_nodes, _, template = math_article_parser_service.parse_for_index(math_article)

In [None]:
from math_rag.application.models.assistants.inputs import KatexCorrector as AssistantInput
from math_rag.application.models.assistants.inputs import (
    KatexCorrectorRetrier as RetrierAssistantInput,
)
from math_rag.application.models.assistants.outputs import KatexCorrector as AssistantOutput


katex_client = application_container.katex_client()


katexes = [math_node.latex.strip('$') for math_node in math_nodes]

In [None]:
max_num_retries = ...
katexes: list[str] = ...


# initialize tasks, one per input, preserving original order via index
tasks = [KatexCorrectorTask(index=i, current_katex=katex) for i, katex in enumerate(katexes)]

# phase 1: initial validation and katex corrector
to_validate = [task.current_katex for task in tasks]
initial_results = await katex_client.batch_validate_many(to_validate, batch_size=50)

# collect tasks that need correction
input_id_to_task: dict[UUID, KatexCorrectorTask] = {}
input_id_to_input: dict[UUID, AssistantInput] = {}
initial_inputs: list[AssistantInput] = []
invalid_tasks: list[KatexCorrectorTask] = []

for task, result in zip(tasks, initial_results):
    if not result.valid:
        if result.error is None:
            raise ValueError()

        task.error = result.error
        ci = AssistantInput(katex=task.current_katex, error=task.error)
        initial_inputs.append(ci)
        input_id_to_task[ci.id] = task
        input_id_to_input[ci.id] = ci
        invalid_tasks.append(task)

    else:
        task.num_retries = 0

# run the corrector exactly once on all initially invalid katexes
if initial_inputs:
    outputs = await katex_corrector_assistant.concurrent_assist(initial_inputs)

    for output in outputs:
        task = input_id_to_task[output.input_id]
        input = input_id_to_input[output.input_id]
        task.history.append((input, output))
        task.current_katex = output.katex
        task.num_retries = 1

# phase 2: retry loop with validation and katex retrier corrector
while invalid_tasks:
    # re-validate what's still invalid
    to_validate = [task.current_katex for task in invalid_tasks]
    results = await katex_client.batch_validate_many(to_validate, batch_size=50)

    next_invalid_tasks: list[KatexCorrectorTask] = []
    retrier_inputs: list[RetrierAssistantInput] = []
    retrier_input_id_to_task: dict[UUID, KatexCorrectorTask] = {}
    retrier_input_id_to_input: dict[UUID, RetrierAssistantInput] = {}

    # build retrier inputs for those still failing
    for task, result in zip(invalid_tasks, results):
        if result.valid:
            continue

        if result.error is None:
            raise ValueError()

        task.error = result.error

        # assemble every previous (input, output) plus the new one with None output
        history_pairs = [(ci, co) for (ci, co) in task.history]
        ci = AssistantInput(katex=task.current_katex, error=task.error)
        history_pairs.append((ci, None))
        ri = RetrierAssistantInput(pairs=history_pairs)
        retrier_inputs.append(ri)
        retrier_input_id_to_task[ri.id] = task
        retrier_input_id_to_input[ri.id] = ri

        # mark for next round if retries remain
        next_invalid_tasks.append(task)

    if not retrier_inputs:
        break

    # call retrier in one batch
    retrier_outputs = await katex_corrector_retrier_assistant.concurrent_assist(retrier_inputs)

    # apply retrier outputs, bump retries, check max_num_retries
    invalid_tasks: list[KatexCorrectorTask] = []

    for output in retrier_outputs:
        task = retrier_input_id_to_task[output.input_id]
        input = retrier_input_id_to_input[output.input_id]
        task.history.append((input, output))
        task.current_katex = output.katex
        task.num_retries += 1

        if task.num_retries > max_num_retries:
            raise RuntimeError(f'Max retries reached for KaTeX at index {task.index}')

        invalid_tasks.append(task)

# return corrected katexes in the original order
tasks.sort(key=lambda task: task.index)

# NOTE: result
[task.current_katex for task in tasks]

In [None]:
math_expressions = [
    MathExpression(
        math_article_id=math_article.id,
        math_expression_dataset_id=None,
        index_id=index.id,
        latex=node.latex,
        katex=katex,
        position=node.position,
        is_inline=node.is_inline,
    )
    for node, katex in zip(math_nodes, final_katexes)
]
index_to_katex = {i: math_expression.katex for i, math_expression in enumerate(math_expressions)}

In [None]:
from math_rag.infrastructure.utils import TemplateContextChunkerUtil, TemplateFormatterUtil


template_chunks = TemplateContextChunkerUtil.chunk(template, max_context_size=1000)
chunks = [
    TemplateFormatterUtil.format(chunk, index_to_katex, omit_wrapper=False)
    for chunk in template_chunks
]

In [None]:
from math_rag.application.models.assistants.inputs import (
    MathExpressionDescriptionWriter as AssistantInput,
)


input = AssistantInput(katex=..., context=...)

output = await math_expression_description_writer_assistant.assist(input)