We use Colab for free GPU training. For accelerated training make sure you are using the GPU harware accelerator: Runtime ->  Change runtime type \\
To check you are running a GPU runtime, run the following:

In [1]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find("failed") >= 0:
  print("Not connected to a gpu runtime")
else:
  print("Connected to a gpu runtime:")
  print(gpu_info)

Connected to a gpu runtime:
Tue Mar 23 21:29:17 2021       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 460.56       Driver Version: 460.32.03    CUDA Version: 11.2     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  Tesla P100-PCIE...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   34C    P0    27W / 250W |      0MiB / 16280MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------

## Setup

In [44]:
import numpy as np
import tensorflow as tf
from tqdm.notebook import tqdm
import time
from IPython import display
import pickle
from tensorflow.keras import backend as K
from sklearn.model_selection import train_test_split


class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self


params = AttrDict({

    # dataset creation parameters
    'dataset_path': '',
    'dataset_size': 3000,
    'num_parallel_processes': 15,
    'env_name': 'char_sp',
    'int_base': 10,
    'balanced': False,
    'positive': True,
    'precision': 10,
    'n_variables': 2,
    'n_coefficients': 0,
    'leaf_probs': '0.75,0,0.25,0',
    'max_len': 510,
    'max_int': 5,
    'max_ops': 15,
    'max_ops_G': 15,
    'clean_prefix_expr': True,
    'rewrite_functions': '',
    'tasks': 'prim_fwd',

    # used operations, followed by genreation weight
    'operators': 'add:1,sub:1,mul:1,pow2:1,pow3:1', 
    #'operators': 'add:10,sub:3,mul:10,div:5,sqrt:4,pow2:4,pow3:2,pow4:1,pow5:1,ln:4,exp:4,sin:4,cos:4,tan:4,asin:1,acos:1,atan:1,sinh:1,cosh:1,tanh:1,asinh:1,acosh:1,atanh:1',

    # model hyperparameters
    'num_layers': 4,
    'd_model': 512,
    'dff': 512,
    'num_heads': 8,
    'dropout_rate': 0.1,

    # model training hyperparameters
    'learning_rate': 1e-4,
    'batch_size': 32,
    'model_name': 'lime',
    'epochs': 300,
    'model_path': 'model.h5'
    })

## Equation Generation Environment Class

