In [1]:
%load_ext autoreload
%autoreload 2

import os
import sys

import nest_asyncio


sys.path.insert(0, os.path.abspath('..'))
nest_asyncio.apply()

In [2]:
import logging


logging.basicConfig(
    level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s'
)

In [8]:
from math_rag.infrastructure.containers import InfrastructureContainer


infrastructure_container = InfrastructureContainer()
infrastructure_container.init_resources()

math_article_seeder = infrastructure_container.math_article_seeder()
math_expression_seeder = infrastructure_container.math_expression_seeder()
math_article_seeder.seed()
await math_expression_seeder.seed()

math_article_repository = infrastructure_container.math_article_repository()
math_expression_repository = infrastructure_container.math_expression_repository()
google_file_repository = infrastructure_container.google_file_repository()

2025-02-27 20:32:37,413 - INFO - file_cache is only supported with oauth2client<4.0.0


In [4]:
import gzip
import shutil
import tarfile

from io import BytesIO
from pathlib import Path
from zipfile import ZipFile

from pylatexenc.latexwalker import LatexMathNode


ARTICLES_PATH = '../tmp/articles'

### Extract

In [None]:
def get_gzip_original_filename(file_path):
    with open(file_path, 'rb') as f:
        if f.read(2) != b'\x1f\x8b':
            return None
        f.read(1)
        flag = f.read(1)[0]
        f.read(4)
        f.read(1)
        f.read(1)
        orig_name = None
        if flag & 0x08:
            name_bytes = bytearray()
            while True:
                b = f.read(1)
                if not b or b == b'\x00':
                    break
                name_bytes.extend(b)
            try:
                orig_name = name_bytes.decode('utf-8')
            except UnicodeDecodeError:
                orig_name = name_bytes.decode('latin1')
        return orig_name


def extract_gz(file_path, dest_folder):
    orig_name = get_gzip_original_filename(file_path)
    if not orig_name:
        orig_name = os.path.splitext(os.path.basename(file_path))[0]
    dest_path = os.path.join(dest_folder, orig_name)
    with gzip.open(file_path, 'rb') as f_in, open(dest_path, 'wb') as f_out:
        shutil.copyfileobj(f_in, f_out)


def extract_tar_gz(file_path, dest_folder):
    with tarfile.open(file_path, 'r:gz') as tar:
        tar.extractall(path=dest_folder)


def process_subdir(subdir_path):
    files = os.listdir(subdir_path)
    pdf_files = {f for f in files if f.endswith('.pdf')}
    for pdf in pdf_files:
        base_name = pdf[:-4]
        gz_name = f'arXiv-{base_name}.gz'
        tar_gz_name = f'arXiv-{base_name}.tar.gz'
        gz_file = None
        if tar_gz_name in files:
            gz_file = tar_gz_name
        elif gz_name in files:
            gz_file = gz_name
        if gz_file:
            new_dir = os.path.join(subdir_path, base_name)
            os.makedirs(new_dir, exist_ok=True)
            shutil.move(os.path.join(subdir_path, pdf), new_dir)
            shutil.move(os.path.join(subdir_path, gz_file), new_dir)
            new_gz_path = os.path.join(new_dir, gz_file)
            if gz_file.endswith('.tar.gz'):
                extract_tar_gz(new_gz_path, new_dir)
            else:
                extract_gz(new_gz_path, new_dir)


def extract_all():
    for subdir in os.listdir(ARTICLES_PATH):
        subdir_path = os.path.join(ARTICLES_PATH, subdir)
        if os.path.isdir(subdir_path):
            process_subdir(subdir_path)


extract_all()

  tar.extractall(path=dest_folder)


In [9]:
def clean():
    for root, dirs, files in os.walk(ARTICLES_PATH):
        for file in files:
            if file.endswith('.gz'):
                os.remove(os.path.join(root, file))


clean()

### Load

In [5]:
folder_name = 'articles'
file_name = 'articles_v1.zip'  # NOTE

file_id = google_file_repository.get_file_id(file_name, folder_name)
assert file_id is not None

file_bytes = google_file_repository.get_file_by_id(file_id)

In [6]:
with ZipFile(file_bytes, 'r') as zip_file:
    files = {
        name: zip_file.read(name)
        for name in zip_file.namelist()
        if not name.endswith('/') and name.endswith('.tex')
    }

In [None]:
from math_rag.core.models import MathArticle


math_articles = [MathArticle(name=name, bytes=bytes) for name, bytes in files.items()]
math_article_repository.insert_math_articles(math_articles)

In [7]:
for file_name in math_article_repository.list_math_article_names():
    print(file_name)

### Parse

In [None]:
from math_rag.infrastructure.services import LatexService


latex_service = LatexService()

In [None]:
file_names = math_article_repository.list_math_article_names()
file_names = [x for x in file_names if x.endswith('.tex')]

math_nodes: list[LatexMathNode] = []
append_math_node = lambda x: math_nodes.append(x)
callbacks = {LatexMathNode: append_math_node}

for file_name in file_names:
    file_bytes = math_article_repository.get_math_article_by_name(file_name)
    latex = latex_service.read(file_bytes)
    nodes = latex_service.parse(latex)
    latex_service.traverse(nodes, callbacks)

In [None]:
import re

from math_rag.core.enums import MathCategory
from math_rag.core.models import MathExpression


math_expressions: list[MathExpression] = []

