In [46]:
import numpy as np
import itertools
from enum import Enum

### Define Latex Tokens

In [47]:
# give all tokens in latex expressions
# tokens for natural expression would be words
latex_token = ['{', '}', '\\', '-', '+', '^', '_', '(', ')', ',', ' ',
               'frac', 'times', 'ne', 'ge', 'le', 
               'int', 'lim', 'frac', 'd', 'to', 'h',
               '0', '1', '2', '3', '4', '5', '6', '7', '8', '9', 
               'x', 'y', 'z']

# write
with open('./data/latex_token.txt', 'w') as f:
    for token in latex_token:
        f.write(token + ';')

### Generate

In [48]:
class ParamType(Enum):
    ALL = 0
    VAR = 1
    CONST = 2

In [49]:
# format: (latex expression, number of parameters, natural expression, parameter type - optional)
BASE_RULES = [
    (["{0}-{1}"], 2, ["{0} minus {1}", "subtract {1} from {0}"]),
    (["{}^{}"], 2, ["{} to the {}", "{} to the power of {}"]),
    (["{}^2"], 1, ["{} squared", "{} to the power of two"]),
    (["\\frac{{{}}}{{{}}}"], 2, ["{} over {}", "{} divided by {}"]),
    (["{}+{}"], 2, ["{} plus {}", "add {} to {}"]),
    (["{} \\times {}"], 2, ["{} times {}", "{} multiplied by {}", "product of {} and {}"]), 
    (["{} \\ne {}"], 2, ["{} is not equal to {}"]), 
    (["{} \\ge {}"], 2, ["{} is greater or equal to {}"]),
    (["{} \\le {}"], 2, ["{} is less or equal to {}"])
]

In [50]:
# format: (latex expression, number of parameters, natural expression, parameter type - optional)
COMPLEX_RULES = [
    (["\\int_{{{0}}}^{{{1}}} h({2}) \\,d{3}"], 4, ["integral of h ( {2} ) from {0} to {1} with respect to {3}"], 
     [ParamType.CONST, ParamType.CONST, ParamType.VAR, ParamType.VAR]),
    
    (["\\lim_{{{0} \\to {1}}} h({2})"], 3, ["limit of h ( {2} ) as {0} goes to {1}"], 
     [ParamType.VAR, ParamType.CONST, ParamType.VAR]),
    
    (["\\frac{{dh({0})}}{{d{1}}}"], 2, ["derivative of h ( {0} ) with respect to {1}"], 
     [ParamType.VAR, ParamType.VAR]),
]

In [51]:
TEX_DIGITS = np.arange(10)
NAT_DIGITS = ["zero", "one", "two", "three", "four", "five", "six", "seven", "eight", "nine"]
VARS = ["x", "y", "z"]

DIGITS = list(zip(TEX_DIGITS, NAT_DIGITS))
VARS = list(zip(VARS, VARS))

ATOMS_PER_TYPE = {
    ParamType.ALL: DIGITS + VARS,
    ParamType.VAR: VARS,
    ParamType.CONST: DIGITS
}


def generate_examples(texs, num_params, naturals, param_types=[]):
    
    if not param_types: param_types = [ParamType.ALL] * num_params
    atoms_per_params = [ATOMS_PER_TYPE[typ] for typ in param_types]
    
    results = []
    
    for params in itertools.product(*(atoms_per_params)):
        tex_params, nat_params = zip(*params)
        
        tex_result = [entry.format(*tex_params) for entry in texs]
        natural_results = [entry.format(*nat_params) for entry in naturals]
        results += list(itertools.product(tex_result, natural_results))
    
    return results

In [52]:
# generate base expression pairs
base_examples = [example for rule in BASE_RULES for example in generate_examples(*rule)]

In [53]:
# generate complex expression pairs
complex_examples = [example for rule in COMPLEX_RULES for example in generate_examples(*rule)]

In [54]:
# give some nested expression pairs
# not automatic for now
nested_examples = [
    ("x^2 \\le x+5", "x to the two is less or equal to x plus five"), 
    ("\\frac{x+1}{x} \\ne 1", "x plus one over x is not equal to one"), 
]

### Write
- each pair is separated by `;` since `,` is a token in latex expression

In [55]:
# write base expression pairs
with open('./data/base_expression.txt', 'w') as f:
    f.write('latex;natural language\n')
    
    for pair in base_examples:
        f.write(pair[0] + ';' + pair[1] + '\n')

In [56]:
# write complicate expression
with open('./data/complex_expression.txt', 'w') as f:
    f.write('latex;natural language\n')
    
    for pair in complex_examples:
        f.write(pair[0] + ';' + pair[1] + '\n')

In [57]:
# write nested expression
with open('./data/nested_expression.txt', 'w') as f:
    f.write('latex;natural language\n')
    
    for pair in nested_examples:
        f.write(pair[0] + ';' + pair[1] + '\n')