This cell contains the utility functions for generaiting random expressions. This is modified version of the code used for [	arXiv:1912.01412](https://arxiv.org/abs/1912.01412) for generating random integrals and is built with Sympy (see [SymbolicMathematics](https://github.com/facebookresearch/SymbolicMathematics) ). 

In [49]:
import sympy as sp
from sympy.parsing.sympy_parser import parse_expr
from sympy.core.cache import clear_cache
from sympy.integrals.risch import NonElementaryIntegral
from sympy.calculus.util import AccumBounds
import os
import io
import re
import sys
import math
import itertools
from collections import OrderedDict
import numpy as np
import numexpr as ne


SPECIAL_WORDS = ['<s>', '</s>', '<pad>', '(', ')']
SPECIAL_WORDS = SPECIAL_WORDS + [f'<SPECIAL_{i}>' for i in range(len(SPECIAL_WORDS), 10)]


INTEGRAL_FUNC = {sp.erf, sp.erfc, sp.erfi, sp.erfinv, sp.erfcinv, sp.expint, sp.Ei, sp.li, sp.Li, sp.Si, sp.Ci, sp.Shi, sp.Chi, sp.fresnelc, sp.fresnels}
EXP_OPERATORS = {'exp', 'sinh', 'cosh'}
EVAL_SYMBOLS = {'x', 'y', 'z', 'a0', 'a1', 'a2', 'a3', 'a4', 'a5', 'a6', 'a7', 'a8', 'a9'}
EVAL_VALUES = [0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 1.1, 2.1, 3.1]
EVAL_VALUES = EVAL_VALUES + [-x for x in EVAL_VALUES]

TEST_ZERO_VALUES = [0.1, 0.9, 1.1, 1.9]
TEST_ZERO_VALUES = [-x for x in TEST_ZERO_VALUES] + TEST_ZERO_VALUES
ZERO_THRESHOLD = 1e-13


class ValueErrorExpression(Exception):
    pass


class UnknownSymPyOperator(Exception):
    pass


class InvalidPrefixExpression(Exception):

    def __init__(self, data):
        self.data = data

    def __str__(self):
        return repr(self.data)


def count_nested_exp(s):
    """
    Return the maximum number of nested exponential functions in an infix expression.
    """
    stack = []
    count = 0
    max_count = 0
    for v in re.findall('[+-/*//()]|[a-zA-Z0-9]+', s):
        if v == '(':
            stack.append(v)
        elif v == ')':
            while True:
                x = stack.pop()
                if x in EXP_OPERATORS:
                    count -= 1
                if x == '(':
                    break
        else:
            stack.append(v)
            if v in EXP_OPERATORS:
                count += 1
                max_count = max(max_count, count)
    assert len(stack) == 0
    return max_count


def is_valid_expr(s):
    """
    Check that we are able to evaluate an expression (and that it will not blow in SymPy evaluation).
    """
    s = s.replace('Derivative(f(x),x)', '1')
    s = s.replace('Derivative(1,x)', '1')
    s = s.replace('(E)', '(exp(1))')
    s = s.replace('(I)', '(1)')
    s = s.replace('(pi)', '(1)')
    s = re.sub(r'(?<![a-z])(f|g|h|Abs|sign|ln|sin|cos|tan|sec|csc|cot|asin|acos|atan|asec|acsc|acot|tanh|sech|csch|coth|asinh|acosh|atanh|asech|acoth|acsch)\(', '(', s)
    count = count_nested_exp(s)
    if count >= 4:
        return False
    for v in EVAL_VALUES:
        try:
            local_dict = {s: (v + 1e-4 * i) for i, s in enumerate(EVAL_SYMBOLS)}
            value = ne.evaluate(s, local_dict=local_dict).item()
            if not (math.isnan(value) or math.isinf(value)):
                return True
        except (FloatingPointError, ZeroDivisionError, TypeError, MemoryError):
            continue
    return False


def eval_test_zero(eq):
    """
    Evaluate an equation by replacing all its free symbols with random values.
    """
    variables = eq.free_symbols
    assert len(variables) <= 3
    outputs = []
    for values in itertools.product(*[TEST_ZERO_VALUES for _ in range(len(variables))]):
        _eq = eq.subs(zip(variables, values)).doit()
        outputs.append(float(sp.Abs(_eq.evalf())))
    return outputs


def has_inf_nan(*args):
    """
    Detect whether some expressions contain a NaN / Infinity symbol.
    """
    for f in args:
        if f.has(sp.nan) or f.has(sp.oo) or f.has(-sp.oo) or f.has(sp.zoo):
            return True
    return False


class CharSPEnvironment(object):

    # https://docs.sympy.org/latest/modules/functions/elementary.html#real-root

    SYMPY_OPERATORS = {
        # Elementary functions
        sp.Add: 'add',
        sp.Mul: 'mul',
        sp.Pow: 'pow',
        sp.exp: 'exp',
        sp.log: 'ln',
        sp.Abs: 'abs',
        sp.sign: 'sign',
        # Trigonometric Functions
        sp.sin: 'sin',
        sp.cos: 'cos',
        sp.tan: 'tan',
        sp.cot: 'cot',
        sp.sec: 'sec',
        sp.csc: 'csc',
        # Trigonometric Inverses
        sp.asin: 'asin',
        sp.acos: 'acos',
        sp.atan: 'atan',
        sp.acot: 'acot',
        sp.asec: 'asec',
        sp.acsc: 'acsc',
        # Hyperbolic Functions
        sp.sinh: 'sinh',
        sp.cosh: 'cosh',
        sp.tanh: 'tanh',
        sp.coth: 'coth',
        sp.sech: 'sech',
        sp.csch: 'csch',
        # Hyperbolic Inverses
        sp.asinh: 'asinh',
        sp.acosh: 'acosh',
        sp.atanh: 'atanh',
        sp.acoth: 'acoth',
        sp.asech: 'asech',
        sp.acsch: 'acsch',
        # Derivative
        sp.Derivative: 'derivative',
    }

    OPERATORS = {
        # Elementary functions
        'add': 2,
        'sub': 2,
        'mul': 2,
        'div': 2,
        'pow': 2,
        'rac': 2,
        'inv': 1,
        'pow2': 1,
        'pow3': 1,
        'pow4': 1,
        'pow5': 1,
        'sqrt': 1,
        'exp': 1,
        'ln': 1,
        'abs': 1,
        'sign': 1,
        # Trigonometric Functions
        'sin': 1,
        'cos': 1,
        'tan': 1,
        'cot': 1,
        'sec': 1,
        'csc': 1,
        # Trigonometric Inverses
        'asin': 1,
        'acos': 1,
        'atan': 1,
        'acot': 1,
        'asec': 1,
        'acsc': 1,
        # Hyperbolic Functions
        'sinh': 1,
        'cosh': 1,
        'tanh': 1,
        'coth': 1,
        'sech': 1,
        'csch': 1,
        # Hyperbolic Inverses
        'asinh': 1,
        'acosh': 1,
        'atanh': 1,
        'acoth': 1,
        'asech': 1,
        'acsch': 1,
        # Derivative
        'derivative': 2,
        # custom functions
        'f': 1,
        'g': 2,
        'h': 3,
    }

    def __init__(self, params):

        self.max_int = params.max_int
        self.max_ops = params.max_ops
        self.max_ops_G = params.max_ops_G
        self.int_base = params.int_base
        self.balanced = params.balanced
        self.positive = params.positive
        self.precision = params.precision
        self.n_variables = params.n_variables
        self.n_coefficients = params.n_coefficients
        self.max_len = params.max_len
        self.clean_prefix_expr = params.clean_prefix_expr
        assert self.max_int >= 1
        assert abs(self.int_base) >= 2
        assert self.precision >= 2

        # parse operators with their weights
        self.operators = sorted(list(self.OPERATORS.keys()))
        ops = params.operators.split(',')
        ops = sorted([x.split(':') for x in ops])
        assert len(ops) >= 1 and all(o in self.OPERATORS for o, _ in ops)
        self.all_ops = [o for o, _ in ops]
        self.una_ops = [o for o, _ in ops if self.OPERATORS[o] == 1]
        self.bin_ops = [o for o, _ in ops if self.OPERATORS[o] == 2]

        self.all_ops_probs = np.array([float(w) for _, w in ops]).astype(np.float64)
        self.una_ops_probs = np.array([float(w) for o, w in ops if self.OPERATORS[o] == 1]).astype(np.float64)
        self.bin_ops_probs = np.array([float(w) for o, w in ops if self.OPERATORS[o] == 2]).astype(np.float64)
        self.all_ops_probs = self.all_ops_probs / self.all_ops_probs.sum()
        self.una_ops_probs = self.una_ops_probs / self.una_ops_probs.sum()
        self.bin_ops_probs = self.bin_ops_probs / self.bin_ops_probs.sum()

        assert len(self.all_ops) == len(set(self.all_ops)) >= 1
        assert set(self.all_ops).issubset(set(self.operators))
        assert len(self.all_ops) == len(self.una_ops) + len(self.bin_ops)

        # symbols / elements
        self.constants = ['pi', 'E']
        self.variables = OrderedDict({
            'x': sp.Symbol('x', real=True, nonzero=True),  # , positive=True
            'y': sp.Symbol('y', real=True, nonzero=True),  # , positive=True
            'z': sp.Symbol('z', real=True, nonzero=True),  # , positive=True
            't': sp.Symbol('t', real=True, nonzero=True),  # , positive=True
        })
        self.coefficients = OrderedDict({
            f'a{i}': sp.Symbol(f'a{i}', real=True)
            for i in range(10)
        })
        self.functions = OrderedDict({
            'f': sp.Function('f', real=True, nonzero=True),
            'g': sp.Function('g', real=True, nonzero=True),
            'h': sp.Function('h', real=True, nonzero=True),
        })
        self.symbols = ['I', 'INT+', 'INT-', 'INT', 'FLOAT', '-', '.', '10^', 'Y', "Y'", "Y''"]
        if self.balanced:
            assert self.int_base > 2
            max_digit = (self.int_base + 1) // 2
            self.elements = [str(i) for i in range(max_digit - abs(self.int_base), max_digit)]
        else:
            self.elements = [str(i) for i in range(abs(self.int_base))]
        assert 1 <= self.n_variables <= len(self.variables)
        assert 0 <= self.n_coefficients <= len(self.coefficients)
        assert all(k in self.OPERATORS for k in self.functions.keys())
        assert all(v in self.OPERATORS for v in self.SYMPY_OPERATORS.values())

        # SymPy elements
        self.local_dict = {}
        for k, v in list(self.variables.items()) + list(self.coefficients.items()) + list(self.functions.items()):
            assert k not in self.local_dict
            self.local_dict[k] = v

        # vocabulary
        self.words = SPECIAL_WORDS + self.constants + list(self.variables.keys()) + list(self.coefficients.keys()) + self.operators + self.symbols + self.elements
        self.id2word = {i + 1: s for i, s in enumerate(self.words)}
        self.word2id = {s: i for i, s in self.id2word.items()}
        assert len(self.words) == len(set(self.words))

        # number of words / indices
        self.n_words = params.n_words = len(self.words)
        self.eos_index = params.eos_index = 0
        self.pad_index = params.pad_index = 1

        # leaf probabilities
        s = [float(x) for x in params.leaf_probs.split(',')]
        assert len(s) == 4 and all(x >= 0 for x in s)
        self.leaf_probs = np.array(s).astype(np.float64)
        self.leaf_probs = self.leaf_probs / self.leaf_probs.sum()
        assert self.leaf_probs[0] > 0
        assert (self.leaf_probs[1] == 0) == (self.n_coefficients == 0)

        # possible leaves
        self.n_leaves = self.n_variables + self.n_coefficients
        if self.leaf_probs[2] > 0:
            self.n_leaves += self.max_int * (1 if self.positive else 2)
        if self.leaf_probs[3] > 0:
            self.n_leaves += len(self.constants)

        # generation parameters
        self.nl = 1  # self.n_leaves
        self.p1 = 1  # len(self.una_ops)
        self.p2 = 1  # len(self.bin_ops)

        # initialize distribution for binary and unary-binary trees
        self.bin_dist = self.generate_bin_dist(params.max_ops)
        self.ubi_dist = self.generate_ubi_dist(params.max_ops)

        # rewrite expressions
        self.rewrite_functions = [x for x in params.rewrite_functions.split(',') if x != '']
        assert len(self.rewrite_functions) == len(set(self.rewrite_functions))
        assert all(x in ['expand', 'factor', 'expand_log', 'logcombine', 'powsimp', 'simplify'] for x in self.rewrite_functions)

    def generate_bin_dist(self, max_ops):
        """
        `max_ops`: maximum number of operators
        Enumerate the number of possible binary trees that can be generated from empty nodes.
        D[e][n] represents the number of different binary trees with n nodes that
        can be generated from e empty nodes, using the following recursion:
            D(0, n) = 0
            D(1, n) = C_n (n-th Catalan number)
            D(e, n) = D(e - 1, n + 1) - D(e - 2, n + 1)
        """
        # initialize Catalan numbers
        catalans = [1]
        for i in range(1, 2 * max_ops + 1):
            catalans.append((4 * i - 2) * catalans[i - 1] // (i + 1))

        # enumerate possible trees
        D = []
        for e in range(max_ops + 1):  # number of empty nodes
            s = []
            for n in range(2 * max_ops - e + 1):  # number of operators
                if e == 0:
                    s.append(0)
                elif e == 1:
                    s.append(catalans[n])
                else:
                    s.append(D[e - 1][n + 1] - D[e - 2][n + 1])
            D.append(s)
        return D

    def generate_ubi_dist(self, max_ops):
        """
        `max_ops`: maximum number of operators
        Enumerate the number of possible unary-binary trees that can be generated from empty nodes.
        D[e][n] represents the number of different binary trees with n nodes that
        can be generated from e empty nodes, using the following recursion:
            D(0, n) = 0
            D(e, 0) = L ** e
            D(e, n) = L * D(e - 1, n) + p_1 * D(e, n - 1) + p_2 * D(e + 1, n - 1)
        """
        # enumerate possible trees
        # first generate the tranposed version of D, then transpose it
        D = []
        D.append([0] + ([self.nl ** i for i in range(1, 2 * max_ops + 1)]))
        for n in range(1, 2 * max_ops + 1):  # number of operators
            s = [0]
            for e in range(1, 2 * max_ops - n + 1):  # number of empty nodes
                s.append(self.nl * s[e - 1] + self.p1 * D[n - 1][e] + self.p2 * D[n - 1][e + 1])
            D.append(s)
        assert all(len(D[i]) >= len(D[i + 1]) for i in range(len(D) - 1))
        D = [[D[j][i] for j in range(len(D)) if i < len(D[j])] for i in range(max(len(x) for x in D))]
        return D

    def write_int(self, val):
        """
        Convert a decimal integer to a representation in the given base.
        The base can be negative.
        In balanced bases (positive), digits range from -(base-1)//2 to (base-1)//2
        """
        base = self.int_base
        balanced = self.balanced
        res = []
        max_digit = abs(base)
        if balanced:
            max_digit = (base - 1) // 2
        else:
            if base > 0:
                neg = val < 0
                val = -val if neg else val
        while True:
            rem = val % base
            val = val // base
            if rem < 0 or rem > max_digit:
                rem -= base
                val += 1
            res.append(str(rem))
            if val == 0:
                break
        if base < 0 or balanced:
            res.append('INT')
        else:
            res.append('INT-' if neg else 'INT+')
        return res[::-1]

    def parse_int(self, lst):
        """
        Parse a list that starts with an integer.
        Return the integer value, and the position it ends in the list.
        """
        base = self.int_base
        balanced = self.balanced
        val = 0
        if not (balanced and lst[0] == 'INT' or base >= 2 and lst[0] in ['INT+', 'INT-'] or base <= -2 and lst[0] == 'INT'):
            raise InvalidPrefixExpression(f"Invalid integer in prefix expression")
        i = 0
        for x in lst[1:]:
            if not (x.isdigit() or x[0] == '-' and x[1:].isdigit()):
                break
            val = val * base + int(x)
            i += 1
        if base > 0 and lst[0] == 'INT-':
            val = -val
        return val, i + 1

    def sample_next_pos_ubi(self, nb_empty, nb_ops, rng):
        """
        Sample the position of the next node (unary-binary case).
        Sample a position in {0, ..., `nb_empty` - 1}, along with an arity.
        """
        assert nb_empty > 0
        assert nb_ops > 0
        probs = []
        for i in range(nb_empty):
            probs.append((self.nl ** i) * self.p1 * self.ubi_dist[nb_empty - i][nb_ops - 1])
        for i in range(nb_empty):
            probs.append((self.nl ** i) * self.p2 * self.ubi_dist[nb_empty - i + 1][nb_ops - 1])
        probs = [p / self.ubi_dist[nb_empty][nb_ops] for p in probs]
        probs = np.array(probs, dtype=np.float64)
        e = rng.choice(2 * nb_empty, p=probs)
        arity = 1 if e < nb_empty else 2
        e = e % nb_empty
        return e, arity

    def get_leaf(self, max_int, rng):
        """
        Generate a leaf.
        """
        self.leaf_probs
        leaf_type = rng.choice(4, p=self.leaf_probs)
        if leaf_type == 0:
            return [list(self.variables.keys())[rng.randint(self.n_variables)]]
        elif leaf_type == 1:
            return [list(self.coefficients.keys())[rng.randint(self.n_coefficients)]]
        elif leaf_type == 2:
            c = rng.randint(1, max_int + 1)
            c = c if (self.positive or rng.randint(2) == 0) else -c
            return self.write_int(c)
        else:
            return [self.constants[rng.randint(len(self.constants))]]

    def _generate_expr(self, nb_total_ops, max_int, rng, require_x=False, require_y=False, require_z=False):
        """
        Create a tree with exactly `nb_total_ops` operators.
        """
        stack = [None]
        nb_empty = 1  # number of empty nodes
        l_leaves = 0  # left leaves - None states reserved for leaves
        t_leaves = 1  # total number of leaves (just used for sanity check)

        # create tree
        for nb_ops in range(nb_total_ops, 0, -1):

            # next operator, arity and position
            skipped, arity = self.sample_next_pos_ubi(nb_empty, nb_ops, rng)
            if arity == 1:
                op = rng.choice(self.una_ops, p=self.una_ops_probs)
            else:
                op = rng.choice(self.bin_ops, p=self.bin_ops_probs)

            nb_empty += self.OPERATORS[op] - 1 - skipped  # created empty nodes - skipped future leaves
            t_leaves += self.OPERATORS[op] - 1            # update number of total leaves
            l_leaves += skipped                           # update number of left leaves

            # update tree
            pos = [i for i, v in enumerate(stack) if v is None][l_leaves]
            stack = stack[:pos] + [op] + [None for _ in range(self.OPERATORS[op])] + stack[pos + 1:]

        # sanity check
        assert len([1 for v in stack if v in self.all_ops]) == nb_total_ops
        assert len([1 for v in stack if v is None]) == t_leaves

        # create leaves
        # optionally add variables x, y, z if possible
        assert not require_z or require_y
        assert not require_y or require_x
        leaves = [self.get_leaf(max_int, rng) for _ in range(t_leaves)]
        if require_z and t_leaves >= 2:
            leaves[1] = ['z']
        if require_y:
            leaves[0] = ['y']
        if require_x and not any(len(leaf) == 1 and leaf[0] == 'x' for leaf in leaves):
            leaves[-1] = ['x']
        rng.shuffle(leaves)

        # insert leaves into tree
        for pos in range(len(stack) - 1, -1, -1):
            if stack[pos] is None:
                stack = stack[:pos] + leaves.pop() + stack[pos + 1:]

        return stack

    def write_infix(self, token, args):
        """
        Infix representation.
        Convert prefix expressions to a format that SymPy can parse.
        """
        if token == 'add':
            return f'({args[0]})+({args[1]})'
        elif token == 'sub':
            return f'({args[0]})-({args[1]})'
        elif token == 'mul':
            return f'({args[0]})*({args[1]})'
        elif token == 'div':
            return f'({args[0]})/({args[1]})'
        elif token == 'pow':
            return f'({args[0]})**({args[1]})'
        elif token == 'rac':
            return f'({args[0]})**(1/({args[1]}))'
        elif token == 'abs':
            return f'Abs({args[0]})'
        elif token == 'inv':
            return f'1/({args[0]})'
        elif token == 'pow2':
            return f'({args[0]})**2'
        elif token == 'pow3':
            return f'({args[0]})**3'
        elif token == 'pow4':
            return f'({args[0]})**4'
        elif token == 'pow5':
            return f'({args[0]})**5'
        elif token in ['sign', 'sqrt', 'exp', 'ln', 'sin', 'cos', 'tan', 'cot', 'sec', 'csc', 'asin', 'acos', 'atan', 'acot', 'asec', 'acsc', 'sinh', 'cosh', 'tanh', 'coth', 'sech', 'csch', 'asinh', 'acosh', 'atanh', 'acoth', 'asech', 'acsch']:
            return f'{token}({args[0]})'
        elif token == 'derivative':
            return f'Derivative({args[0]},{args[1]})'
        elif token == 'f':
            return f'f({args[0]})'
        elif token == 'g':
            return f'g({args[0]},{args[1]})'
        elif token == 'h':
            return f'h({args[0]},{args[1]},{args[2]})'
        elif token.startswith('INT'):
            return f'{token[-1]}{args[0]}'
        else:
            return token
        raise InvalidPrefixExpression(f"Unknown token in prefix expression: {token}, with arguments {args}")

    def _prefix_to_infix(self, expr):
        """
        Parse an expression in prefix mode, and output it in either:
          - infix mode (returns human readable string)
          - develop mode (returns a dictionary with the simplified expression)
        """
        if len(expr) == 0:
            raise InvalidPrefixExpression("Empty prefix list.")
        t = expr[0]
        if t in self.operators:
            args = []
            l1 = expr[1:]
            for _ in range(self.OPERATORS[t]):
                i1, l1 = self._prefix_to_infix(l1)
                args.append(i1)
            return self.write_infix(t, args), l1
        elif t in self.variables or t in self.coefficients or t in self.constants or t == 'I':
            return t, expr[1:]
        else:
            val, i = self.parse_int(expr)
            return str(val), expr[i:]

    def prefix_to_infix(self, expr):
        """
        Prefix to infix conversion.
        """
        p, r = self._prefix_to_infix(expr)
        if len(r) > 0:
            raise InvalidPrefixExpression(f"Incorrect prefix expression \"{expr}\". \"{r}\" was not parsed.")
        return f'({p})'

    def rewrite_sympy_expr(self, expr):
        """
        Rewrite a SymPy expression.
        """
        expr_rw = expr
        for f in self.rewrite_functions:
            if f == 'expand':
                expr_rw = sp.expand(expr_rw)
            elif f == 'factor':
                expr_rw = sp.factor(expr_rw)
            elif f == 'expand_log':
                expr_rw = sp.expand_log(expr_rw, force=True)
            elif f == 'logcombine':
                expr_rw = sp.logcombine(expr_rw, force=True)
            elif f == 'powsimp':
                expr_rw = sp.powsimp(expr_rw, force=True)
            elif f == 'simplify':
                expr_rw = simplify(expr_rw, seconds=1)
        return expr_rw

    def infix_to_sympy(self, infix, no_rewrite=False, 
                       evaluate=False, force=False):
        """
        Convert an infix expression to SymPy.
        """
        if force:
          return parse_expr(infix, evaluate=evaluate, local_dict=self.local_dict)

        if not is_valid_expr(infix):
            raise ValueErrorExpression
        #expr = parse_expr(infix, evaluate=True, local_dict=self.local_dict)
        expr = parse_expr(infix, evaluate=evaluate, local_dict=self.local_dict)

        if expr.has(sp.I) or expr.has(AccumBounds):
            raise ValueErrorExpression
        if not no_rewrite:
            expr = self.rewrite_sympy_expr(expr)
        return expr

    def _sympy_to_prefix(self, op, expr):
        """
        Parse a SymPy expression given an initial root operator.
        """
        n_args = len(expr.args)

        # derivative operator
        if op == 'derivative':
            assert n_args >= 2
            assert all(len(arg) == 2 and str(arg[0]) in self.variables and int(arg[1]) >= 1 for arg in expr.args[1:]), expr.args
            parse_list = self.sympy_to_prefix(expr.args[0])
            for var, degree in expr.args[1:]:
                parse_list = ['derivative' for _ in range(int(degree))] + parse_list + [str(var) for _ in range(int(degree))]
            return parse_list

        assert (op == 'add' or op == 'mul') and (n_args >= 2) or (op != 'add' and op != 'mul') and (1 <= n_args <= 2)

        # square root
        if op == 'pow' and isinstance(expr.args[1], sp.Rational) and expr.args[1].p == 1 and expr.args[1].q == 2:
            return ['sqrt'] + self.sympy_to_prefix(expr.args[0])

        # parse children
        parse_list = []
        for i in range(n_args):
            if i == 0 or i < n_args - 1:
                parse_list.append(op)
            parse_list += self.sympy_to_prefix(expr.args[i])

        return parse_list

    def sympy_to_prefix(self, expr):
        """
        Convert a SymPy expression to a prefix one.
        """
        if isinstance(expr, sp.Symbol):
            return [str(expr)]
        elif isinstance(expr, sp.Integer):
            return self.write_int(int(str(expr)))
        elif isinstance(expr, sp.Rational):
            return ['div'] + self.write_int(int(expr.p)) + self.write_int(int(expr.q))
        elif expr == sp.E:
            return ['E']
        elif expr == sp.pi:
            return ['pi']
        elif expr == sp.I:
            return ['I']
        # SymPy operator
        for op_type, op_name in self.SYMPY_OPERATORS.items():
            if isinstance(expr, op_type):
                return self._sympy_to_prefix(op_name, expr)
        # environment function
        for func_name, func in self.functions.items():
            if isinstance(expr, func):
                return self._sympy_to_prefix(func_name, expr)
        # unknown operator
        raise UnknownSymPyOperator(f"Unknown SymPy operator: {expr}")

    def extract_non_constant_subtree(self, expr):
        return extract_non_constant_subtree(expr, self.variables.values())

    def gen_expr(self, rng):
        """
        Generate pairs of (function, derivative) or (function, primitive).
        Start by generating a random function f, and use SymPy to compute f'.
        """
        x = self.variables['x']
        if rng.randint(40) == 0:
            nb_ops = rng.randint(0, 4)
        else:
            nb_ops = rng.randint(4, self.max_ops + 1)

        try:
          # generate an expression and rewrite it,
          # avoid issues in 0 and convert to SymPy
          F_expr = self._generate_expr(nb_ops, self.max_int, rng)
          infix = self.prefix_to_infix(F_expr)
          F = self.infix_to_sympy(infix, evaluate=False)

          # skip constant expressions
          if x not in F.free_symbols:
              return None

          # remove additive constant, re-index coefficients
          f = sp.simplify(F)

          # skip invalid expressions
          if has_inf_nan(f, F):
              return None

          # convert back to prefix
          f_prefix = self.sympy_to_prefix(f)
          F1_prefix = self.sympy_to_prefix(F)
          F2_prefix = self.sympy_to_prefix(sp.expand(f))
          F3_prefix = self.sympy_to_prefix(sp.expand(F))

          return f_prefix, (F1_prefix, F2_prefix, F3_prefix)

        except RuntimeWarning:
          return None
        except TimeoutError:
            return None
        except (ValueErrorExpression, UnknownSymPyOperator, OverflowError, TypeError):
            return None
        except Exception as e:
            return None

In [46]:
DATASET_ENV = CharSPEnvironment(params)


def sympy_to_sequences(expr, env):
  prefix = env.sympy_to_prefix(expr)
  return [env.word2id[w] for w in prefix]


def sequences_to_sympy(x, env):
  x = [env.id2word[w] for w in x]
  x = env.prefix_to_infix(x)
  return env.infix_to_sympy(x)


def prefix_to_sympy(x, env):
  x = env.prefix_to_infix(x)
  return env.infix_to_sympy(x)

## Visualising some random expressions

The the function 'DATASET_ENV.gen_expr' generates random equations in prefix notation (see https://en.wikipedia.org/wiki/Polish_notation for more details).

In [17]:
irreducible_eq, reducible_eqs = DATASET_ENV.gen_expr(np.random)

prefix_to_sympy(irreducible_eq, DATASET_ENV)

x*y + (x + 4)**2

In [19]:
prefix_to_sympy(reducible_eqs[0], DATASET_ENV)

x*y + (x + 4)**2

In [21]:
prefix_to_sympy(reducible_eqs[1], DATASET_ENV)

x**2 + x*y + 8*x + 16

In [22]:
prefix_to_sympy(reducible_eqs[2], DATASET_ENV)

x**2 + x*y + 8*x + 16

# Dataset creation

Here we use python's multiprocessing module to parallelise dataset createtion. \\
The following creates two files of the tokenize equations 'x.pkl' and 'y.pkl'. Where 'y.pkl' contains the simplified versions of the equations found in 'x.pkl'.


In [None]:
from multiprocessing import Process, Manager
from IPython import display
import pickle


_samples_per_process = int(params.dataset_size//params.num_parallel_processes)


def sampling_func(inp, tar):
  np.random.seed(None)
  i = 0
  while i < _samples_per_process:
    x = None
    while x is None:
      x = DATASET_ENV.gen_expr(np.random)
    y, x = x
    y = [DATASET_ENV.word2id[w] for w in y]
    i += 1
    for u in x:
      if len(u) > 1:
        inp.append([DATASET_ENV.word2id[w] for w in u])
        tar.append(y)


print('Creating dataset ...')
start = time.time()

manager = Manager()
inp = manager.list() 
tar = manager.list() 

processes = []
for _ in range(params.num_parallel_processes):
  p = Process(target=sampling_func, args=(inp, tar)) 
  p.start()
  processes.append(p)
for p in processes:
  p.join()

inp = list(inp)
tar = list(tar)

display.clear_output()

print('Saving')
with open(params.dataset_path + 'x.pkl', 'wb') as f:
  pickle.dump(inp, f)
with open(params.dataset_path + 'y.pkl', 'wb') as f:
  pickle.dump(tar, f)

clear_cache()
manager.shutdown()
print('Complete')

Creating dataset ...


## Model Setup
This cell defines the keras architecture of the standard transformer. 
Here we use a variant of the original transformer architecture known as ReZero (see [arXiv:2003.04887](https://arxiv.org/abs/2003.04887) ) for faster training convergence. 

In [29]:
def create_padding_mask(seq):
  seq = tf.cast(tf.math.equal(seq, 0), tf.float32)

  # add extra dimensions to add the padding
  # to the attention logits.
  return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)


def create_look_ahead_mask(size):
  mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
  return mask  # (seq_len, seq_len)


def create_masks(inp, tar):
  # Encoder padding mask
  enc_padding_mask = create_padding_mask(inp)

  # Used in the 2nd attention block in the decoder.
  # This padding mask is used to mask the encoder outputs.
  dec_padding_mask = create_padding_mask(inp)

  # Used in the 1st attention block in the decoder.
  # It is used to pad and mask future tokens in the input received by 
  # the decoder.
  look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
  dec_target_padding_mask = create_padding_mask(tar)
  combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)

  return enc_padding_mask, combined_mask, dec_padding_mask


def get_angles(pos, i, d_model):
  angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
  return pos * angle_rates


def positional_encoding(position, d_model):
  angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                          np.arange(d_model)[np.newaxis, :],
                          d_model)

  # apply sin to even indices in the array; 2i
  angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])

  # apply cos to odd indices in the array; 2i+1
  angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

  pos_encoding = angle_rads[np.newaxis, ...]

  return tf.cast(pos_encoding, dtype=tf.float32)


def scaled_dot_product_attention(q, k, v, mask):
  """Calculate the attention weights.
  q, k, v must have matching leading dimensions.
  k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
  The mask has different shapes depending on its type(padding or look ahead) 
  but it must be broadcastable for addition.

  Args:
    q: query shape == (..., seq_len_q, depth)
    k: key shape == (..., seq_len_k, depth)
    v: value shape == (..., seq_len_v, depth_v)
    mask: Float tensor with shape broadcastable 
          to (..., seq_len_q, seq_len_k). Defaults to None.

  Returns:
    output, attention_weights
  """

  matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)

  # scale matmul_qk
  dk = tf.cast(tf.shape(k)[-1], tf.float32)
  scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

  # add the mask to the scaled tensor.
  if mask is not None:
    scaled_attention_logits += (mask * -1e9)  

  # softmax is normalized on the last axis (seq_len_k) so that the scores
  # add up to 1.
  attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

  output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

  return output, attention_weights


