In [None]:
import numpy as np
import sympy

# a,b represent parameters, x,y data
a, b, x, y = sympy.symbols("a,b,x,y")
model = a * x + b * y

Normally one would lambdify the complete model with respect to all free symbols

In [None]:
eval_model = sympy.lambdify((x, y, a, b), model)

x_input = np.array([1, 2, 3, 4, 5])
y_input = np.array([5, 4, 3, 2, 1])
eval_model(x_input, y_input, 1, 2)

When using an `Estimator` the data input is constant though. Which means that the data is supplied is always the same. Hence optimization in form of caching can be performed.
The idea is to traverse through the `Expr` tree and identify sub expressions that are independent of parameters, the only parts that change during the optimization process.

In this example we fix b to the value of 2. Hence the sub expression b * y can be precomputed.

In [None]:
constant_symbols = {b, x, y}
constant_sub_expressions = []

def find_constant_subexpressions(expr) -> bool:
    if not expr.args:
        if expr in constant_symbols:
            return True
        return False
    
    is_constant = True
    temp_constant_sub_expression = []
    for arg in expr.args:
        if find_constant_subexpressions(arg):
            if arg.args:
                temp_constant_sub_expression.append(arg)
        else:
            is_constant = False
            
    if not is_constant and temp_constant_sub_expression:
        global constant_sub_expressions
        constant_sub_expressions += temp_constant_sub_expression
    return is_constant

find_constant_subexpressions(model)

In [None]:
constant_sub_expressions

So let's lambdify those constant subexpressions and substitute them with a placeholder

In [None]:
from typing import Dict, Callable

constant_sub_functions: Dict[sympy.Symbol, Callable] = {}

for sub_expr in constant_sub_expressions:
    placeholder = sympy.Symbol(f"cached_{str(sub_expr)}")
    
    constant_sub_functions[placeholder] = sympy.lambdify(set(sub_expr.free_symbols), sub_expr)
    model = model.subs(sub_expr, placeholder)

In [None]:
model

In [None]:
constant_sub_functions

Now the evaluation of the full model uses the cached values as input

In [None]:
cached_eval_model = sympy.lambdify((a,x)+tuple(constant_sub_functions.keys()), model)

In [None]:
cached_values = list(constant_sub_functions.values())[0](2, y_input)

In [None]:
cached_eval_model(1, x_input, cached_values)

Which is the same as 

In [None]:
eval_model(x_input, y_input, 1, 2)

So this works!