# Sympy: substitution of nested functions

## Introduce the parameters

In [5]:
from sympy import symbols, Function, UnevaluatedExpr

x, y= symbols('x y')
f=symbols('f',cls=Function)(x)
g=symbols('g',cls=Function)(f,y)
k=symbols('k',cls=Function)(g,y)

display(f)
display(g)
display(k)



f(x)

g(f(x), y)

k(g(f(x), y), y)

## kwargs concept

In [6]:
def f_func1(**kwargs):
    return x+1
def g_func1(**kwargs):
    return f+y
def k_func1(**kwargs):
    return g+y

display(f_func1())
display(g_func1())
display(k_func1())

x + 1

y + f(x)

y + g(f(x), y)

## introducing substitution

In [7]:
def f_func2(**kwargs):
    expr=x+1
    expr=expr.subs(kwargs)
    return expr

def g_func2(**kwargs):
    expr=f+y
    expr=expr.subs(kwargs)
    return expr

def k_func2(**kwargs):
    expr=g+y
    expr=expr.subs(kwargs)
    return expr

print('--------------------------')
display(f_func2())
display(f_func2(x=2))
print('--------------------------')
display(g_func2())
display(g_func2(f=f_func2()))       #f=f_func2() has no effect because g is already replaced
display(g_func2(f=f_func2(),x=3))   #f will not be replaced because sympy does not substitute the nested functions.
print('--------------------------')
display(k_func2())
display(k_func2(g=g_func2()))      #g=g_func2() has no effect because g is already replaced
display(k_func2(g=g_func2(),x=2))  # f will not be replaced because sympy does not substitute the nested functions.

--------------------------


x + 1

3

--------------------------


y + f(x)

y + f(x)

y + f(3)

--------------------------


y + g(f(x), y)

y + g(f(x), y)

y + g(f(2), y)

### introducing default parameters

In [8]:
def f_func3(x=x,**kwargs):
    x=UnevaluatedExpr(x)
    expr=x+1
    expr=expr.subs(kwargs)
    return expr

def g_func3(f=f,**kwargs):
    #f=UnevaluatedExpr(f)
    expr=f+y
    expr=expr.subs(kwargs)
    return expr

def k_func3(g=g,**kwargs):
    g=UnevaluatedExpr(g)
    expr=g+y
    expr=expr.subs(kwargs)
    return expr


print('--------------------------')
display(g_func3())
display(g_func3(f=f_func3()))               #f=f_func2() works beacuse of UnevaluatedExpr
display(g_func3(f=f_func3(),x=3))           # f will not be replaced because sympy does not substitute the nested functions.
display(g_func3(f=f_func3(),x=3).doit()) 
print('--------------------------')
display(k_func3())
display(k_func3(g=g_func3()))                #g=g_func3() works beacuse of UnevaluatedExpr
display(k_func3(g=g_func3(),f=f_func3()))    #f=f_func3() does NOT work because sympy does not substitute the nested functions.
display(k_func3(g=g_func3(),f=f_func3(),x=3)) 

--------------------------


y + f(x)

y + 1 + x

y + 1 + 3

y + 4

--------------------------


y + g(f(x), y)

y + (y + f(x))

y + (y + f(x))

y + (y + f(3))

#### using an auxilary paramtere ( Does NOT work)

In [9]:
f1, g1= symbols('f1 g1')

def f_func4(x=x,**kwargs):
    x=UnevaluatedExpr(x)
    expr=x+1
    expr=expr.subs(kwargs)
    return expr

def g_func4(f1=f,**kwargs):
    f1=UnevaluatedExpr(f1)
    expr=f1+y
    expr=expr.subs(kwargs)
    return expr

def k_func4(g1=g,**kwargs):
    g1=UnevaluatedExpr(g1)
    expr=g1+y
    expr=expr.subs(kwargs)
    return expr

print('--------------------------')
display(k_func4())
display(k_func4(g1=g_func4()))                 #g=g_func3() works beacuse of UnevaluatedExpr
display(k_func4(g1=g_func4(),f1=f_func4()))    #f=f_func3() does NOT work because sympy does not substitute the nested functions.
display(k_func4(g1=g_func4(),f1=f_func4(),x=3)) 

--------------------------


y + g(f(x), y)

y + (y + f(x))

y + (y + f(x))

y + (y + f(3))

## This works too, but skip to the final solution.

In [10]:
from sympy import symbols, Function, UnevaluatedExpr
x, y, z= symbols('x y z')
f=symbols('f',cls=Function)(x)
g=symbols('g',cls=Function)(x,y)
k=symbols('k',cls=Function)(x,y)

def f_func5(x=x,**kwargs):
    x=UnevaluatedExpr(x)
    expr=x+1
    expr=expr.subs(kwargs)
    return expr

def g_func5(f=f,**kwargs):
    expr=f+y
    expr=expr.subs(kwargs)
    return expr

def k_func5(g=g,**kwargs):
    #print(kwargs)
    expr=g+y+z
    expr=expr.subs(kwargs)
    return expr

# nested dictionary
db={}
db={'g':g_func5(),'f':f_func5()}
db={'g':g_func5(**db),'f':f_func5(**db)}

display(k_func5())
display(k_func5(g=g_func5()))                 
display(k_func5(g=g_func5(**db),f=f_func5(**db)))    
display(k_func5(g=g_func5(**db),f=z,x=3)) 
display(k_func5(**db)) 
display(k_func5(**db,x=3)) 

y + z + g(x, y)

2*y + z + f(x)

2*y + z + 1 + x

2*y + z + 1 + 3

2*y + z + 1 + x

2*y + z + 1 + 3

# Final solution (Works !!!)
- The reason for the nested function substituton was not working was that when introducing kwargs into a function, the keys of the dictionary (kwargs) must be strings. this string is ok for symbols but for function wont be replaced. the key is to evaluate the dict items using eval() so sympy subs recognize them. This simplify the problem. 
    
- Also the Unevaluated expressions wont be needed in the function as it can be defined globally and once.
- The default values won't be needed either.

In [3]:
from sympy import symbols, Function, UnevaluatedExpr
x, y, z= symbols('x y z')
f=symbols('f',cls=Function)(x)
g=symbols('g',cls=Function)(x,y)
k=symbols('k',cls=Function)(x,y)
x=UnevaluatedExpr(x)

def f_func6(**kwargs):
    kwargs = {eval(key): value for key, value in kwargs.items()}
    expr=x+1
    expr=expr.subs(kwargs)
    return expr

def g_func6(**kwargs):
    kwargs = {eval(key): value for key, value in kwargs.items()}
    expr=f+y
    expr=expr.subs(kwargs)
    return expr

def k_func6(**kwargs):
    kwargs = {eval(key): value for key, value in kwargs.items()}
    #print(kwargs)
    expr=g+y+z
    expr=expr.subs(kwargs)
    return expr

#  dictionary ( does not to be nested anymore)
db={}
db={'g':g_func6(),'f':f_func6()}



display(k_func6())
display(k_func6(g=g_func6()))                 
display(k_func6(g=g_func6(**db),f=f_func6(**db)))    
display(k_func6(g=g_func6(**db),f=z,x=3)) 
display(k_func6(**db)) 
display(k_func6(**db,x=3)) 

y + z + g(x, y)

2*y + z + f(x)

2*y + z + 1 + x

2*y + z + 4

2*y + z + 1 + x

2*y + z + 4