In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys

import nest_asyncio


sys.path.insert(0, os.path.abspath('..'))
nest_asyncio.apply()

In [2]:
import logging


logging.basicConfig(
    level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s'
)

In [3]:
from math_rag.infrastructure.containers import InfrastructureContainer


RESET = False

infrastructure_container = InfrastructureContainer()
infrastructure_container.init_resources()

math_article_seeder = infrastructure_container.math_article_seeder()
math_expression_seeder = infrastructure_container.math_expression_seeder()
math_expression_prediction_seeder = (
    infrastructure_container.math_expression_prediction_seeder()
)
math_article_seeder.seed(reset=RESET)
await math_expression_seeder.seed(reset=RESET)
await math_expression_prediction_seeder.seed(reset=RESET)

math_article_repository = infrastructure_container.math_article_repository()
math_expression_repository = infrastructure_container.math_expression_repository()
math_expression_prediction_repository = (
    infrastructure_container.math_expression_prediction_repository()
)
google_file_repository = infrastructure_container.google_file_repository()

katex_correction_assistant = infrastructure_container.katex_correction_assistant()
katex_validator_service = infrastructure_container.katex_validator_service()
latex_parser_service = infrastructure_container.latex_parser_service()
latex_visitor_service = infrastructure_container.latex_visitor_service()
arxiv_searcher_service = infrastructure_container.arxiv_searcher_service()

2025-03-05 15:43:08,452 - INFO - file_cache is only supported with oauth2client<4.0.0


In [4]:
from IPython.display import Math, display
from pylatexenc.latexwalker import LatexMathNode

### Download

In [None]:
from math_rag.infrastructure.services.arxiv import MathCategory
from math_rag.infrastructure.utils import GzipExtractorUtil


results = [
    result
    for category in MathCategory
    for result in arxiv_searcher_service.search(category, 4)
]
files: dict[str, bytes] = {}

for result in results:
    arxiv_id = result.entry_id.split('/')[-1]
    src = await arxiv_searcher_service.get_src(arxiv_id)
    # NOTE: we dont need pdfs at the moment
    # pdf = await arxiv_searcher_service.get_pdf(arxiv_id)

    if src is None:
        continue

    src_name, src_bytes = src

    if not src_name:
        continue

    if src_name.endswith('.tar.gz'):
        extracted_files = GzipExtractorUtil.extract_tar_gz(src_bytes)
        files.update({f'{arxiv_id}/{k}': v for k, v in extracted_files.items()})

    elif src_name.endswith('.gz'):
        extracted_bytes = GzipExtractorUtil.extract_gz(src_bytes)
        files[f'{arxiv_id}.tex'] = extracted_bytes

    else:
        raise ValueError(f'Unexpected file extension {src_name}')

### Load

In [None]:
# NOTE: deprecated
# from zipfile import ZipFile
# folder_name, name = 'articles', 'articles_v1.zip'

# file_id = google_file_repository.get_file_id(name, folder_name)
# assert file_id is not None

# file_bytes = google_file_repository.get_file_by_id(file_id)

# with ZipFile(file_bytes, 'r') as zip_file:
#     files = {
#         name: zip_file.read(name)
#         for name in zip_file.namelist()
#         if not name.endswith('/')
#     }

In [8]:
from math_rag.core.models import MathArticle


math_articles = [MathArticle(name=name, bytes=bytes) for name, bytes in files.items()]
math_article_repository.insert_math_articles(math_articles)

In [10]:
for name in math_article_repository.list_math_article_names()[:10]:
    print(name)

2502.15966v2.tex
2502.18319v1.tex
2502.19069v1/article_FGT_Hermite_original.tex
2502.19069v1/image1.png
2502.19069v1/image2.png
2502.19681v1.tex
2502.20618v1.tex
2502.20642v1.tex
2502.20887v1.tex
2502.21015v1.tex


### Parse

In [None]:
from math_rag.infrastructure.utils import FileReaderUtil


file_names = math_article_repository.list_math_article_names()
file_names = [x for x in file_names if x.endswith('.tex')]

math_nodes: list[LatexMathNode] = []


def append_math_node(math_node: LatexMathNode):
    latex = math_node.latex_verbatim()

    if 'tikz' not in latex and len(latex) < 1000:
        math_nodes.append(math_node)


for name in file_names:
    math_article = math_article_repository.get_math_article_by_name(name)
    latex = FileReaderUtil.read(math_article.bytes)
    nodes = latex_parser_service.parse(latex)
    callbacks = {LatexMathNode: append_math_node}

    latex_visitor_service.visit(nodes, callbacks)

In [None]:
from math_rag.core.models import MathExpression


math_expressions: list[MathExpression] = []
batch_size = 100

for i in range(0, len(math_nodes), batch_size):
    math_node_batch = math_nodes[i : i + batch_size]
    latex_batch = [str(math_node.latex_verbatim()) for math_node in math_node_batch]
    katex_batch = [latex.strip('$') for latex in latex_batch]
    results = await katex_validator_service.validate_many(katex_batch)

    for math_node, katex, result in zip(math_node_batch, katex_batch, results):
        if not result.valid:
            try:
                katex = await katex_correction_assistant.correct(katex, result.error)
            except Exception as e:
                katex = None
                print(e)

        math_expression = MathExpression(
            latex=latex,
            katex=katex,
            position=math_node.pos,
            is_inline=math_node.displaytype == 'inline',
        )
        math_expressions.append(math_expression)

In [None]:
# TODO insert in batches in previous loop
await math_expression_repository.insert_math_expressions(math_expressions)