class MultiHeadAttention(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads):
    super(MultiHeadAttention, self).__init__()
    self.num_heads = num_heads
    self.d_model = d_model

    assert d_model % self.num_heads == 0

    self.depth = d_model // self.num_heads

    self.wq = tf.keras.layers.Dense(d_model)
    self.wk = tf.keras.layers.Dense(d_model)
    self.wv = tf.keras.layers.Dense(d_model)

    self.dense = tf.keras.layers.Dense(d_model)

  def split_heads(self, x, batch_size):
    """Split the last dimension into (num_heads, depth).
    Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
    """
    x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
    return tf.transpose(x, perm=[0, 2, 1, 3])

  def call(self, v, k, q, mask):
    batch_size = tf.shape(q)[0]

    q = self.wq(q)  # (batch_size, seq_len, d_model)
    k = self.wk(k)  # (batch_size, seq_len, d_model)
    v = self.wv(v)  # (batch_size, seq_len, d_model)

    q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
    k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
    v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)

    # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
    # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
    scaled_attention, attention_weights = scaled_dot_product_attention(
        q, k, v, mask)

    scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)

    concat_attention = tf.reshape(scaled_attention, 
                                  (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)

    output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)

    return output, attention_weights


def point_wise_feed_forward_network(d_model, dff):
  return tf.keras.Sequential([
      tf.keras.layers.Dense(dff, activation='relu'),  # (batch_size, seq_len, dff)
      tf.keras.layers.Dense(d_model)  # (batch_size, seq_len, d_model)
  ])


