In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys

import nest_asyncio


sys.path.insert(0, os.path.abspath('..'))
nest_asyncio.apply()

In [2]:
import logging


logging.basicConfig(
    level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s'
)

In [3]:
from math_rag.infrastructure.containers import InfrastructureContainer


RESET = False

infrastructure_container = InfrastructureContainer()
infrastructure_container.init_resources()

math_article_seeder = infrastructure_container.math_article_seeder()
math_expression_seeder = infrastructure_container.math_expression_seeder()
math_expression_classification_repository = (
    infrastructure_container.math_expression_classification_seeder()
)
math_article_seeder.seed(reset=RESET)
await math_expression_seeder.seed(reset=RESET)
await math_expression_classification_repository.seed(reset=RESET)

math_article_repository = infrastructure_container.math_article_repository()
math_expression_repository = infrastructure_container.math_expression_repository()
math_expression_classification_repository = (
    infrastructure_container.math_expression_classification_repository()
)
google_file_repository = infrastructure_container.google_file_repository()

katex_corrector_assistant = infrastructure_container.katex_corrector_assistant()
katex_client = infrastructure_container.katex_client()
latex_parser_service = infrastructure_container.latex_parser_service()
latex_visitor_service = infrastructure_container.latex_visitor_service()
arxiv_client = infrastructure_container.arxiv_client()

2025-05-14 19:52:59,204 - INFO - PyTorch version 2.6.0 available.
2025-05-14 19:52:59,536 - INFO - file_cache is only supported with oauth2client<4.0.0


In [None]:
from IPython.display import Math, display

### Load

In [None]:
from math_rag.application.enums.arxiv import MathCategory


math_article_loader_service = infrastructure_container.math_article_loader_service()
await math_article_loader_service.load(MathCategory, 100)

### Parse

In [None]:
from pylatexenc.latexwalker import LatexMathNode


math_article_parser_service = infrastructure_container.math_article_parser_service()

file_names = [
    name for name in math_article_repository.list_names() if name.endswith('.tex')
]
math_nodes: list[LatexMathNode] = []

for name in file_names:
    math_article = math_article_repository.find_by_name(name)
    math_nodes.extend(math_article_parser_service.parse(math_article))

In [None]:
from math_rag.application.models.assistants import KatexCorrectorAssistantInput
from math_rag.core.models import MathExpression

In [None]:
katexes = [str(math_node.latex_verbatim()).strip('$') for math_node in math_nodes]
results = await katex_client.batch_validate_many(katexes, batch_size=1000)
math_node_validation_result = list(zip(math_nodes, results))

In [None]:
math_expressions = [
    MathExpression(
        latex=str(math_node.latex_verbatim()),
        katex=str(math_node.latex_verbatim()).strip('$') if result.valid else None,
        position=math_node.pos,
        is_inline=math_node.displaytype == 'inline',
    )
    for math_node, result in math_node_validation_result
]

await math_expression_repository.batch_insert_many(math_expressions, batch_size=100)

In [None]:
inputs = [
    KatexCorrectorAssistantInput(
        katex=str(math_node.latex_verbatim()).strip('$'), error=result.error
    )
    for math_node, result in math_node_validation_result
    if not result.valid
]
batch_id = await katex_corrector_assistant.batch_assist_init(inputs)

2025-03-08 23:00:59,929 - INFO - HTTP Request: POST https://api.openai.com/v1/files "HTTP/1.1 200 OK"
2025-03-08 23:01:00,845 - INFO - HTTP Request: POST https://api.openai.com/v1/batches "HTTP/1.1 200 OK"
2025-03-08 23:01:00,855 - INFO - Batch batch_67ccbe1c87448190b76b85dd9ac2e151 created with status validating


In [None]:
# invalid_math_expression_ids = [
#     math_expression.id
#     for math_expression in math_expressions
#     if math_expression.katex == None
# ]

In [None]:
outputs = await katex_corrector_assistant.batch_assist(inputs)
print(outputs is not None)
invalid_total = 0

if outputs:
    katexes = [output.katex for output in outputs]
    results = await katex_client.batch_validate_many(katexes, batch_size=1000)
    # TODO update
    # TODO retry

    for x in results:
        if not x.valid:
            invalid_total += 1

invalid_total

# items_for_update = [
#     (id, katex)
#     for id, katex, result in zip(invalid_math_expression_ids, katexes, results)
#     if result.valid
# ]
# math_expression_repository.batch_update_katex(items_for_update)

