In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys

import nest_asyncio


sys.path.insert(0, os.path.abspath('..'))
nest_asyncio.apply()

In [2]:
import logging


logging.basicConfig(
    level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s'
)

In [3]:
from math_rag.infrastructure.containers import InfrastructureContainer


RESET = False

infrastructure_container = InfrastructureContainer()
infrastructure_container.init_resources()

math_article_seeder = infrastructure_container.math_article_seeder()
math_expression_seeder = infrastructure_container.math_expression_seeder()
math_expression_classification_repository = (
    infrastructure_container.math_expression_classification_seeder()
)
math_article_seeder.seed(reset=RESET)
await math_expression_seeder.seed(reset=RESET)
await math_expression_classification_repository.seed(reset=RESET)

math_article_repository = infrastructure_container.math_article_repository()
math_expression_repository = infrastructure_container.math_expression_repository()
math_expression_classification_repository = (
    infrastructure_container.math_expression_classification_repository()
)
google_file_repository = infrastructure_container.google_file_repository()

kc_assistant = infrastructure_container.kc_assistant()
katex_validator_service = infrastructure_container.katex_validator_service()
latex_parser_service = infrastructure_container.latex_parser_service()
latex_visitor_service = infrastructure_container.latex_visitor_service()
arxiv_searcher_service = infrastructure_container.arxiv_searcher_service()

2025-03-13 15:13:13,817 - INFO - file_cache is only supported with oauth2client<4.0.0


In [9]:
from IPython.display import Math, display
from pylatexenc.latexwalker import LatexMathNode

### Download

In [None]:
from math_rag.infrastructure.services.arxiv import MathCategory
from math_rag.infrastructure.utils import GzipExtractorUtil


results = [
    result
    for category in MathCategory
    for result in arxiv_searcher_service.search(category, 4)
]
files: dict[str, bytes] = {}

for result in results:
    arxiv_id = result.entry_id.split('/')[-1]
    src = await arxiv_searcher_service.get_src(arxiv_id)
    # NOTE: we dont need pdfs at the moment
    # pdf = await arxiv_searcher_service.get_pdf(arxiv_id)

    if src is None:
        continue

    src_name, src_bytes = src

    if not src_name or src_name.endswith('.pdf'):
        continue

    if src_name.endswith('.tar.gz'):
        extracted_files = GzipExtractorUtil.extract_tar_gz(src_bytes)
        files.update({f'{arxiv_id}/{k}': v for k, v in extracted_files.items()})

    elif src_name.endswith('.gz'):
        extracted_bytes = GzipExtractorUtil.extract_gz(src_bytes)
        files[f'{arxiv_id}.tex'] = extracted_bytes

    else:
        raise ValueError(f'Unexpected file extension {src_name}')

### Load

In [None]:
# NOTE: deprecated
# from zipfile import ZipFile
# folder_name, name = 'articles', 'articles_v1.zip'

# file_id = google_file_repository.get_file_id(name, folder_name)
# assert file_id is not None

# file_bytes = google_file_repository.get_file_by_id(file_id)

# with ZipFile(file_bytes, 'r') as zip_file:
#     files = {
#         name: zip_file.read(name)
#         for name in zip_file.namelist()
#         if not name.endswith('/')
#     }

In [6]:
from math_rag.core.models import MathArticle

In [11]:
math_articles = [MathArticle(name=name, bytes=bytes) for name, bytes in files.items()]
math_article_repository.insert_many(math_articles)

In [12]:
for name in math_article_repository.list_names()[:10]:
    print(name)

2502.15966v2.tex
2502.19681v1.tex
2502.20642v1.tex
2502.21222v1/THK_80_Kepler_laws_-_revised.tex
2502.21222v1/fig1.pdf
2502.21222v1/fig2.pdf
2502.21222v1/fig3.pdf
2503.01541v1/jams-l.cls
2503.01541v1/software-reviewing.bbl
2503.01541v1/software-reviewing.tex


### Parse

In [None]:
from copy import deepcopy

from math_rag.infrastructure.utils import FileReaderUtil


class LatexMathNodePlus(LatexMathNode):
    latex: str
    katex: str


file_names = math_article_repository.list_names()
file_names = [x for x in file_names if x.endswith('.tex')]
math_nodes: list[LatexMathNodePlus] = []


def append_math_node(math_node: LatexMathNode):
    latex = str(math_node.latex_verbatim())

    if 'tikz' not in latex and len(latex) < 1000:
        math_node_plus: LatexMathNodePlus = deepcopy(math_node)
        math_node_plus.__class__ = LatexMathNodePlus
        math_node_plus.latex = latex
        math_node_plus.katex = latex.strip('$')
        math_nodes.append(math_node_plus)