class ReZero(tf.keras.layers.Layer):
    def __init__(self, name):
        super(ReZero, self).__init__(name=name)
        a_init = tf.zeros_initializer()
        self.alpha = tf.Variable(name=self.name + '-alpha',
            initial_value=a_init(shape=(1,), dtype="float32"), trainable=True
        )

    def call(self, inputs):
        return self.alpha * inputs


class EncoderLayer(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads, dff, rate=0.1):
    super(EncoderLayer, self).__init__()

    self.mha = MultiHeadAttention(d_model, num_heads)
    self.ffn = point_wise_feed_forward_network(d_model, dff)

    #self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    #self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.rz1 = ReZero(self.name + 'rz1')
    self.rz2 = ReZero(self.name + 'rz2')

    self.dropout1 = tf.keras.layers.Dropout(rate)
    self.dropout2 = tf.keras.layers.Dropout(rate)

  def call(self, x, training, mask):

    attn_output, _ = self.mha(x, x, x, mask)  # (batch_size, input_seq_len, d_model)
    attn_output = self.dropout1(attn_output, training=training)
    out1 = x + self.rz1(attn_output)
    #out1 = self.layernorm1(x + attn_output)  # (batch_size, input_seq_len, d_model)

    ffn_output = self.ffn(out1)  # (batch_size, input_seq_len, d_model)
    ffn_output = self.dropout2(ffn_output, training=training)
    out2 = out1 + self.rz2(ffn_output)
    #out2 = self.layernorm2(out1 + ffn_output)  # (batch_size, input_seq_len, d_model)

    return out2


