In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys

import nest_asyncio


sys.path.insert(0, os.path.abspath('..'))
nest_asyncio.apply()

In [2]:
import logging


logging.basicConfig(
    level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s'
)

In [3]:
from math_rag.infrastructure.containers import InfrastructureContainer


RESET = False

infrastructure_container = InfrastructureContainer()
infrastructure_container.init_resources()

math_article_seeder = infrastructure_container.math_article_seeder()
math_expression_seeder = infrastructure_container.math_expression_seeder()
math_expression_classification_repository = (
    infrastructure_container.math_expression_classification_seeder()
)
math_article_seeder.seed(reset=RESET)
await math_expression_seeder.seed(reset=RESET)
await math_expression_classification_repository.seed(reset=RESET)

math_article_repository = infrastructure_container.math_article_repository()
math_expression_repository = infrastructure_container.math_expression_repository()
math_expression_classification_repository = (
    infrastructure_container.math_expression_classification_repository()
)
google_file_repository = infrastructure_container.google_file_repository()

katex_corrector_assistant = infrastructure_container.katex_corrector_assistant()
katex_client = infrastructure_container.katex_client()
latex_parser_service = infrastructure_container.latex_parser_service()
latex_visitor_service = infrastructure_container.latex_visitor_service()
arxiv_client = infrastructure_container.arxiv_client()

2025-05-11 09:37:02,848 - INFO - file_cache is only supported with oauth2client<4.0.0


In [None]:
from IPython.display import Math, display

### Load

In [None]:
from math_rag.application.enums.arxiv import MathCategory


math_article_loader_service = infrastructure_container.math_article_loader_service()
await math_article_loader_service.load(MathCategory, 100)

### Parse

In [None]:
from math_rag.application.extensions import LatexMathRichNode


math_article_parser_service = infrastructure_container.math_article_parser_service()

file_names = [
    name for name in math_article_repository.list_names() if name.endswith('.tex')
]
math_nodes: list[LatexMathRichNode] = []

for name in file_names:
    math_article = math_article_repository.find_by_name(name)
    math_nodes.extend(math_article_parser_service.parse(math_article))

In [None]:
from math_rag.application.models.assistants import KatexCorrectorAssistantInput
from math_rag.core.models import MathExpression

In [None]:
katexes = [math_node.katex for math_node in math_nodes]
results = await katex_client.batch_validate_many(katexes, batch_size=1000)
math_node_validation_result = list(zip(math_nodes, results))

In [None]:
math_expressions = [
    MathExpression(
        latex=math_node.latex,
        katex=math_node.katex if result.valid else None,
        position=math_node.pos,
        is_inline=math_node.displaytype == 'inline',
    )
    for math_node, result in math_node_validation_result
]

await math_expression_repository.batch_insert_many(math_expressions, batch_size=100)

In [None]:
inputs = [
    KatexCorrectorAssistantInput(katex=math_node.katex, error=result.error)
    for math_node, result in math_node_validation_result
    if not result.valid
]
batch_id = await katex_corrector_assistant.batch_assist_init(inputs)

2025-03-08 23:00:59,929 - INFO - HTTP Request: POST https://api.openai.com/v1/files "HTTP/1.1 200 OK"
2025-03-08 23:01:00,845 - INFO - HTTP Request: POST https://api.openai.com/v1/batches "HTTP/1.1 200 OK"
2025-03-08 23:01:00,855 - INFO - Batch batch_67ccbe1c87448190b76b85dd9ac2e151 created with status validating


In [None]:
# invalid_math_expression_ids = [
#     math_expression.id
#     for math_expression in math_expressions
#     if math_expression.katex == None
# ]

In [None]:
batch_id = 'batch_67ccbe1c87448190b76b85dd9ac2e151'
outputs = await katex_corrector_assistant.batch_assist_result(batch_id)
print(outputs is not None)
invalid_total = 0

if outputs:
    katexes = [output.katex for output in outputs]
    results = await katex_client.batch_validate_many(katexes, batch_size=1000)
    # TODO update
    # TODO retry

    for x in results:
        if not x.valid:
            invalid_total += 1

invalid_total

# items_for_update = [
#     (id, katex)
#     for id, katex, result in zip(invalid_math_expression_ids, katexes, results)
#     if result.valid
# ]
# math_expression_repository.batch_update_katex(items_for_update)

In [None]:
for math_expression in math_expressions[:100]:
    math_display_object = Math(math_expression.katex)

    display(math_display_object)

