In [1]:
import os
import time
import torch
import pickle
import sympy as sp
import numpy as np
from scipy.optimize import least_squares
from ssde import PDERecursionSolver
from ssde.execute import cython_recursion_execute as ce
from ssde.execute import python_execute as pe
from ssde.program import Program
from ssde.utils import jupyter_logging, cube
from ssde.const import make_const_optimizer
from ssde.pde import function_map
from ssde.task.recursion.recursion import make_regression_metric
from ssde.gradients import hessian, jacobian
from ssde.library import PlaceholderConstant


np.random.seed(0)
LEFT, RIGHT = -1, 1
TLEFT, TRIGHT = 0, 1
if not os.path.exists('logs'):
    os.makedirs('logs')
SOLUTION = sp.sympify("2.5*x1**4 - 1.3*x2**3 + 0.5*x3**2 - 1.7*t")

def calculate_source(solution):
    ''' Calculate the source term of the PDE '''
    # solution: sympy expr of the solution
    real_params = dict()
    for symbol in sp.preorder_traversal(solution):
        if isinstance(symbol, sp.Symbol):
            exec('%s = sp.Symbol("%s")' % (symbol.name, symbol.name))
            if symbol.name not in real_params:
                real_params[symbol.name] = None
    real_params = sorted(list(real_params.keys()))
    print(f'real_params:{real_params}')
    source = 0
    for i in real_params:
        if i == 't':
            source += sp.diff(solution, i)
            continue
        source -= sp.diff(solution, i, 2)
    print(f'source:{source}')
    solution_func = sp.lambdify(real_params, solution, modules='numpy')
    source_func = sp.lambdify(real_params, source, modules='numpy')
    return solution_func, source_func, real_params

def replace_nxexpr(traversals, index=0, new_traversal=None):
    """
    Replace Nxexpr in the traversal with the corresponding expression recursively.

    Parameters
    ----------
    traversals : list
        The list of traversal of the single variable.
    index : int
        The index of current traversal.
    new_traversal : list
        The result of the replacement.
    """
    if new_traversal is None:
        new_traversal = []

    if index + 1 == len(traversals):
        return new_traversal + traversals[index]
    
    current_ls = traversals[index]
    for token in current_ls:
        if token.name != 'Nxexpr':
            new_traversal.append(token)
        else:
            sub_result = replace_nxexpr(traversals, index+1, [])
            new_traversal.extend(sub_result)
    return new_traversal

def opti_nxexpr(traversal, x, y):
    def opti_consts(nxexpr):
        for token in traversal:
            if token.name == 'Nxexpr':
                token.value = nxexpr.reshape(-1,1)
        y_hat = ce(traversal, x)
        return (y_hat-y).ravel()
    return opti_consts

@jupyter_logging("logs/single_var.log")
def solve_single_var(X_input, y_input, var_index, config_file, diff=None):
    model = PDERecursionSolver(config_file)
    start_time = time.time()
    config = model.fit(X_input, y_input, start_n_var=var_index,diff=diff)
    print(f'Time used: {time.time()-start_time}')
    traversal = config.program_.traversal
    expr = config.program_.sympy_expr
    print(f'Identified var x{var_index}\'s parametirc expression:')
    print(expr)
    print(f'Identified var x{var_index}\'s traversal:')
    print(traversal)
    return traversal, expr, model


solution_func, source_func, real_params = calculate_source(SOLUTION)

  from .autonotebook import tqdm as notebook_tqdm


real_params:['t', 'x1', 'x2', 'x3']
source:-30.0*x1**2 + 7.8*x2 - 2.7


## Expression of `t` 

In [2]:
# boundary conditions of t direction
n_tbc_dim1 = 10
n_x123bc_dim1 = 4

t_bc = np.linspace(TLEFT, TRIGHT, n_tbc_dim1)
t_bc_dim1 = t_bc.reshape(-1, 1)
X123bc = cube(LEFT, RIGHT, n_x123bc_dim1)
X1bc = X123bc[:, 0]
X2bc = X123bc[:, 1]
X3bc = X123bc[:, 2]
X, yz_inx = np.meshgrid(t_bc, np.arange(n_x123bc_dim1), indexing='ij')
coordinates = np.stack([X.flatten(), X1bc[yz_inx].flatten(), X2bc[yz_inx].flatten(), X3bc[yz_inx].flatten()], axis=-1)