class DecoderLayer(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads, dff, rate=0.1):
    super(DecoderLayer, self).__init__()

    self.mha1 = MultiHeadAttention(d_model, num_heads)
    self.mha2 = MultiHeadAttention(d_model, num_heads)

    self.ffn = point_wise_feed_forward_network(d_model, dff)

    #self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    #self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    #self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6)
    self.rz1 = ReZero(self.name + 'rz1')
    self.rz2 = ReZero(self.name + 'rz2')
    self.rz3 = ReZero(self.name + 'rz3')

    self.dropout1 = tf.keras.layers.Dropout(rate)
    self.dropout2 = tf.keras.layers.Dropout(rate)
    self.dropout3 = tf.keras.layers.Dropout(rate)


  def call(self, x, enc_output, training, 
           look_ahead_mask, padding_mask):
    # enc_output.shape == (batch_size, input_seq_len, d_model)

    attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask)  # (batch_size, target_seq_len, d_model)
    attn1 = self.dropout1(attn1, training=training)
    out1 = x + self.rz1(attn1)
    #out1 = self.layernorm1(attn1 + x)

    attn2, attn_weights_block2 = self.mha2(
        enc_output, enc_output, out1, padding_mask)  # (batch_size, target_seq_len, d_model)
    attn2 = self.dropout2(attn2, training=training)
    out2 = out1 + self.rz2(attn2)
    #out2 = self.layernorm2(attn2 + out1)  # (batch_size, target_seq_len, d_model)

    ffn_output = self.ffn(out2)  # (batch_size, target_seq_len, d_model)
    ffn_output = self.dropout3(ffn_output, training=training)
    out3 = out2 + self.rz3(ffn_output)
    #out3 = self.layernorm3(ffn_output + out2)  # (batch_size, target_seq_len, d_model)

    return out3, attn_weights_block1, attn_weights_block2


class Encoder(tf.keras.layers.Layer):
  def __init__(self, num_layers, d_model, num_heads, dff,
               maximum_position_encoding, rate=0.1):
    super(Encoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers

    self.pos_encoding = positional_encoding(maximum_position_encoding, 
                                            self.d_model)

    self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate) 
                       for _ in range(num_layers)]

    self.dropout = tf.keras.layers.Dropout(rate)

  def call(self, x, training, mask):

    seq_len = tf.shape(x)[1]

    #x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    x += self.pos_encoding[:, :seq_len, :]

    x = self.dropout(x, training=training)

    for i in range(self.num_layers):
      x = self.enc_layers[i](x, training, mask)

    return x  # (batch_size, input_seq_len, d_model)


