In [None]:
!pip install datasets sacremoses sentencepiece

In [None]:
import re

class LaTeX_placeholder_swap:
    """
    Text-placeholder swap.
    """
    _LATEX_PATTERN = re.compile(r'\$(.*?)\$', flags=re.DOTALL)

    BASE = '<_LATEX_{}_>'

    def __init__(self):
      self.reset()

    def reset(self):
        self._counter = 0
        self._exprs = []

    def _replace_with_placeholder(self, match) -> str:
        key = self.BASE.format(self._counter)
        self._exprs.append(match.group(0))
        self._counter += 1
        return key

    def mask(self, text: str) -> str:
        return self._LATEX_PATTERN.sub(self._replace_with_placeholder, text)

    def unmask(self, text: str) -> str:
        for i, replaced in enumerate(self._exprs):
            text = text.replace(self.BASE.format(i), replaced)
        return text

In [None]:
from datasets import dataset

class LaTeXify:
    """
    LaTeXiFy dataset
    """
    FUNC_EXPR    = re.compile(r"\b(exp|sin|cos|tan|ln|log)(?:\s+|\()(?P<arg>[A-Za-z0-9]+)\)?", re.IGNORECASE)
    GREEK        = re.compile(r"\b(alpha|beta|gamma|delta|epsilon|zeta|eta|theta|iota|kappa|lambda|mu|nu|xi|omicron|pi|rho|sigma|tau|upsilon|phi|chi|psi|omega)\b", re.IGNORECASE)
    LIMIT        = re.compile(r"lim_?\{?([^}]+)\}?", re.IGNORECASE)
    SUM_EXPR     = re.compile(r"(?<!\\)sum_\{([^}]+)\}\^(?:\{([^}]+)\}|([A-Za-z0-9]))", re.IGNORECASE)
    INTEGRAL     = re.compile(r"(?<!\\)\bintegral\b", re.IGNORECASE)
    EXPR         = re.compile(r"(\d+(?:\s*[\*×·⋅\+\-\=÷]\s*\d+)+)")

    def __call__(self, ds: Dataset) -> Dataset:
        return ds.map(self._process_batch, batched=True)

    @staticmethod
    def _process_batch(batch: dict) -> dict:
        out = []
        for text in batch['text']:
            t = LaTeX_placeholder_swap.mask(text)
            t = LaTeXify.SUM_EXPR.sub(lambda m: f"$\\sum_{{{m.group(1)}}}^{{{m.group(2) or m.group(3)}}}$", t)
            t = LaTeXify.INTEGRAL.sub(lambda m: "$\\int$", t)
            t = LaTeXify.FUNC_EXPR.sub(lambda m: f"$\\{m.group(1).lower()}({m.group('arg')})$", t)
            t = LaTeXify.GREEK.sub(lambda m: f"$\\{m.group(1).lower()}$", t)
            t = LaTeXify.LIMIT.sub(lambda m: f"$\\lim_{{{m.group(1)}}}$", t)
            def expr_repl(m):
                expr = m.group(1)
                expr = expr.replace('*', ' \\times ')
                expr = expr.replace('×', ' \\times ')
                expr = expr.replace('·', ' \\cdot ')
                expr = expr.replace('⋅', ' \\cdot ')
                expr = expr.replace('+', ' + ')
                expr = expr.replace('-', ' - ')
                expr = expr.replace('=', ' = ')
                expr = expr.replace('÷', ' \\div ')
                expr = expr.replace(':', ' \\div ')
                expr = re.sub(r'\s+', ' ', expr.strip())
                return f'${expr}$'
            t = LaTeXify.EXPR.sub(expr_repl, t)
            t = re.sub(r"(\d+)\s*[·⋅]\s*([A-Za-z])", lambda m: f'${m.group(1)} \\cdot {m.group(2)}$', t)
            t = re.sub(r"([A-Za-z])\s*[·⋅]\s*(\d+)", lambda m: f'${m.group(1)} \\cdot {m.group(2)}$', t)
            t = re.sub(r"([A-Za-z])\s*[\*×]\s*(\d+)", lambda m: f'${m.group(1)} \\times {m.group(2)}$', t)
            t = re.sub(r"(\d+)\s*[\*×]\s*([A-Za-z])", lambda m: f'${m.group(1)} \\times {m.group(2)}$', t)
            t = re.sub(r"([A-Za-z])\s*[\*×]\s*([A-Za-z])", lambda m: f'${m.group(1)} \\times {m.group(2)}$', t)
            parts = t.split('$')
            for i in range(0, len(parts), 2):
                parts[i] = re.sub(r"(?<![A-Za-z0-9])(\d+)(?![A-Za-z0-9])", lambda m: f'${m.group(1)}$', parts[i])
            t = '$'.join(parts)
            t = LaTeX_placeholder_swap.unmask(t)
            out.append(t)
        return {'text': out}