### Display

In [None]:
for i, latex_math_node in enumerate(math_nodes[:100]):
    latex = latex_math_node.latex_verbatim()
    math_display_object = Math(latex)

    display(math_display_object)

### Correct

In [7]:
incorrect_katex = r'd\omega = \theta \w \omega'
error = r'KaTeX parse error: Undefined control sequence: \w at position 18: …omega = \theta \̲w̲ ̲\omega'

In [None]:
from math_rag.application.models.assistants import KatexCorrectorAssistantInput


input = KatexCorrectorAssistantInput(katex=incorrect_katex, error=error)

In [None]:
output = await katex_corrector_assistant.assist(input)
corrected_katex = output.katex
print(corrected_katex)
display(Math(corrected_katex))

2025-03-14 12:43:02,263 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


d\omega = \theta \omega


<IPython.core.display.Math object>

In [None]:
import copy


inputs = [copy.deepcopy(input) for _ in range(5000)]
outputs = await katex_corrector_assistant.concurrent_assist(inputs)

In [None]:
batch_id = await katex_corrector_assistant.batch_assist_init([input])

2025-03-14 11:50:03,775 - INFO - HTTP Request: POST https://api.openai.com/v1/files "HTTP/1.1 200 OK"
2025-03-14 11:50:04,129 - INFO - HTTP Request: POST https://api.openai.com/v1/batches "HTTP/1.1 200 OK"
2025-03-14 11:50:04,133 - INFO - Batch batch_67d409dc05b88190a834aa56e48ef515 created with status validating


In [None]:
batch_id = 'batch_67d409dc05b88190a834aa56e48ef515'
outputs = await katex_corrector_assistant.batch_assist_result(batch_id)

2025-03-14 12:39:38,088 - INFO - HTTP Request: GET https://api.openai.com/v1/batches/batch_67d409dc05b88190a834aa56e48ef515 "HTTP/1.1 200 OK"
2025-03-14 12:39:38,089 - INFO - Batch batch_67d409dc05b88190a834aa56e48ef515 status completed
Batch batch_67d409dc05b88190a834aa56e48ef515 requests - completed: 1, failed: 0, total: 1
2025-03-14 12:39:38,378 - INFO - HTTP Request: GET https://api.openai.com/v1/files/file-4w9g482KDizAEntZJMKm9j/content "HTTP/1.1 200 OK"
2025-03-14 12:39:38,764 - INFO - HTTP Request: GET https://api.openai.com/v1/files/file-44LDSWLLWKGGWgkpvxRdFY/content "HTTP/1.1 200 OK"
2025-03-14 12:39:39,143 - INFO - HTTP Request: DELETE https://api.openai.com/v1/files/file-4w9g482KDizAEntZJMKm9j "HTTP/1.1 200 OK"
2025-03-14 12:39:39,538 - INFO - HTTP Request: DELETE https://api.openai.com/v1/files/file-44LDSWLLWKGGWgkpvxRdFY "HTTP/1.1 200 OK"


### Dataset

In [None]:
from datasets import Dataset, DatasetDict, DatasetInfo


citation = """\
@misc{math-rag-dataset,
  author = {Your Name},
  title = {Demo Dataset},
  year = {2025},
  url = {https://huggingface.co/datasets/your-username/your-dataset-name}
}
"""

info = DatasetInfo(
    description='A small demo dataset with text and binary labels.',
    version='1.0.0',
    features={
        'text': {'dtype': 'string', '_type': 'Value'},
        'label': {'dtype': 'int64', '_type': 'Value'},
    },
    license='mit',
    citation=citation,
)

data = {
    'text': ['Hello, world!', 'Hugging Face is awesome!', "Let's upload this dataset."],
    'label': [0, 1, 0],
}

dataset = Dataset.from_dict(data)
dataset = dataset.shuffle(seed=42)

split_test = dataset.train_test_split(test_size=0.1, seed=42)
train_validation_dataset = split_test['train']
test_dataset = split_test['test']

split_validation = train_validation_dataset.train_test_split(test_size=0.1, seed=42)
train_dataset = split_validation['train']
validation_dataset = split_validation['test']

dataset_dict = DatasetDict(
    {'train': train_dataset, 'validation': validation_dataset, 'test': test_dataset}
)
dataset_dict.info = info
dataset_dict.push_to_hub(
    'your-username/your-dataset-name',
    private=True,
    token='your_hf_token',
)