class Decoder(tf.keras.layers.Layer):
  def __init__(self, num_layers, d_model, num_heads, dff,
               maximum_position_encoding, rate=0.1):
    super(Decoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers

    self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)

    self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate) 
                       for _ in range(num_layers)]
    self.dropout = tf.keras.layers.Dropout(rate)

  def call(self, x, enc_output, training, 
           look_ahead_mask, padding_mask):

    seq_len = tf.shape(x)[1]
    attention_weights = {}

    #x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    x += self.pos_encoding[:, :seq_len, :]

    x = self.dropout(x, training=training)

    for i in range(self.num_layers):
      x, block1, block2 = self.dec_layers[i](x, enc_output, training,
                                             look_ahead_mask, padding_mask)

      attention_weights['decoder_layer{}_block1'.format(i+1)] = block1
      attention_weights['decoder_layer{}_block2'.format(i+1)] = block2

    # x.shape == (batch_size, target_seq_len, d_model)
    return x, attention_weights


class Transformer(tf.keras.Model):
  def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, 
               target_vocab_size, pe_input, pe_target, rate=0.1):
    super(Transformer, self).__init__()

    self.tok_embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
    self.dec_embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)

    self.tokenizer = Encoder(num_layers, d_model, num_heads, dff,
                             pe_input, rate)
    self.decoder = Decoder(num_layers, d_model, num_heads, dff,
                           pe_target, rate)
    self.final_layer = tf.keras.layers.Dense(target_vocab_size)

  def call(self, inp, tar, training, enc_padding_mask, 
           look_ahead_mask, dec_padding_mask):
    
    inp = self.tok_embedding(inp)
    enc_output = self.tokenizer(inp, training, enc_padding_mask)  # (batch_size, inp_seq_len, d_model)

    # dec_output.shape == (batch_size, tar_seq_len, d_model)
    tar = self.dec_embedding(tar)
    dec_output, attention_weights = self.decoder(
        tar, enc_output, training, look_ahead_mask, dec_padding_mask)

    final_output = self.final_layer(dec_output)  # (batch_size, tar_seq_len, target_vocab_size)

    return final_output, attention_weights

# Load the dataset
Load the datasets of equations. You can create your own or load a pre-computed one from here and load it into the 

In [23]:
SYMBOL_VOCAB_SIZE = DATASET_ENV.n_words
START_TOKEN = SYMBOL_VOCAB_SIZE + 1
END_TOKEN = SYMBOL_VOCAB_SIZE + 2
VOCAB_SIZE = END_TOKEN + 1

In [24]:
class SimplifyDataset:
  def __init__(self, path, inp_data_len=128, tar_data_len=128):
    self.path = path
    self.inp_data_len = inp_data_len
    self.tar_data_len = tar_data_len
    self.num_train_examples = None
    self.num_test_examples = None

  def preprocess_tokens(self, w, maxlen, wrap_start_end=False):
    if wrap_start_end:
      w = [[START_TOKEN] + x + [END_TOKEN] for x in w]

    w = tf.keras.preprocessing.sequence.pad_sequences(w, maxlen=maxlen,
                                                      padding='post')
    return w

  def load_dataset(self, num_examples=None):
    with open(self.path + 'x.pkl', 'rb') as f:
      src = pickle.load(f)
    with open(self.path + 'y.pkl', 'rb') as f:
      tar = pickle.load(f)

    if num_examples is not None:
      src, tar = src[:num_examples], tar[:num_examples]

    return src, tar

  def call(self, buffer_size, batch_size, num_examples=None):
    input, target = self.load_dataset(num_examples)
    input = self.preprocess_tokens(input, self.inp_data_len)
    target = self.preprocess_tokens(target, self.tar_data_len, True)

    input_train, input_val, target_train, target_val = train_test_split(input, target, test_size=0.1)

    self.num_train_examples = len(input_train)
    self.num_test_examples = len(input_val)

    train_dataset = tf.data.Dataset.from_tensor_slices(
        (input_train, target_train))
    train_dataset = train_dataset.shuffle(buffer_size)
    train_dataset = train_dataset.batch(batch_size, drop_remainder=True)
    train_dataset = train_dataset.repeat()

    val_dataset = tf.data.Dataset.from_tensor_slices((input_val, target_val))
    val_dataset = val_dataset.batch(batch_size, drop_remainder=True).repeat()

    return train_dataset, val_dataset


dataset = SimplifyDataset(params.dataset_path)
train_dataset, val_dataset = dataset.call(32000, params.batch_size)

In [None]:
train_dataset

In [64]:
dataset.num_train_examples, dataset.num_test_examples

(253, 29)

In [None]:
VOCAB_SIZE

## Model Setup
This cell contains the functions that builds the transformer model in Keras. 
Here we use a variant of the original transformer architecture known as ReZero (see [arXiv:2003.04887](https://arxiv.org/abs/2003.04887) ) for faster training convergence. 

## Model Design

In [41]:
def create_padding_mask(seq):
  seq = tf.cast(tf.math.equal(seq, 0), tf.float32)

  # add extra dimensions to add the padding
  # to the attention logits.
  return seq[:, tf.newaxis, tf.newaxis, :]  # (batch_size, 1, 1, seq_len)


def create_look_ahead_mask(size):
  mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0)
  return mask  # (seq_len, seq_len)


def create_masks(inp, tar):
  # Encoder padding mask
  enc_padding_mask = create_padding_mask(inp)

  # Used in the 2nd attention block in the decoder.
  # This padding mask is used to mask the encoder outputs.
  dec_padding_mask = create_padding_mask(inp)

  # Used in the 1st attention block in the decoder.
  # It is used to pad and mask future tokens in the input received by 
  # the decoder.
  look_ahead_mask = create_look_ahead_mask(tf.shape(tar)[1])
  dec_target_padding_mask = create_padding_mask(tar)
  combined_mask = tf.maximum(dec_target_padding_mask, look_ahead_mask)

  return enc_padding_mask, combined_mask, dec_padding_mask


def get_angles(pos, i, d_model):
  angle_rates = 1 / np.power(10000, (2 * (i//2)) / np.float32(d_model))
  return pos * angle_rates


def positional_encoding(position, d_model):
  angle_rads = get_angles(np.arange(position)[:, np.newaxis],
                          np.arange(d_model)[np.newaxis, :],
                          d_model)

  # apply sin to even indices in the array; 2i
  angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2])

  # apply cos to odd indices in the array; 2i+1
  angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2])

  pos_encoding = angle_rads[np.newaxis, ...]

  return tf.cast(pos_encoding, dtype=tf.float32)


def scaled_dot_product_attention(q, k, v, mask):
  """Calculate the attention weights.
  q, k, v must have matching leading dimensions.
  k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v.
  The mask has different shapes depending on its type(padding or look ahead) 
  but it must be broadcastable for addition.

  Args:
    q: query shape == (..., seq_len_q, depth)
    k: key shape == (..., seq_len_k, depth)
    v: value shape == (..., seq_len_v, depth_v)
    mask: Float tensor with shape broadcastable 
          to (..., seq_len_q, seq_len_k). Defaults to None.

  Returns:
    output, attention_weights
  """

  matmul_qk = tf.matmul(q, k, transpose_b=True)  # (..., seq_len_q, seq_len_k)

  # scale matmul_qk
  dk = tf.cast(tf.shape(k)[-1], tf.float32)
  scaled_attention_logits = matmul_qk / tf.math.sqrt(dk)

  # add the mask to the scaled tensor.
  if mask is not None:
    scaled_attention_logits += (mask * -1e9)  

  # softmax is normalized on the last axis (seq_len_k) so that the scores
  # add up to 1.
  attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1)  # (..., seq_len_q, seq_len_k)

  output = tf.matmul(attention_weights, v)  # (..., seq_len_q, depth_v)

  return output, attention_weights


