In [None]:
from pathlib import Path
import sys
import traceback

import numpy as np
import sympy as sp
import torch

In [None]:
module_path = str(Path('..').resolve())
if module_path not in sys.path:
    sys.path.append(module_path)
    
from symbolicmathematics.utils import load_settings, to_cuda
from symbolicmathematics.envs import build_env
from symbolicmathematics.envs.sympy_utils import simplify
from symbolicmathematics.model import build_modules

## Build environment / Reload model

In [None]:
# trained model, e.g. "wget https://dl.fbaipublicfiles.com/SymbolicMathematics/models/fwd_bwd.pth"
cpu = True
model_name = 'fwd_bwd'

In [None]:
params = load_settings(cpu=cpu, model_name=model_name)

In [None]:
env = build_env(params)
x = env.local_dict['x']

In [None]:
modules = build_modules(env, params)
encoder = modules['encoder']
decoder = modules['decoder']

## Start from a function F, compute its derivative f = F', and try to recover F from f

In [None]:
# here you can modify the integral function the model has to predict, F
F_infix = 'x * tan(exp(x)/x)'
# F_infix = 'x * cos(x**2) * tan(x)'
# F_infix = 'cos(x**2 * exp(x * cos(x)))'
# F_infix = 'ln(cos(x + exp(x)) * sin(x**2 + 2) * exp(x) / x)'

In [None]:
# F (integral, that the model will try to predict)
F = sp.S(F_infix, locals=env.local_dict)
F

In [None]:
# f (F', that the model will take as input)
f = F.diff(x)
f

### Compute prefix representations

In [None]:
F_prefix = env.sympy_to_prefix(F)
f_prefix = env.sympy_to_prefix(f)
print(f"F prefix: {' '.join(F_prefix)}")
print(f"f prefix: {' '.join(f_prefix)}")

### Encode input

In [None]:
x1_prefix = env.clean_prefix(['sub', 'derivative', 'f', 'x', 'x'] + f_prefix)
x1 = torch.LongTensor(
    [env.eos_index] +
    [env.word2id[w] for w in x1_prefix] +
    [env.eos_index]
).view(-1, 1)
len1 = torch.LongTensor([len(x1)])
x1, len1 = to_cuda([x1, len1], cpu=params.cpu)

with torch.no_grad():
    encoded = encoder('fwd', x=x1, lengths=len1, causal=False).transpose(0, 1)

### Decode with beam search

In [None]:
beam_size = 10
with torch.no_grad():
    _, _, beam = decoder.generate_beam(encoded, len1, beam_size=beam_size, length_penalty=1.0, early_stopping=1, max_len=200)
    assert len(beam) == 1
hypotheses = beam[0].hyp
assert len(hypotheses) == beam_size

### Print results

In [None]:
print(f"Input function f: {f}")
print(f"Reference function F: {F}")
print()

for score, sent in sorted(hypotheses, key=lambda x: x[0], reverse=True):

    # parse decoded hypothesis
    ids = sent[1:].tolist()                  # decoded token IDs
    tok = [env.id2word[wid] for wid in ids]  # convert to prefix

    try:
        hyp = env.prefix_to_infix(tok)       # convert to infix
        hyp = env.infix_to_sympy(hyp)        # convert to SymPy

        # check whether we recover f if we differentiate the hypothesis
        # note that sometimes, SymPy fails to show that hyp' - f == 0, and 
        # the result is considered as invalid, although it may be correct
        res = "OK" if simplify(hyp.diff(x) - f, seconds=1) == 0 else "NO"

    except:
        res = "INVALID PREFIX EXPRESSION"
        hyp = tok

    # print result
    print(f"{score:.5f}  {res}  {hyp}")