In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys

import nest_asyncio


sys.path.insert(0, os.path.abspath('..'))
nest_asyncio.apply()

In [None]:
import logging


logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# logging.getLogger('pylatexenc.latexwalker').setLevel(logging.ERROR)
# logging.getLogger('httpx').setLevel(logging.WARNING)
# logging.getLogger('openai').setLevel(logging.WARNING)

In [None]:
from math_rag.infrastructure.containers import InfrastructureContainer


RESET = False

# containers
infrastructure_container = InfrastructureContainer()
infrastructure_container.init_resources()
infrastructure_container.wire(modules=[__name__])

application_container = infrastructure_container.application_container()
application_container.init_resources()
application_container.wire(modules=[__name__])

# seed
for object_seeder in infrastructure_container.object_seeders():
    object_seeder.seed(reset=RESET)

for document_seeder in infrastructure_container.document_seeders():
    await document_seeder.seed(reset=RESET)

# index
for document_indexer in infrastructure_container.document_indexers():
    await document_indexer.index(reset=RESET)

2025-06-01 17:22:08,776 - INFO - PyTorch version 2.6.0 available.


In [4]:
from uuid import UUID

from math_rag.core.enums import MathExpressionDatasetBuildStage
from math_rag.core.models import MathExpressionDataset


math_expression_dataset_builder_service = (
    application_container.math_expression_dataset_builder_service()
)
math_expression_dataset_repository = application_container.math_expression_dataset_repository()


current_dataset = MathExpressionDataset(
    # build_from_dataset_id=UUID('dd5b1e2d-dc0c-4ae9-9f24-2fd2d37f42fa'),
    # build_from_stage=MathExpressionDatasetBuildStage.LOAD_MATH_EXPRESSION_SAMPLES,
)
await math_expression_dataset_repository.insert_one(current_dataset)
await math_expression_dataset_builder_service.build(current_dataset)

2025-06-01 17:22:14,288 - INFO - Dataset 4dd82c08-d628-43e6-b4e5-1724fbde3e52 build started
2025-06-01 17:22:14,288 - INFO - Dataset 4dd82c08-d628-43e6-b4e5-1724fbde3e52 build loading math articles...
2025-06-01 17:22:14,289 - INFO - Dataset 4dd82c08-d628-43e6-b4e5-1724fbde3e52 build stage updated to MathExpressionDatasetBuildStage.LOAD_MATH_ARTICLES
2025-06-01 17:22:14,289 - INFO - Requesting page (first: True, try: 0): https://export.arxiv.org/api/query?search_query=cat%3Amath.GM&id_list=&sortBy=submittedDate&sortOrder=descending&start=0&max_results=100
2025-06-01 17:22:21,641 - INFO - Got first page: 100 of 4306 total results
2025-06-01 17:22:21,719 - INFO - HTTP Request: GET https://arxiv.org/src/2505.23238v1 "HTTP/1.1 200 OK"
2025-06-01 17:22:21,874 - INFO - MathArticleDatasetLoaderService loaded 10 math articles
2025-06-01 17:22:21,897 - INFO - MathArticleDatasetLoaderService 10 math articles in total
2025-06-01 17:22:21,898 - INFO - Dataset 4dd82c08-d628-43e6-b4e5-1724fbde3e52 b

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/3 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

README.md:   0%|          | 0.00/3.12k [00:00<?, ?B/s]

2025-06-01 17:27:20,154 - INFO - MathExpressionDatasetPublisherService published 2791 math expression samples
2025-06-01 17:27:20,168 - INFO - Dataset 4dd82c08-d628-43e6-b4e5-1724fbde3e52 published
2025-06-01 17:27:20,170 - INFO - Dataset 4dd82c08-d628-43e6-b4e5-1724fbde3e52 build finished


### Index

### Display

In [None]:
from IPython.display import Math, display


for i, latex_math_node in enumerate(math_nodes[:100]):
    latex = latex_math_node.latex_verbatim()
    math_display_object = Math(latex)

    display(math_display_object)

### Correct

In [6]:
incorrect_katex = r'd\omega = \theta \w \omega'
error = (
    r'KaTeX parse error: Undefined control sequence: \w at position 18: …omega = \theta \̲w̲ ̲\omega'
)

In [7]:
from math_rag.application.models.assistants import KatexCorrectorAssistantInput


input = KatexCorrectorAssistantInput(katex=incorrect_katex, error=error)

In [None]:
katex_corrector_assistant = application_container.katex_corrector_assistant()

output = await katex_corrector_assistant.assist(input)
corrected_katex = output.katex
print(corrected_katex)
display(Math(corrected_katex))

In [None]:
import copy


inputs = [copy.deepcopy(input) for _ in range(5000)]
outputs = await katex_corrector_assistant.concurrent_assist(inputs)

### Dataset

#### Upload