class MultiHeadAttention(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads):
    super(MultiHeadAttention, self).__init__()
    self.num_heads = num_heads
    self.d_model = d_model

    assert d_model % self.num_heads == 0

    self.depth = d_model // self.num_heads

    self.wq = tf.keras.layers.Dense(d_model)
    self.wk = tf.keras.layers.Dense(d_model)
    self.wv = tf.keras.layers.Dense(d_model)

    self.dense = tf.keras.layers.Dense(d_model)

  def split_heads(self, x, batch_size):
    """Split the last dimension into (num_heads, depth).
    Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth)
    """
    x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth))
    return tf.transpose(x, perm=[0, 2, 1, 3])

  def call(self, v, k, q, mask):
    batch_size = tf.shape(q)[0]

    q = self.wq(q)  # (batch_size, seq_len, d_model)
    k = self.wk(k)  # (batch_size, seq_len, d_model)
    v = self.wv(v)  # (batch_size, seq_len, d_model)

    q = self.split_heads(q, batch_size)  # (batch_size, num_heads, seq_len_q, depth)
    k = self.split_heads(k, batch_size)  # (batch_size, num_heads, seq_len_k, depth)
    v = self.split_heads(v, batch_size)  # (batch_size, num_heads, seq_len_v, depth)

    # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth)
    # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k)
    scaled_attention, attention_weights = scaled_dot_product_attention(
        q, k, v, mask)

    scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3])  # (batch_size, seq_len_q, num_heads, depth)

    concat_attention = tf.reshape(scaled_attention, 
                                  (batch_size, -1, self.d_model))  # (batch_size, seq_len_q, d_model)

    output = self.dense(concat_attention)  # (batch_size, seq_len_q, d_model)

    return output, attention_weights


def point_wise_feed_forward_network(d_model, dff):
  return tf.keras.Sequential([
      tf.keras.layers.Dense(dff, activation='relu'),  # (batch_size, seq_len, dff)
      tf.keras.layers.Dense(d_model)  # (batch_size, seq_len, d_model)
  ])


class ReZero(tf.keras.layers.Layer):
    def __init__(self, name):
        super(ReZero, self).__init__(name=name)
        a_init = tf.zeros_initializer()
        self.alpha = tf.Variable(name=self.name + '-alpha',
            initial_value=a_init(shape=(1,), dtype="float32"), trainable=True
        )

    def call(self, inputs):
        return self.alpha * inputs


class EncoderLayer(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads, dff, rate=0.1):
    super(EncoderLayer, self).__init__()

    self.mha = MultiHeadAttention(d_model, num_heads)
    self.ffn = point_wise_feed_forward_network(d_model, dff)

    self.rz1 = ReZero(self.name + 'rz1')
    self.rz2 = ReZero(self.name + 'rz2')

    self.dropout1 = tf.keras.layers.Dropout(rate)
    self.dropout2 = tf.keras.layers.Dropout(rate)

  def call(self, x, training, mask):

    attn_output, _ = self.mha(x, x, x, mask)  # (batch_size, input_seq_len, d_model)
    attn_output = self.dropout1(attn_output, training=training)
    out1 = x + self.rz1(attn_output)
    #out1 = self.layernorm1(x + attn_output)  # (batch_size, input_seq_len, d_model)

    ffn_output = self.ffn(out1)  # (batch_size, input_seq_len, d_model)
    ffn_output = self.dropout2(ffn_output, training=training)
    out2 = out1 + self.rz2(ffn_output)
    #out2 = self.layernorm2(out1 + ffn_output)  # (batch_size, input_seq_len, d_model)

    return out2


class DecoderLayer(tf.keras.layers.Layer):
  def __init__(self, d_model, num_heads, dff, rate=0.1):
    super(DecoderLayer, self).__init__()

    self.mha1 = MultiHeadAttention(d_model, num_heads)
    self.mha2 = MultiHeadAttention(d_model, num_heads)

    self.ffn = point_wise_feed_forward_network(d_model, dff)

    self.rz1 = ReZero(self.name + 'rz1')
    self.rz2 = ReZero(self.name + 'rz2')
    self.rz3 = ReZero(self.name + 'rz3')

    self.dropout1 = tf.keras.layers.Dropout(rate)
    self.dropout2 = tf.keras.layers.Dropout(rate)
    self.dropout3 = tf.keras.layers.Dropout(rate)


  def call(self, x, enc_output, training, 
           look_ahead_mask, padding_mask):
    # enc_output.shape == (batch_size, input_seq_len, d_model)

    attn1, attn_weights_block1 = self.mha1(x, x, x, look_ahead_mask)  # (batch_size, target_seq_len, d_model)
    attn1 = self.dropout1(attn1, training=training)
    out1 = x + self.rz1(attn1)

    attn2, attn_weights_block2 = self.mha2(
        enc_output, enc_output, out1, padding_mask)  # (batch_size, target_seq_len, d_model)
    attn2 = self.dropout2(attn2, training=training)
    out2 = out1 + self.rz2(attn2)

    ffn_output = self.ffn(out2)  # (batch_size, target_seq_len, d_model)
    ffn_output = self.dropout3(ffn_output, training=training)
    out3 = out2 + self.rz3(ffn_output)

    return out3, attn_weights_block1, attn_weights_block2


class Encoder(tf.keras.layers.Layer):
  def __init__(self, num_layers, d_model, num_heads, dff,
               maximum_position_encoding, rate=0.1):
    super(Encoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers

    self.pos_encoding = positional_encoding(maximum_position_encoding, 
                                            self.d_model)

    self.enc_layers = [EncoderLayer(d_model, num_heads, dff, rate) 
                       for _ in range(num_layers)]

    self.dropout = tf.keras.layers.Dropout(rate)

  def call(self, x, training, mask):

    seq_len = tf.shape(x)[1]
    x += self.pos_encoding[:, :seq_len, :]

    x = self.dropout(x, training=training)

    for i in range(self.num_layers):
      x = self.enc_layers[i](x, training, mask)

    return x  # (batch_size, input_seq_len, d_model)


class Decoder(tf.keras.layers.Layer):
  def __init__(self, num_layers, d_model, num_heads, dff,
               maximum_position_encoding, rate=0.1):
    super(Decoder, self).__init__()

    self.d_model = d_model
    self.num_layers = num_layers

    self.pos_encoding = positional_encoding(maximum_position_encoding, d_model)

    self.dec_layers = [DecoderLayer(d_model, num_heads, dff, rate) 
                       for _ in range(num_layers)]
    self.dropout = tf.keras.layers.Dropout(rate)

  def call(self, x, enc_output, training, 
           look_ahead_mask, padding_mask):

    seq_len = tf.shape(x)[1]
    attention_weights = {}

    #x *= tf.math.sqrt(tf.cast(self.d_model, tf.float32))
    x += self.pos_encoding[:, :seq_len, :]

    x = self.dropout(x, training=training)

    for i in range(self.num_layers):
      x, block1, block2 = self.dec_layers[i](x, enc_output, training,
                                             look_ahead_mask, padding_mask)

      attention_weights['decoder_layer{}_block1'.format(i+1)] = block1
      attention_weights['decoder_layer{}_block2'.format(i+1)] = block2

    # x.shape == (batch_size, target_seq_len, d_model)
    return x, attention_weights


class Transformer(tf.keras.Model):
  def __init__(self, num_layers, d_model, num_heads, dff, input_vocab_size, 
               target_vocab_size, pe_input, pe_target, rate=0.1):
    super(Transformer, self).__init__()

    self.tok_embedding = tf.keras.layers.Embedding(input_vocab_size, d_model)
    self.dec_embedding = tf.keras.layers.Embedding(target_vocab_size, d_model)

    self.tokenizer = Encoder(num_layers, d_model, num_heads, dff,
                             pe_input, rate)
    self.decoder = Decoder(num_layers, d_model, num_heads, dff,
                           pe_target, rate)
    self.final_layer = tf.keras.layers.Dense(target_vocab_size)

  def call(self, inp, tar, training, enc_padding_mask, 
           look_ahead_mask, dec_padding_mask):
    
    inp = self.tok_embedding(inp)
    enc_output = self.tokenizer(inp, training, enc_padding_mask)  # (batch_size, inp_seq_len, d_model)

    # dec_output.shape == (batch_size, tar_seq_len, d_model)
    tar = self.dec_embedding(tar)
    dec_output, attention_weights = self.decoder(
        tar, enc_output, training, look_ahead_mask, dec_padding_mask)

    final_output = self.final_layer(dec_output)  # (batch_size, tar_seq_len, target_vocab_size)

    return final_output, attention_weights

## Model Creation

In [29]:
K.clear_session()

transformer = Transformer(
    num_layers=params.num_layers,
    d_model=params.d_model,
    num_heads=params.num_heads,
    dff=params.dff,
    input_vocab_size=VOCAB_SIZE,
    target_vocab_size=VOCAB_SIZE, 
    pe_input=params.max_len,
    pe_target=params.max_len,
    rate=params.dropout_rate)


learning_rate = 1e-4
optimizer = tf.keras.optimizers.Adam(learning_rate, beta_1=0.9, beta_2=0.98, 
                                     epsilon=1e-9)

loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')

def loss_function(real, pred):
  mask = tf.math.logical_not(tf.math.equal(real, 0))
  loss_ = loss_object(real, pred)

  mask = tf.cast(mask, dtype=loss_.dtype)
  loss_ *= mask

  return tf.reduce_sum(loss_)/tf.reduce_sum(mask)


def accuracy_function(real, pred):
  accuracies = tf.equal(tf.cast(real, tf.int64), tf.argmax(pred, axis=2))

  mask = tf.math.logical_not(tf.math.equal(real, 0))
  accuracies = tf.math.logical_and(mask, accuracies)

  accuracies = tf.cast(accuracies, dtype=tf.float32)
  mask = tf.cast(mask, dtype=tf.float32)
  return tf.reduce_sum(accuracies)/tf.reduce_sum(mask)


train_loss = tf.keras.metrics.Mean(name='train_loss')
val_loss = tf.keras.metrics.Mean(name='val_loss')

train_accuracy = tf.keras.metrics.Mean(name='train_accuracy')
val_accuracy = tf.keras.metrics.Mean(name='val_accuracy')


train_step_signature = [
    tf.TensorSpec(shape=(None, dataset.inp_data_len), dtype=tf.int32),
    tf.TensorSpec(shape=(None, dataset.tar_data_len), dtype=tf.int32),
]

@tf.function(input_signature=train_step_signature)
def train_step(inp, tar):
  tar_inp = tar[:, :-1]
  tar_real = tar[:, 1:]

  enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)

  with tf.GradientTape() as tape:
    predictions, _ = transformer(inp, tar_inp, 
                                 True, 
                                 enc_padding_mask, 
                                 combined_mask, 
                                 dec_padding_mask)
    loss = loss_function(tar_real, predictions)

  gradients = tape.gradient(loss, transformer.trainable_variables)    
  optimizer.apply_gradients(zip(gradients, transformer.trainable_variables))

  train_loss(loss)
  train_accuracy(accuracy_function(tar_real, predictions))