y_bc_dim1 = solution_func(*coordinates.transpose(1,0)).reshape(-1, n_x123bc_dim1)

In [3]:
X_input = [None, t_bc_dim1]
y_input = [None, y_bc_dim1]

def u_t_forward(y, t):
    return jacobian(y, t, i=0, j=0)
LOG_PATH = 'logs/3dheat_t.log'
config_file = 'configs/config_heat_gp.json'
if os.path.exists(LOG_PATH):
    os.remove(LOG_PATH)
t_traversal, t_expr, t_model = solve_single_var(X_input, y_input, 
                                                1, config_file, log_path=LOG_PATH,diff=u_t_forward)


t_model.save('models/3dheat_t_model')
print('t model successfully saved')
for i in t_traversal:
    if i.name == 'x1':
        i.name = 't'
with open('models/3dheat_t_traversal.pkl', 'wb') as f:
    pickle.dump(t_traversal, f)
print('t traversal successfully saved')

Saved Trainer state to models/3dheat_t_model/trainer.json.
t model successfully saved
t traversal successfully saved


## Expression of `X1` 

In [11]:
# Compute the value of const (actually the label of x2)
n_tbc = 1
n_x1bc = 31
n_x23bc = 4
t_bc = np.array([TLEFT])
X1_bc = np.linspace(LEFT, RIGHT, n_x1bc).reshape(-1,1)
X_bc_dim2 = X1_bc
X23_bc = np.random.uniform(LEFT, RIGHT, (n_x23bc,2))
X2_bc = X23_bc[:, 0]
X3_bc = X23_bc[:, 1]
T, X1, x23_inx = np.meshgrid(t_bc, X1_bc, np.arange(n_x23bc), indexing='ij')
x1_points = np.stack([T.flatten(), X1.flatten(), X2_bc[x23_inx].flatten(), X3_bc[x23_inx].flatten()], axis=-1)
y_bc = solution_func(*x1_points.T).reshape(-1, 1)

In [12]:
with open('models/3dheat_t_traversal.pkl', 'rb') as f:
    t_traversal = pickle.load(f)

opti_x1expr = opti_nxexpr(t_traversal, x1_points, y_bc)

consts = np.ones(x1_points.shape[0])
res = least_squares(opti_x1expr, consts, method='lm')
y_bc_reshaped = res.x.reshape(n_tbc, n_x1bc, n_x23bc,1)
y_bc_transposed = y_bc_reshaped.transpose(1,0,2,3)
y_bc_dim2 = y_bc_transposed.reshape((n_x1bc, -1))
# test optimization is correct or not
# jin1
# consts_real = []
# for i in range(4):
#     consts_real.append(2.5 * X1_bc**4 - 1.3* X2_bc[i]**3 + 0.5 * X3_bc[i]**2)
# consts_real = np.concatenate(consts_real, axis=1)
# print('abs error:', np.abs(consts_real-y_bc_dim2).mean())

abs error: 0.0


In [13]:
X_input = [None, X_bc_dim2]
y_input = [None, y_bc_dim2]

if os.path.exists('logs/3dheat_x1.log'):
    os.remove('logs/3dheat_x1.log')
config_file = 'configs/config_heat_gp.json'
LOG_PATH = 'logs/3dheat_x1.log'
x1_traversal, x1_expr, x1_model = solve_single_var(X_input, y_input, 
                                                1, config_file, log_path=LOG_PATH)
x1_model.save('models/3dheat_x1_model')
print('x1 model successfully saved')
with open('models/3dheat_x1_traversal.pkl', 'wb') as f:
    pickle.dump(x1_traversal, f)
print('x1 traversal successfully saved')

Saved Trainer state to models/3dheat_x1_model/trainer.json.
x1 model successfully saved
x1 traversal successfully saved


In [2]:
traversals = []
with open(f'models/3dheat_t_traversal.pkl', 'rb') as f:
    traversals.append(pickle.load(f))
for i in range(1,2):
    with open(f'models/3dheat_x{i}_traversal.pkl', 'rb') as f:
        traversals.append(pickle.load(f))
new_traversal = replace_nxexpr(traversals)
for token in new_traversal:
    if token.input_var is not None and token.name[0] == 'x':
        token.input_var = int(token.name[1])