In [4]:
from math_rag.application.models.datasets import (
    MathExpressionDataset,
    MathExpressionSample,
)
from math_rag.core.enums import MathExpressionLabelEnum


samples = [
    MathExpressionSample(
        latex=f'x + {i} = 5',
        label=MathExpressionLabelEnum.EQUALITY,
    )
    for i in range(10)
]
dataset = MathExpressionDataset(samples)

In [5]:
from math_rag.application.assistants.prompts import MATH_EXPRESSION_LABELER_PROMPT
from math_rag.application.models.datasets import (
    DatasetMetadataFile,
    DatasetSplitSettings,
)


settings = DatasetSplitSettings(train_ratio=0.8, validate_ratio=0.1, test_ratio=0.1, seed=42)

json_str = MATH_EXPRESSION_LABELER_PROMPT.model_dump_json(indent=4)
content = json_str.encode('utf-8')
metadata_file = DatasetMetadataFile(name='prompt.json', content=content)

In [6]:
dataset_publisher_service = infrastructure_container.dataset_publisher_service()
dataset_publisher_service.publish(dataset, MathExpressionSample, settings, metadata_file)

2025-05-14 21:12:31,414 - INFO - Dataset KebabSeller/mathexpressiondataset already exists


Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

Uploading the dataset shards:   0%|          | 0/1 [00:00<?, ?it/s]

Creating parquet from Arrow format:   0%|          | 0/1 [00:00<?, ?ba/s]

#### Download

In [7]:
from datasets import load_dataset
from datasets.download import DownloadConfig
from decouple import config


HF_USERNAME = config('HF_USERNAME', default=None)
HF_TOKEN = config('HF_TOKEN', default=None)

download_config = DownloadConfig(
    max_retries=3,
    disable_tqdm=True,
)

dataset_dict = load_dataset(
    path=f'{HF_USERNAME}/mathexpressiondataset',
    split=None,
    download_config=download_config,
    token=HF_TOKEN,
    trust_remote_code=True,
)

README.md:   0%|          | 0.00/641 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/1.28k [00:00<?, ?B/s]

validate-00000-of-00001.parquet:   0%|          | 0.00/1.24k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/1.24k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/8 [00:00<?, ? examples/s]

Generating validate split:   0%|          | 0/1 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1 [00:00<?, ? examples/s]

In [8]:
from typing import cast

from datasets import ClassLabel


class_label = cast(ClassLabel, dataset_dict['train'].features['label'])
class_label.names

['equality', 'inequality', 'constant', 'variable', 'other']

### OpenAI finish reasons

In [4]:
response_with_length_finish_reason = {
    'id': 'batch_req_68338415d224819095e7e42c4aeac8f8',
    'custom_id': 'b9b237e5-d1c8-4348-9c90-f29e998228a6',
    'response': {
        'status_code': 200,
        'request_id': 'a12fbd18fe2366571f083c2b6e707c96',
        'body': {
            'id': 'chatcmpl-BbB9Kx0P0WJlR2r4I1SuRkYa2Mvn0',
            'object': 'chat.completion',
            'created': 1748200694,
            'model': 'gpt-4.1-nano-2025-04-14',
            'choices': [
                {
                    'index': 0,
                    'message': {
                        'role': 'assistant',
                        'content': 'some too long content...',
                        'refusal': None,
                        'annotations': [],
                    },
                    'logprobs': None,
                    'finish_reason': 'length',
                }
            ],
            'usage': {
                'prompt_tokens': 285,
                'completion_tokens': 1024,
                'total_tokens': 1309,
                'prompt_tokens_details': {'cached_tokens': 0, 'audio_tokens': 0},
                'completion_tokens_details': {
                    'reasoning_tokens': 0,
                    'audio_tokens': 0,
                    'accepted_prediction_tokens': 0,
                    'rejected_prediction_tokens': 0,
                },
            },
            'service_tier': 'default',
            'system_fingerprint': 'fp_eede8f0d45',
        },
    },
    'error': None,
}

In [6]:
from openai import NOT_GIVEN
from openai.lib._parsing._completions import (
    parse_chat_completion,
)
from openai.types.chat import ChatCompletion


chat_completion = ChatCompletion(**response_with_length_finish_reason['response']['body'])

for choice in chat_completion.choices:
    if choice.finish_reason == 'length' or choice.finish_reason == 'content_filter':
        # TODO
        print(choice.finish_reason)
        pass

target = parse_chat_completion(
    response_format=NOT_GIVEN,
    input_tools=NOT_GIVEN,
    chat_completion=chat_completion,
)

length


LengthFinishReasonError: Could not parse response content as the length limit was reached - CompletionUsage(completion_tokens=1024, prompt_tokens=285, total_tokens=1309, completion_tokens_details=CompletionTokensDetails(accepted_prediction_tokens=0, audio_tokens=0, reasoning_tokens=0, rejected_prediction_tokens=0), prompt_tokens_details=PromptTokensDetails(audio_tokens=0, cached_tokens=0))