@tf.function(input_signature=train_step_signature)
def val_step(inp, tar):
  tar_inp = tar[:, :-1]
  tar_real = tar[:, 1:]

  enc_padding_mask, combined_mask, dec_padding_mask = create_masks(inp, tar_inp)

  predictions, _ = transformer(inp, tar_inp, 
                                True, 
                                enc_padding_mask, 
                                combined_mask, 
                                dec_padding_mask)
  loss = loss_function(tar_real, predictions)

  val_loss(loss)
  val_accuracy(accuracy_function(tar_real, predictions))

## Training Loop

Automatically saves the model weights that gives the smallest loss on the validation set.

In [37]:
steps_per_epoch = dataset.num_train_examples//params.batch_size
steps_per_val = dataset.num_test_examples//params.batch_size
best_val = np.inf


for epoch in range(params.epochs):
  start = time.time()
  
  train_loss.reset_states()
  train_accuracy.reset_states()

  print(f'Ep: {epoch}, Current Best: {best_val}')
  for (i, (inp, targ)) in enumerate(train_dataset.take(steps_per_epoch)):
    train_step(inp, targ)

    if i % 200 == 0:
      print('Iter {} Loss {:.4f} Train Acc {:.4f}'.format(
            i, train_loss.result(), train_accuracy.result()))
      
  print('Epoch {:.4f} Train Loss {:.4f} Train Acc {:.4f}'.format(
      epoch, train_loss.result(), train_accuracy.result()))
  print('Time taken for 1 epoch: {:.4f} secs'.format(time.time() - start))

  val_loss.reset_states()
  val_accuracy.reset_states()
  for inp, targ in train_dataset.take(steps_per_val):
    val_step(inp, targ)
  validation_loss = val_loss.result()
  validation_acc = val_accuracy.result()

  display.clear_output()
  print("Validation Loss: {:.4f} Vaidation Acc {:.4f}\n".format(
      validation_loss, validation_acc))
  if validation_loss < best_val:
    transformer.save_weights(params.model_path)
    best_val = validation_loss


Validation Loss: 0.0000 Vaidation Acc 0.0000

Ep: 15, Current Best: 0.0
Iter 0 Loss 1.9330 Train Acc 0.3293


KeyboardInterrupt: ignored

# Evaluate

Here we test our model on select equations. Obviously, quality of the output varies depending on dataset and model size. Here an 'InvalidPrefixExpression' is thrown when the model outputs an invalid expression.

In [33]:
def evaluate(eqa, max_length=40):
  # inp sentence is portuguese, hence adding the start and end token
  #eqa = DATASET_ENV.sympy_to_prefix(eqa)
  #eqa = [DATASET_ENV.word2id[w] for w in eqa]
  #eqa = np.array(eqa)
  eqa = tf.convert_to_tensor([eqa])
  #sentence = tokenizers.pt.tokenize(sentence).to_tensor()

  encoder_input = eqa

  # as the target is english, the first word to the transformer should be the
  # english start token.
  output = tf.convert_to_tensor([START_TOKEN])
  output = tf.expand_dims(output, 0)

  for i in range(max_length):
    enc_padding_mask, combined_mask, dec_padding_mask = create_masks(
        encoder_input, output)

    # predictions.shape == (batch_size, seq_len, vocab_size)
    predictions, attention_weights = transformer(encoder_input, 
                                                 output,
                                                 False,
                                                 enc_padding_mask,
                                                 combined_mask,
                                                 dec_padding_mask)

    # select the last word from the seq_len dimension
    predictions = predictions[: ,-1:, :]  # (batch_size, 1, vocab_size)

    predicted_id = tf.argmax(predictions, axis=-1)
    predicted_id = tf.cast(predicted_id, tf.int32)

    # return the result if the predicted_id is equal to the end token
    if predicted_id == END_TOKEN:
      break

    # concatentate the predicted_id to the output which is given to the decoder
    # as its input.
    output = tf.concat([output, predicted_id], axis=-1)

  return output[:, 1:]

In [34]:
from sympy.parsing.sympy_parser import (parse_expr, standard_transformations, 
                                        implicit_multiplication_application,
                                        convert_xor)
from sympy import symbols, diff, init_printing, preorder_traversal, simplify

test_expr = "2x - 2x"
TRANSFORMATIONS = (standard_transformations + (implicit_multiplication_application,) + (convert_xor,))
test_expr = parse_expr(test_expr, transformations=TRANSFORMATIONS, evaluate=False)
test_expr

-2*x + 2*x

model output:

In [39]:
model_output = evaluate(sympy_to_sequences(test_expr, DATASET_ENV)).numpy()[0]
sequences_to_sympy(model_output, DATASET_ENV)

InvalidPrefixExpression: ignored

simplification according to Sympy:

In [99]:
simplify(test_expr)

0

In [100]:
test_expr = "2x^3+5x^3"
TRANSFORMATIONS = (standard_transformations + (implicit_multiplication_application,) + (convert_xor,))
test_expr = parse_expr(test_expr, transformations=TRANSFORMATIONS, evaluate=False)
test_expr

2*x**3 + 5*x**3

model output:

In [102]:
model_output = evaluate(sympy_to_sequences(test_expr, DATASET_ENV)).numpy()[0]
sequences_to_sympy(model_output, DATASET_ENV)

InvalidPrefixExpression: ignored

In [103]:
simplify(test_expr)

7*x**3