for name in file_names:
    math_article = math_article_repository.find_by_name(name)
    latex = FileReaderUtil.read(math_article.bytes)
    nodes = latex_parser_service.parse(latex)
    callbacks = {LatexMathNode: append_math_node}

    latex_visitor_service.visit(nodes, callbacks)

In [5]:
from math_rag.application.models.assistants import KCAssistantInput
from math_rag.core.models import MathExpression

In [None]:
katexes = [math_node.katex for math_node in math_nodes]
results = await katex_validator_service.batch_validate_many(katexes, 1000)
math_node_validation_result = list(zip(math_nodes, results))

In [None]:
math_expressions = [
    MathExpression(
        latex=math_node.latex,
        katex=math_node.katex if result.valid else None,
        position=math_node.pos,
        is_inline=math_node.displaytype == 'inline',
    )
    for math_node, result in math_node_validation_result
]

await math_expression_repository.batch_insert_many(math_expressions, batch_size=100)

In [None]:
inputs = [
    KCAssistantInput(katex=math_node.katex, error=result.error)
    for math_node, result in math_node_validation_result
    if not result.valid
]
batch_id = await kc_assistant.batch_assist_init(inputs)

2025-03-08 23:00:59,929 - INFO - HTTP Request: POST https://api.openai.com/v1/files "HTTP/1.1 200 OK"
2025-03-08 23:01:00,845 - INFO - HTTP Request: POST https://api.openai.com/v1/batches "HTTP/1.1 200 OK"
2025-03-08 23:01:00,855 - INFO - Batch batch_67ccbe1c87448190b76b85dd9ac2e151 created with status validating


In [None]:
# invalid_math_expression_ids = [
#     math_expression.id
#     for math_expression in math_expressions
#     if math_expression.katex == None
# ]

In [None]:
batch_id = 'batch_67ccbe1c87448190b76b85dd9ac2e151'
outputs = await kc_assistant.batch_assist_result(batch_id)
print(outputs is not None)
invalid_total = 0

if outputs:
    katexes = [output.katex for output in outputs]
    results = await katex_validator_service.batch_validate_many(katexes, 1000)
    # TODO update
    # TODO retry

    for x in results:
        if not x.valid:
            invalid_total += 1

invalid_total

# items_for_update = [
#     (id, katex)
#     for id, katex, result in zip(invalid_math_expression_ids, katexes, results)
#     if result.valid
# ]
# math_expression_repository.batch_update_katex(items_for_update)

2025-03-09 11:36:00,581 - INFO - HTTP Request: GET https://api.openai.com/v1/batches/batch_67ccbe1c87448190b76b85dd9ac2e151 "HTTP/1.1 200 OK"
2025-03-09 11:36:00,585 - INFO - Batch batch_67ccbe1c87448190b76b85dd9ac2e151 status completed
Batch batch_67ccbe1c87448190b76b85dd9ac2e151 requests - completed: 27224, failed: 0, total: 27224
2025-03-09 11:36:01,197 - INFO - HTTP Request: GET https://api.openai.com/v1/files/file-XnHttiJAHPF8S5CNT1BoqR/content "HTTP/1.1 200 OK"


ValidationError: 1 validation error for KCAssistantOutput
  Invalid JSON: expected value at line 1 column 1 [type=json_invalid, input_value='\\frac{1}{(\\sin(\\frac{...\frac{(2k-1)\\pi}{2^n})', input_type=str]
    For further information visit https://errors.pydantic.dev/2.10/v/json_invalid

In [13]:
for math_expression in math_expressions[:100]:
    math_display_object = Math(math_expression.katex)

    display(math_display_object)

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

### Display

In [None]:
for i, latex_math_node in enumerate(math_nodes[:100]):
    latex = latex_math_node.latex_verbatim()
    math_display_object = Math(latex)

    display(math_display_object)

### Correct

In [13]:
incorrect_katex = r'd\omega = \theta \w \omega'
error = r'KaTeX parse error: Undefined control sequence: \w at position 18: …omega = \theta \̲w̲ ̲\omega'

In [14]:
from math_rag.application.models.assistants import KCAssistantInput


input = KCAssistantInput(katex=incorrect_katex, error=error)
output = await kc_assistant.assist(input)
corrected_katex = output.katex
print(corrected_katex)
display(Math(corrected_katex))

2025-03-13 15:35:12,960 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"


d\omega = \theta \omega


<IPython.core.display.Math object>