for math_node in math_nodes:
    category_slug = re.search(r'articles_v\d+/([^/]+)/', file_name).group(1)
    category = MathCategory.from_str(category_slug)
    math_expression = MathExpression(
        latex=math_node.latex_verbatim(),
        position=math_node.pos,
        is_inline=math_node.displaytype == 'inline',
        math_category=category,
    )
    math_expressions.append(math_expression)

In [None]:
await math_expression_repository.insert_math_expressions(math_expressions)

  await document_repo.create_collection(collection_name)


### Display

In [140]:
def fix_latex(latex_str: str):
    fixed = latex_str.replace('\\[', '$$').replace('\\]', '$$')
    fixed = fixed.replace('\\EE', '\\mathbb{E}')
    fixed = fixed.replace('\\II', '\\mathbb{I}')
    fixed = fixed.replace('\\Var', '\\mathrm{Var}')
    fixed = fixed.replace('\\HH', '\\mathbb{H}')
    fixed = fixed.replace('\\AND', '\\wedge')
    fixed = fixed.replace('\\OR', '\\vee')
    fixed = fixed.replace('\\mathbbm{1}', '\\mathbf{1}')
    fixed = fixed.replace('\\Maj', '\\mathrm{Maj}')
    fixed = fixed.replace('\\sgn', '\\operatorname{sgn}')
    fixed = fixed.replace('\\Tribus', '\\mathrm{Tribus}')
    fixed = fixed.replace('\\linebreak', '\\text{ }')
    fixed = fixed.replace('\\Prob', '\\mathbb{P}')
    fixed = fixed.replace('\\WW', '\\mathcal{W}')

    return fixed

In [None]:
from IPython.display import Math, display


for i, latex_math_node in enumerate(math_nodes[:100]):
    latex = latex_math_node.latex_verbatim()
    latex_fixed = fix_latex(latex)
    math_display_object = Math(latex_fixed)

    # print(i)
    # print(math_display_object._repr_latex_())

    display(math_display_object)

### Classify

In [2]:
from decouple import config

from math_rag.infrastructure.inference.llms import LLM


OPENAI_BASE_URL = config('OPENAI_BASE_URL')
OPENAI_API_KEY = config('OPENAI_API_KEY')

In [3]:
llm = LLM(model='gpt-4o-mini', base_url=OPENAI_BASE_URL, api_key=OPENAI_API_KEY)

In [5]:
# math_expr = math_nodes[13].latex_verbatim()  # 13
math_expr = '$\\mu(x)=\\frac{1}{2^n}$'

In [8]:
from openai import NOT_GIVEN


prompt = f"""
You are a mathematical expression classifier.
Given a mathematical expression, classify it in one of 4 given classes:
- constant
- variable
- formula
- other

Return a class only!

Mathematical expression:
{math_expr}

Class:
"""

prompt = f"""
You are a mathematical expression classifier.
Given a mathematical expression, classify it in a single class regarding STRUCTURE of the expression.
Class must be a single word.

Return a class only!

Mathematical expression:
{math_expr}

Class:
"""

use_json = False
completion = await llm.client.chat.completions.create(
    model=llm.model,
    messages=[{'role': 'user', 'content': prompt}],
    response_format={'type': 'json_object'} if use_json else NOT_GIVEN,
    logprobs=True,
    temperature=0.0,
    top_logprobs=5,
)
print(completion)

ChatCompletion(id='chatcmpl-B5b8A1KGrSNzsq45Ay2kJRr09oukR', choices=[Choice(finish_reason='stop', index=0, logprobs=ChoiceLogprobs(content=[ChatCompletionTokenLogprob(token='Function', bytes=[70, 117, 110, 99, 116, 105, 111, 110], logprob=-0.4679870307445526, top_logprobs=[TopLogprob(token='Function', bytes=[70, 117, 110, 99, 116, 105, 111, 110], logprob=-0.4679870307445526), TopLogprob(token='Equation', bytes=[69, 113, 117, 97, 116, 105, 111, 110], logprob=-1.217987060546875), TopLogprob(token='function', bytes=[102, 117, 110, 99, 116, 105, 111, 110], logprob=-2.592987060546875), TopLogprob(token='equ', bytes=[101, 113, 117], logprob=-6.342987060546875), TopLogprob(token='Formula', bytes=[70, 111, 114, 109, 117, 108, 97], logprob=-7.342987060546875)])], refusal=None), message=ChatCompletionMessage(content='Function', refusal=None, role='assistant', audio=None, function_call=None, tool_calls=None))], created=1740674070, model='gpt-4o-mini-2024-07-18', object='chat.completion', service_

In [9]:
completion.choices[0].message.content

'Function'

In [57]:
import numpy as np


for x in completion.choices[0].logprobs.content:
    for y in x.top_logprobs:
        print(f'"{y.token}": {np.exp(y.logprob)}')

    print('------')
    print(x.token)
    print(x.logprob)
    print(np.exp(x.logprob))

"formula": 0.9999996871837189
" formula": 9.931194312156244e-08
"Formula": 7.734421907141565e-08
"_formula": 9.237449661970594e-09
"form": 2.061153622438558e-09
------
formula
-3.1281633e-07
0.9999996871837189


In [None]:
# TODO
# - description for each class
# - how to determine classes?
# - do names need to take a single token?

In [None]:
from enum import Enum


class MathExpressionCategory(str, Enum):
    CONSTANT = 'constant'
    VARIABLE = 'variable'
    EQUALITY = 'equality'
    INEQUALITY = 'inequality'
    OTHER = 'other'