# To optimize nxexpr more accurately, we need use more points on the boundary to optimize the generated skeleton before
n_t = 10
n_x1bc = 21
t_refine = np.linspace(TLEFT, TRIGHT, n_t)
X1_refine = np.linspace(LEFT, RIGHT, n_x1bc)
X2_refine, X3_refine = np.array([LEFT]), np.array([RIGHT])
coordinates = np.array(np.meshgrid(t_refine, X1_refine, X2_refine, X3_refine)).T.reshape(-1, 4)
y_refine = solution_func(*coordinates.T).reshape(-1, 1)

In [3]:
const_pos = [i for i, t in enumerate(new_traversal) if isinstance(t, PlaceholderConstant)]
for i in const_pos:
    if new_traversal[i].name == 'Nxexpr':
        temp = new_traversal[i].value
        new_traversal[i].value = 1
ini_const = np.array([new_traversal[i].value for i in const_pos], dtype=np.float64).ravel()
def refine_consts(consts):
    for i, j in enumerate(const_pos):
        new_traversal[j].value = consts[i].reshape(-1,1)
    y_hat = ce(new_traversal, coordinates)
    return (y_hat-y_refine).ravel()
res = least_squares(refine_consts, ini_const, method='lm')
consts = res.x
for i, j in enumerate(const_pos):
    if new_traversal[j].name != 'Nxexpr':
        new_traversal[j].value = consts[i].reshape(-1)
    else:
        new_traversal[j].value = temp

test_p = Program()
test_p.traversal = new_traversal
sym_expr = sp.simplify(sp.N(sp.expand(sp.sympify(test_p.sympy_expr)),4))
print(f'Identified solution: {sym_expr}')

with open('models/3dheat_dim2_traversal.pkl', 'wb') as f:
    pickle.dump(new_traversal, f)

Identified solution: Nxexpr - 1.7*t + 2.5*x1**4


## Expression of `X2`

In [16]:
# Compute the value of const (actually the label of x2)
n_tbc = 1
n_x1bc = 1
n_x2bc = 31
n_x3bc = 4
t_bc = np.array([TLEFT])
X1_bc = np.array([LEFT])
X2_bc = np.linspace(LEFT, RIGHT, n_x2bc).reshape(-1,1)
X3_bc = np.random.uniform(LEFT, RIGHT, (n_x3bc,1))
t, X1, X2, X3 = np.meshgrid(t_bc, X1_bc, X2_bc, X3_bc, indexing='ij')

x1_points = np.stack([t.ravel(), X1.ravel(), X2.ravel(), X3.ravel()], axis=-1)
X_bc_dim3 = np.concatenate((np.tile(X1_bc, (n_x2bc,1)),X2_bc),axis=1)
y_bc = solution_func(*x1_points.T).reshape(-1, 1)

In [17]:
opti_x2expr = opti_nxexpr(new_traversal, x1_points, y_bc)
consts = np.ones(x1_points.shape[0])
res = least_squares(opti_x2expr, consts, method='lm')
y_bc_dim3 = res.x.reshape(-1,n_x3bc)
# test optimization is correct or not
# consts_real = []
# for i in range(n_x3bc):
#     consts_real.append(1.3* X2_bc**3 - 0.5 * X3_bc[i]**2)
# consts_real = np.concatenate(consts_real, axis=1)
# print('abs error:', np.abs(consts_real-y_bc_dim3).mean())

In [None]:
X_input = [None, X_bc_dim3]
y_input = [None, y_bc_dim3]
if os.path.exists('logs/3dheat_x2.log'):
    os.remove('logs/3dheat_x2.log')
config_file = 'configs/config_heat_gp.json'
LOG_PATH = 'logs/3dheat_x2.log'
x2_traversal, x2_expr, x2_model = solve_single_var(X_input, y_input, 
                                                2, config_file, log_path=LOG_PATH)
x2_model.save('models/3dheat_x2_model')
print('x2 model successfully saved')
with open('models/3dheat_x2_traversal.pkl', 'wb') as f:
    pickle.dump(x2_traversal, f)
print('x2 traversal successfully saved')

In [5]:
with open(f'models/3dheat_dim2_traversal.pkl', 'rb') as f:
    dim2_traversal = pickle.load(f)
with open(f'models/3dheat_x2_traversal.pkl', 'rb') as f:
    x2_traversal = pickle.load(f)