In [13]:
for math_expression in math_expressions[:100]:
    math_display_object = Math(math_expression.katex)

    display(math_display_object)

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

<IPython.core.display.Math object>

### Display

In [None]:
for i, latex_math_node in enumerate(math_nodes[:100]):
    latex = latex_math_node.latex_verbatim()
    math_display_object = Math(latex)

    display(math_display_object)

### Classify

In [None]:
from decouple import config

from math_rag.infrastructure.inference import LLM


OPENAI_BASE_URL = config('OPENAI_BASE_URL')
OPENAI_API_KEY = config('OPENAI_API_KEY')

In [14]:
llm = LLM(model='gpt-4o', base_url=OPENAI_BASE_URL, api_key=OPENAI_API_KEY)

In [5]:
# math_expr = math_nodes[13].latex_verbatim()  # 13
math_expr = '$\\mu(x)=\\frac{1}{2^n}$'

In [None]:
prompt = f"""
You are a mathematical expression classifier.
Given a mathematical expression, classify it in one of 4 given classes:
- constant
- variable
- formula
- other

Return a class only!

Mathematical expression:
{math_expr}

Class:
"""

In [7]:
from openai import NOT_GIVEN
from openai.types.chat import ChatCompletion


def get_prompt(math_expr: str) -> str:
    return f"""
You are a mathematical expression classifier.
Given a mathematical expression, classify it in a single class regarding STRUCTURE of the expression.
Class must be a single word.

Return a class only!

Mathematical expression:
{math_expr}

Class:
"""


use_json = False


async def get_completion(prompt: str) -> ChatCompletion:
    return await llm.client.chat.completions.create(
        model='gpt-4o',
        messages=[{'role': 'user', 'content': prompt}],
        response_format={'type': 'json_object'} if use_json else NOT_GIVEN,
        logprobs=True,
        temperature=0.0,
        top_logprobs=5,
    )

In [8]:
math_expressions_by_category = (
    await math_expression_repository.get_math_expressions_by_category(32)
)

In [None]:
from math_rag.core.models import MathExpressionClassification


completions: list[ChatCompletion] = []
predictions: list[MathExpressionClassification] = []

from time import sleep


for math_expression in math_expressions_by_category:
    prompt = get_prompt(math_expression.latex)
    completion = await get_completion(prompt)

    prediction = MathExpressionClassification(
        math_expression_id=math_expression.id,
        value=completion.choices[0].message.content,
    )
    completions.append(completion)
    predictions.append(prediction)

    # await math_expression_prediction_repository.insert_math_expression_predictions([prediction])
    sleep(1)

In [None]:
await math_expression_prediction_repository.insert_math_expression_classifications(
    predictions
)

In [15]:
import json


with open('../tmp/completions.json', 'w') as json_file:
    json.dump([x.to_dict() for x in completions], json_file, indent=4)

In [None]:
from uuid import UUID


class UUIDEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, UUID):
            return str(obj)

        return super().default(obj)


with open('../tmp/predictions.json', 'w') as json_file:
    json.dump(
        [x.model_dump() for x in predictions], json_file, indent=4, cls=UUIDEncoder
    )

In [None]:
completion = await get_completion(prompt)

completion.choices[0].message.content

'Function'

In [None]:
import numpy as np


for math_node in completion.choices[0].logprobs.content:
    for y in math_node.top_logprobs:
        print(f'"{y.token}": {np.exp(y.logprob)}')

    print('------')
    print(math_node.token)
    print(math_node.logprob)
    print(np.exp(math_node.logprob))

"formula": 0.9999996871837189
" formula": 9.931194312156244e-08
"Formula": 7.734421907141565e-08
"_formula": 9.237449661970594e-09
"form": 2.061153622438558e-09
------
formula
-3.1281633e-07
0.9999996871837189


In [None]:
# TODO
# - description for each class
# - how to determine classes?
# - do names need to take a single token?

### Analyze

In [5]:
from enum import Enum


class MathExpressionCategory(str, Enum):
    EQUALITY = 'equality'
    INEQUALITY = 'inequality'
    CONSTANT = 'constant'
    VARIABLE = 'variable'
    OTHER = 'other'

In [None]:
predictions = (
    await math_expression_prediction_repository.get_math_expression_classifications(
        1024
    )
)

In [11]:
expressions = [
    await math_expression_repository.get_math_expression_by_id(
        prediction.math_expression_id
    )
    for prediction in predictions
]

In [None]:
for expression, prediction in zip(expressions, predictions):
    try:
        result = await katex_validator_service.validate(expression.latex.strip('$'))
        math_display_object = Math(expression.latex)

        print(prediction.value)
        print(expression.latex)
        display(math_display_object)
        print(result.valid, result.error)
        print('-----')

    except Exception as e:
        print(f'skipping {expression.id}')
        print(e)

### Correct

In [37]:
incorrect_katex = r'd\omega = \theta \w \omega'
error = r'KaTeX parse error: Undefined control sequence: \w at position 18: …omega = \theta \̲w̲ ̲\omega'

In [38]:
corrected_katex = await katex_correction_assistant.correct(incorrect_katex, error)
print(corrected_katex)
display(Math(corrected_katex))

2025-03-03 19:50:53,083 - INFO - HTTP Request: POST http://localhost:3000/validate "HTTP/1.1 200 OK"
2025-03-03 19:50:54,043 - INFO - HTTP Request: POST https://api.openai.com/v1/chat/completions "HTTP/1.1 200 OK"
2025-03-03 19:50:54,053 - INFO - HTTP Request: POST http://localhost:3000/validate "HTTP/1.1 200 OK"


d\omega = \theta \omega


<IPython.core.display.Math object>