In [None]:
for math_expression in math_expressions[:100]:
    math_display_object = Math(math_expression.katex)

    display(math_display_object)

### Display

In [None]:
for i, latex_math_node in enumerate(math_nodes[:100]):
    latex = latex_math_node.latex_verbatim()
    math_display_object = Math(latex)

    display(math_display_object)

### Correct

In [7]:
incorrect_katex = r'd\omega = \theta \w \omega'
error = r'KaTeX parse error: Undefined control sequence: \w at position 18: …omega = \theta \̲w̲ ̲\omega'

In [None]:
from math_rag.application.models.assistants import KatexCorrectorAssistantInput


input = KatexCorrectorAssistantInput(katex=incorrect_katex, error=error)

In [None]:
output = await katex_corrector_assistant.assist(input)
corrected_katex = output.katex
print(corrected_katex)
display(Math(corrected_katex))

2025-03-14 12:43:02,263 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


d\omega = \theta \omega


<IPython.core.display.Math object>

In [None]:
import copy


inputs = [copy.deepcopy(input) for _ in range(5000)]
outputs = await katex_corrector_assistant.concurrent_assist(inputs)

### Dataset

#### Upload

In [4]:
from math_rag.application.models.datasets import (
    MathExpressionDataset,
    MathExpressionSample,
)
from math_rag.core.enums import MathExpressionLabelEnum


samples = [
    MathExpressionSample(
        latex=f'x + {i} = 5',
        label=MathExpressionLabelEnum.EQUALITY,
    )
    for i in range(10)
]
dataset = MathExpressionDataset(samples)

In [5]:
from math_rag.application.assistants.prompts import MATH_EXPRESSION_LABELER_PROMPT
from math_rag.application.models.datasets import (
    DatasetMetadataFile,
    DatasetSplitSettings,
)


settings = DatasetSplitSettings(
    train_ratio=0.8, validate_ratio=0.1, test_ratio=0.1, seed=42
)

json_str = MATH_EXPRESSION_LABELER_PROMPT.model_dump_json(indent=4)
content = json_str.encode('utf-8')
metadata_file = DatasetMetadataFile(name='prompt.json', content=content)

In [None]:
dataset_publisher_service = infrastructure_container.dataset_publisher_service()
dataset_publisher_service.publish(
    dataset, MathExpressionSample, settings, metadata_file
)

#### Download

In [7]:
from datasets import load_dataset
from datasets.download import DownloadConfig
from decouple import config


HF_USERNAME = config('HF_USERNAME', default=None)
HF_TOKEN = config('HF_TOKEN', default=None)

download_config = DownloadConfig(
    max_retries=3,
    disable_tqdm=True,
)

dataset_dict = load_dataset(
    path=f'{HF_USERNAME}/mathexpressiondataset',
    split=None,
    cache_dir=None,  # TODO: bind to apptainer
    download_config=download_config,
    token=HF_TOKEN,
    trust_remote_code=True,
)

README.md:   0%|          | 0.00/672 [00:00<?, ?B/s]

train-00000-of-00001.parquet:   0%|          | 0.00/2.04k [00:00<?, ?B/s]

validate-00000-of-00001.parquet:   0%|          | 0.00/1.71k [00:00<?, ?B/s]

test-00000-of-00001.parquet:   0%|          | 0.00/1.71k [00:00<?, ?B/s]

Generating train split:   0%|          | 0/8 [00:00<?, ? examples/s]

Generating validate split:   0%|          | 0/1 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/1 [00:00<?, ? examples/s]

In [8]:
from typing import cast

from datasets import ClassLabel


class_label = cast(ClassLabel, dataset_dict['train'].features['label'])
class_label.names

['equality', 'inequality', 'constant', 'variable', 'other']

In [None]:
import json

from pathlib import Path

from decouple import config
from huggingface_hub import hf_hub_download


HF_USERNAME = config('HF_USERNAME', default=None)
HF_TOKEN = config('HF_TOKEN', default=None)

repo_id = ...
path = hf_hub_download(
    repo_id=repo_id,
    filename=metadata_file.name,
    repo_type='dataset',
    token=HF_TOKEN,
    cache_dir=...,  # TODO bind
)
content_bytes = Path(path).read_bytes()
content = json.loads(content_bytes)
content