new_traversal = replace_nxexpr([dim2_traversal, x2_traversal])

for token in new_traversal:
    if token.name == 'Nxexpr':
            temp = token.value
            token.value = 1
    if token.input_var is not None:
        if token.name[0] == 'x':
            token.input_var = int(token.name[1])
        

# To optimize nxexpr more accurately, we need use more points on the boundary to optimize the generated skeleton before
n_t = 10
n_x1bc = 11
n_x2bc = 11
t_refine = np.linspace(TLEFT, TRIGHT, n_t)
X1_refine = np.linspace(LEFT, RIGHT, n_x1bc)
X2_refine =  np.linspace(LEFT, RIGHT, n_x2bc)
X3_refine = np.array([RIGHT])
coordinates = np.array(np.meshgrid(t_refine, X1_refine, X2_refine, X3_refine)).T.reshape(-1, 4)
y_refine = solution_func(*coordinates.T).reshape(-1, 1)

# 优化
const_pos = [i for i, t in enumerate(new_traversal) if isinstance(t, PlaceholderConstant)]
ini_const = np.array([new_traversal[i].value for i in const_pos], dtype=np.float64).ravel()
def refine_consts(consts):
    for i, j in enumerate(const_pos):
        new_traversal[j].value = consts[i].reshape(-1,1)
    y_hat = ce(new_traversal, coordinates)
    return (y_hat-y_refine).ravel()
res = least_squares(refine_consts, ini_const, method='lm')
consts = res.x
for i, j in enumerate(const_pos):
    if new_traversal[j].name != 'Nxexpr':
        new_traversal[j].value = consts[i].reshape(-1)
    else:
        new_traversal[j].value = temp

test_p = Program()
test_p.traversal = new_traversal
sym_expr = sp.simplify(sp.N(sp.expand(sp.sympify(test_p.sympy_expr)),4))
print(f'Identified solution: {sym_expr}')
with open('models/3dheat_dim3_traversal.pkl', 'wb') as f:
    pickle.dump(new_traversal, f)

Identified solution: Nxexpr - 1.7*t + 2.5*x1**4 - 1.3*x2**3


## Expression of `x3`

In [20]:
# Compute the value of const (actually the label of x2)
n_tbc = 1
n_x1bc = 1
n_x2bc = 1
n_x3bc = 31
t_bc = np.array([TLEFT])
X1_bc = np.array([LEFT])
X2_bc = np.array([LEFT])
X3_bc = np.linspace(LEFT, RIGHT, n_x3bc).reshape(-1,1)
T, X1, X2, X3 = np.meshgrid(t_bc, X1_bc, X2_bc, X3_bc, indexing='ij')

x1_points = np.stack([T.ravel(), X1.ravel(), X2.ravel(), X3.ravel()], axis=-1)
X_bc_dim4 = x1_points[:,1:]
y_bc = solution_func(*x1_points.T).reshape(-1, 1)
dim3_traversal = new_traversal
opti_x3expr = opti_nxexpr(dim3_traversal, x1_points, y_bc)
consts = np.ones(x1_points.shape[0])
res = least_squares(opti_x3expr, consts, method='lm')
y_bc_dim4 = res.x.reshape(-1,1)

In [21]:
X_input = [None, X_bc_dim4]
y_input = [None, y_bc_dim4]
LOG_PATH = 'logs/3dheat_x3.log'
config_file = 'configs/config_heat_gp.json'
if os.path.exists(LOG_PATH):
    os.remove(LOG_PATH)
x3_traversal, x3_expr, x3_model = solve_single_var(X_input, y_input, 
                                                3, config_file, log_path=LOG_PATH)
x3_model.save('models/3dheat_x3_model')
print('x3 model successfully saved')
with open('models/3dheat_x3_traversal.pkl', 'wb') as f:
    pickle.dump(x3_traversal, f)
print('x3 traversal successfully saved')

Saved Trainer state to models/3dheat_x3_model/trainer.json.
x3 model successfully saved
x3 traversal successfully saved


# CFS of 3d heat equation

In [2]:
with open(f'models/3dheat_dim3_traversal.pkl', 'rb') as f:
    dim3_traversal = pickle.load(f)
with open(f'models/3dheat_x3_traversal.pkl', 'rb') as f:
    x3_traversal = pickle.load(f)
new_traversal = replace_nxexpr([dim3_traversal, x3_traversal])

for token in new_traversal:
    if token.input_var is not None:
        if token.name[0] == 'x':
            token.input_var = int(token.name[1])

ini_consts = []
for token in new_traversal:
    if token.name == 'const':
        ini_consts.append(token.value)
ini_consts = [torch.tensor(i, requires_grad=True) for i in ini_consts]

In [3]:
def sample_heat3d(n_X, n_t, n_X_bc, n_t_bc, n_ic, seed=0):
    ''' samples for the domain and boundary for refine'''
    # n_X: number of samples in the spatio domain
    # n_t: number of samples in the time domain
    # n_X_bc: number of samples in the spatio boundary
    # n_t_bc: number of samples in the time boundary
    # n_ic: number of initial conditions(for saptio domain)
    np.random.seed(seed)
    X_spatio = np.random.uniform(LEFT, RIGHT, (n_X, 3))
    X_tempo = np.random.uniform(TLEFT, TRIGHT, (n_t, 1))
    indx = np.meshgrid(X_tempo, np.arange(n_X), indexing='ij')
    X = np.stack([indx[0].flatten(), 
                X_spatio[:,0][indx[1]].flatten(), 
                X_spatio[:,1][indx[1]].flatten(), 
                X_spatio[:,2][indx[1]].flatten()], axis=-1)

    X_bc_spatio = cube(LEFT, RIGHT, n_X_bc)
    X_bc_tempo = np.random.uniform(TLEFT, TRIGHT, (n_t_bc, 1))
    indx = np.meshgrid(X_bc_tempo, np.arange(n_X_bc), indexing='ij')
    X_bc = np.stack([indx[0].flatten(), 
                    X_bc_spatio[:,0][indx[1]].flatten(), 
                    X_bc_spatio[:,1][indx[1]].flatten(), 
                    X_bc_spatio[:,2][indx[1]].flatten()], axis=-1)
    X_ic_spatio = np.random.uniform(LEFT, RIGHT, (n_ic, 3))
    X_ic = np.concatenate((np.ones((X_ic_spatio.shape[0],1))*TLEFT, X_ic_spatio), axis=1)
    X_ibc = np.concatenate((X_bc, X_ic), axis=0)
    X_combine = np.concatenate((X, X_bc, X_ic), axis=0)
    FORWARD_NUM = X_ibc.shape[0]
    return X_combine, X_ibc,  FORWARD_NUM


n_X, n_t, n_X_bc, n_t_bc,n_ic = 100, 100, 100, 100, 1000
X_combine, X_ibc, FORWARD_NUM = sample_heat3d(n_X, n_t, n_X_bc, n_t_bc, n_ic)
source_real = source_func(*X_combine.T).reshape(-1, 1)
solution_real = solution_func(*X_ibc.T).reshape(-1, 1)
label = [source_real, solution_real]
X_combine_torch = torch.tensor(X_combine, requires_grad=True)
label_torch = [torch.tensor(i, requires_grad=True) for i in label]

In [4]:
consts_index = [i for i in range(len(new_traversal)) if new_traversal[i].name == 'const']
metric,_,_ = make_regression_metric("neg_smse_torch", label)
def opti_pde(metric, X_combine_torch, label_torch, new_traversal):
    def pde_r(consts):
        for i in range(len(consts)):
            new_traversal[consts_index[i]].torch_value = consts[i]
        y = pe(new_traversal, X_combine_torch)
        f = function_map['heat3d'](y, X_combine_torch)
        y_hat = [f, y[-FORWARD_NUM:,0:1]]
        r = metric(label_torch,y_hat)
        obj = -r
        return obj
    return pde_r
pde_r = opti_pde(metric, X_combine_torch, label_torch, new_traversal)
optimized_consts, smse = make_const_optimizer('torch')(pde_r, ini_consts)
for i in range(len(optimized_consts)):
    new_traversal[consts_index[i]].value = optimized_consts[i]
    new_traversal[consts_index[i]].parse_value()
test_p = Program()
test_p.traversal = new_traversal
sym_expr = sp.simplify(sp.N(sp.expand(sp.sympify(test_p.sympy_expr)),4))
print(f'Identified solution: {sym_expr}')

Identified solution: -1.7*t + 2.5*x1**4 - 1.3*x2**3 + 0.5*x3**2
