# Nonlinear Wave Equation

This notebook presents a symbolic solution of the nonlinear wave equation in three dimension. The equation is given by:

$$
\frac{\partial^2 u}{\partial t^2} = c^2 (\frac{\partial^2 u}{\partial x_1^2} + \frac{\partial^2 u}{\partial x_2^2} + \frac{\partial^2 u}{\partial x_3^2}) + u^2 + f(x,t)
$$
    
where $u(x,t)$ is the wave function, $c$ is the wave speed, and $f(x,t)$ is a forcing term. 

We set the wave speed $c=1$.



In [1]:
import os
import time
import torch
import pickle
import sympy as sp
import numpy as np
from scipy.optimize import least_squares
from ssde import PDERecursionSolver
from ssde.execute import cython_recursion_execute as ce
from ssde.execute import python_execute as pe
from ssde.program import Program
from ssde.utils import jupyter_logging, cube
from ssde.const import make_const_optimizer
from ssde.pde import function_map
from ssde.task.recursion.recursion import make_regression_metric
from ssde.gradients import hessian, jacobian
from ssde.library import PlaceholderConstant


np.random.seed(0)
LEFT, RIGHT = -1, 1
TLEFT, TRIGHT = 0, 1
if not os.path.exists('logs'):
    os.makedirs('logs')

# The solution is only used to get the bcs and source term here
SOLUTION = sp.sympify("exp(x1**2 + x3**2)*cos(x2)*exp(-0.5*t)")

def calculate_source(solution):
    ''' Calculate the source term of the PDE '''
    # solution: sympy expr of the solution
    real_params = dict()
    for symbol in sp.preorder_traversal(solution):
        if isinstance(symbol, sp.Symbol):
            exec('%s = sp.Symbol("%s")' % (symbol.name, symbol.name))
            if symbol.name not in real_params:
                real_params[symbol.name] = None
    real_params = sorted(list(real_params.keys()))
    print(f'real_params:{real_params}')
    source = 0
    for i in real_params:
        if i == 't':
            source += sp.diff(solution, i, 2)
            continue
        source -= sp.diff(solution, i, 2)
    source -= solution**2
    print(f'source:{sp.N(sp.simplify(source))}')
    solution_func = sp.lambdify(real_params, solution, modules='numpy')
    source_func = sp.lambdify(real_params, source, modules='numpy')
    return solution_func, source_func, real_params

def replace_nxexpr(traversals, index=0, new_traversal=None):
    """
    Replace Nxexpr in the traversal with the corresponding expression recursively.

    Parameters
    ----------
    traversals : list
        The list of traversal of the single variable.
    index : int
        The index of current traversal.
    new_traversal : list
        The result of the replacement.
    """
    if new_traversal is None:
        new_traversal = []

    if index + 1 == len(traversals):
        return new_traversal + traversals[index]
    
    current_ls = traversals[index]
    for token in current_ls:
        if token.name != 'Nxexpr':
            new_traversal.append(token)
        else:
            sub_result = replace_nxexpr(traversals, index+1, [])
            new_traversal.extend(sub_result)
    return new_traversal

def opti_nxexpr(traversal, x, y):
    def opti_consts(nxexpr):
        for token in traversal:
            if token.name == 'Nxexpr':
                token.value = nxexpr.reshape(-1,1)
        y_hat = ce(traversal, x)
        return (y_hat-y).ravel()
    return opti_consts

@jupyter_logging("logs/single_var.log")
def solve_single_var(X_input, y_input, var_index, config_file, diff=None):
    model = PDERecursionSolver(config_file)
    start_time = time.time()
    config = model.fit(X_input, y_input, start_n_var=var_index,diff=diff)
    print(f'Time used: {time.time()-start_time}')
    traversal = config.program_.traversal
    expr = config.program_.sympy_expr
    print(f'Identified var x{var_index}\'s parametirc expression:')
    print(expr)
    print(f'Identified var x{var_index}\'s traversal:')
    print(traversal)
    return traversal, expr, model


solution_func, source_func, real_params = calculate_source(SOLUTION)

real_params:['t', 'x1', 'x2', 'x3']
source:((-4.0*x1**2 - 4.0*x3**2 - 2.75)*exp(1.0*t + x1**2 + x3**2) - exp(0.5*t + 2*x1**2 + 2*x3**2)*cos(x2))*exp(-1.5*t)*cos(x2)


## Expression of `t` 

In [2]:
# boundary conditions of t direction
n_tbc_dim1 = 10
n_x123bc_dim1 = 10

t_bc = np.linspace(TLEFT, TRIGHT, n_tbc_dim1)
t_bc_dim1 = t_bc.reshape(-1, 1)
X123bc = cube(LEFT, RIGHT, n_x123bc_dim1)
X1bc = X123bc[:, 0]
X2bc = X123bc[:, 1]
X3bc = X123bc[:, 2]
X, yz_inx = np.meshgrid(t_bc, np.arange(n_x123bc_dim1), indexing='ij')
coordinates = np.stack([X.flatten(), X1bc[yz_inx].flatten(), X2bc[yz_inx].flatten(), X3bc[yz_inx].flatten()], axis=-1)

y_bc_dim1 = solution_func(*coordinates.transpose(1,0)).reshape(-1, n_x123bc_dim1)

In [3]:
X_input = [None, t_bc_dim1]
y_input = [None, y_bc_dim1]

def u_tt_forward(y, t):
    return hessian(y, t, i=0, j=0)
LOG_PATH = 'logs/3dwave_t.log'
config_file = 'configs/config_wave3d.json'
if os.path.exists(LOG_PATH):
    os.remove(LOG_PATH)
t_traversal, t_expr, t_model = solve_single_var(X_input, y_input, 
                                                1, config_file, log_path=LOG_PATH,diff=u_tt_forward)


t_model.save('models/3dwave_t_model')
print('t model successfully saved')
for i in t_traversal:
    if i.name == 'x1':
        i.name = 't'
with open('models/3dwave_t_traversal.pkl', 'wb') as f:
    pickle.dump(t_traversal, f)
print('t traversal successfully saved')

Saved Trainer state to models/3dwave_t_model/trainer.json.
t model successfully saved
t traversal successfully saved


## Expression of `X1` 

In [4]:
# Compute the value of const (actually the label of x2)
n_tbc = 1
n_x1bc = 31
n_x23bc = 4
t_bc = np.array([TLEFT])
X1_bc = np.linspace(LEFT, RIGHT, n_x1bc).reshape(-1,1)
X_bc_dim2 = X1_bc
X23_bc = np.random.uniform(LEFT, RIGHT, (n_x23bc,2))
X2_bc = X23_bc[:, 0]
X3_bc = X23_bc[:, 1]
T, X1, x23_inx = np.meshgrid(t_bc, X1_bc, np.arange(n_x23bc), indexing='ij')
x1_points = np.stack([T.flatten(), X1.flatten(), X2_bc[x23_inx].flatten(), X3_bc[x23_inx].flatten()], axis=-1)
y_bc = solution_func(*x1_points.T).reshape(-1, 1)

In [6]:
with open('models/3dwave_t_traversal.pkl', 'rb') as f:
    t_traversal = pickle.load(f)

opti_x1expr = opti_nxexpr(t_traversal, x1_points, y_bc)

consts = np.ones(x1_points.shape[0])
res = least_squares(opti_x1expr, consts, method='lm')
y_bc_reshaped = res.x.reshape(n_tbc, n_x1bc, n_x23bc,1)
y_bc_transposed = y_bc_reshaped.transpose(1,0,2,3)
y_bc_dim2 = y_bc_transposed.reshape((n_x1bc, -1))

In [7]:
X_input = [None, X_bc_dim2]
y_input = [None, y_bc_dim2]

config_file = 'configs/config_wave3d.json'
LOG_PATH = 'logs/3dwave_x1.log'
if os.path.exists('logs/3dwave_x1.log'):
    os.remove('logs/3dwave_x1.log')

x1_traversal, x1_expr, x1_model = solve_single_var(X_input, y_input, 
                                                1, config_file, log_path=LOG_PATH)
x1_model.save('models/3dwave_x1_model')
print('x1 model successfully saved')
with open('models/3dwave_x1_traversal.pkl', 'wb') as f:
    pickle.dump(x1_traversal, f)
print('x1 traversal successfully saved')

Saved Trainer state to models/3dwave_x1_model/trainer.json.
x1 model successfully saved
x1 traversal successfully saved


In [2]:
traversals = []
with open(f'models/3dwave_t_traversal.pkl', 'rb') as f:
    traversals.append(pickle.load(f))
for i in range(1,2):
    with open(f'models/3dwave_x{i}_traversal.pkl', 'rb') as f:
        traversals.append(pickle.load(f))
new_traversal = replace_nxexpr(traversals)
for token in new_traversal:
    if token.input_var is not None and token.name[0] == 'x':
        token.input_var = int(token.name[1])

# To optimize nxexpr more accurately, we need use more points on the boundary to optimize the generated skeleton before
n_t = 10
n_x1bc = 21
t_refine = np.linspace(TLEFT, TRIGHT, n_t)
X1_refine = np.linspace(LEFT, RIGHT, n_x1bc)
X2_refine, X3_refine = np.array([LEFT]), np.array([RIGHT])
coordinates = np.array(np.meshgrid(t_refine, X1_refine, X2_refine, X3_refine)).T.reshape(-1, 4)
y_refine = solution_func(*coordinates.T).reshape(-1, 1)

In [3]:
const_pos = [i for i, t in enumerate(new_traversal) if isinstance(t, PlaceholderConstant)]
for i in const_pos:
    if new_traversal[i].name == 'Nxexpr':
        temp = new_traversal[i].value
        new_traversal[i].value = 1
ini_const = np.array([new_traversal[i].value for i in const_pos], dtype=np.float64).ravel()
def refine_consts(consts):
    for i, j in enumerate(const_pos):
        new_traversal[j].value = consts[i].reshape(-1,1)
    y_hat = ce(new_traversal, coordinates)
    return (y_hat-y_refine).ravel()
res = least_squares(refine_consts, ini_const, method='lm')
consts = res.x
for i, j in enumerate(const_pos):
    if new_traversal[j].name != 'Nxexpr':
        new_traversal[j].value = consts[i].reshape(-1)
        new_traversal[j].parse_value()
    else:
        new_traversal[j].value = temp
        

test_p = Program()
test_p.traversal = new_traversal
sym_expr = sp.simplify(sp.N(sp.expand(sp.sympify(test_p.sympy_expr)),4))
print(f'Identified solution: {sym_expr}')

with open('models/3dwave_dim2_traversal.pkl', 'wb') as f:
    pickle.dump(new_traversal, f)

Identified solution: exp(Nxexpr - 0.4999999998224172*t + 1.0000000036652303*x1**2)


## Expression of `X2`

In [11]:
# Compute the value of const (actually the label of x2)
n_tbc = 1
n_x1bc = 1
n_x2bc = 31
n_x3bc = 4
t_bc = np.array([TLEFT])
X1_bc = np.array([LEFT])
X2_bc = np.linspace(LEFT, RIGHT, n_x2bc).reshape(-1,1)
X3_bc = np.random.uniform(LEFT, RIGHT, (n_x3bc,1))
t, X1, X2, X3 = np.meshgrid(t_bc, X1_bc, X2_bc, X3_bc, indexing='ij')

x1_points = np.stack([t.ravel(), X1.ravel(), X2.ravel(), X3.ravel()], axis=-1)
X_bc_dim3 = np.concatenate((np.tile(X1_bc, (n_x2bc,1)),X2_bc),axis=1)
y_bc = solution_func(*x1_points.T).reshape(-1, 1)

In [12]:
opti_x2expr = opti_nxexpr(new_traversal, x1_points, y_bc)
consts = np.ones(x1_points.shape[0])
res = least_squares(opti_x2expr, consts, method='lm')
y_bc_dim3 = res.x.reshape(-1,n_x3bc)
X_input = [None, X_bc_dim3]
y_input = [None, y_bc_dim3]

In [14]:
config_file = 'configs/config_wave3d.json'
LOG_PATH = 'logs/3dwave_x2.log'
if os.path.exists(LOG_PATH):
    os.remove(LOG_PATH)
x2_traversal, x2_expr, x2_model = solve_single_var(X_input, y_input, 
                                                2, config_file, log_path=LOG_PATH)
x2_model.save('models/3dwave_x2_model')
print('x2 model successfully saved')
with open('models/3dwave_x2_traversal.pkl', 'wb') as f:
    pickle.dump(x2_traversal, f)
print('x2 traversal successfully saved')

Saved Trainer state to models/3dwave_x2_model/trainer.json.
x2 model successfully saved
x2 traversal successfully saved


In [4]:
with open(f'models/3dwave_dim2_traversal.pkl', 'rb') as f:
    dim2_traversal = pickle.load(f)
with open(f'models/3dwave_x2_traversal.pkl', 'rb') as f:
    x2_traversal = pickle.load(f)
new_traversal = replace_nxexpr([dim2_traversal, x2_traversal])

for token in new_traversal:
    if token.name == 'Nxexpr':
            temp = token.value
            token.value = 1
    if token.input_var is not None:
        if token.name[0] == 'x':
            token.input_var = int(token.name[1])
        

# To optimize nxexpr more accurately, we need use more points on the boundary to optimize the generated skeleton before
n_t = 10
n_x1bc = 11
n_x2bc = 11
t_refine = np.linspace(TLEFT, TRIGHT, n_t)
X1_refine = np.linspace(LEFT, RIGHT, n_x1bc)
X2_refine =  np.linspace(LEFT, RIGHT, n_x2bc)
X3_refine = np.array([RIGHT])
coordinates = np.array(np.meshgrid(t_refine, X1_refine, X2_refine, X3_refine)).T.reshape(-1, 4)
y_refine = solution_func(*coordinates.T).reshape(-1, 1)

# 优化
const_pos = [i for i, t in enumerate(new_traversal) if isinstance(t, PlaceholderConstant)]
ini_const = np.array([new_traversal[i].value for i in const_pos], dtype=np.float64).ravel()
def refine_consts(consts):
    for i, j in enumerate(const_pos):
        new_traversal[j].value = consts[i].reshape(-1,1)
    y_hat = ce(new_traversal, coordinates)
    return (y_hat-y_refine).ravel()
res = least_squares(refine_consts, ini_const, method='lm')
consts = res.x
for i, j in enumerate(const_pos):
    if new_traversal[j].name != 'Nxexpr':
        new_traversal[j].value = consts[i].reshape(-1)
        new_traversal[j].parse_value()
    else:
        new_traversal[j].value = temp

test_p = Program()
test_p.traversal = new_traversal
sym_expr = sp.simplify(sp.N(sp.expand(sp.sympify(test_p.sympy_expr)),4))
print(f'Identified solution: {sym_expr}')
with open('models/3dwave_dim3_traversal.pkl', 'wb') as f:
    pickle.dump(new_traversal, f)

Identified solution: exp(Nxexpr - 0.5*t + 1.0*x1**2)*cos(x2)**1.0


## Expression of `x3`

In [17]:
# Compute the value of const (actually the label of x2)
n_tbc = 1
n_x1bc = 1
n_x2bc = 1
n_x3bc = 31
t_bc = np.array([TLEFT])
X1_bc = np.array([LEFT])
X2_bc = np.array([LEFT])
X3_bc = np.linspace(LEFT, RIGHT, n_x3bc).reshape(-1,1)
T, X1, X2, X3 = np.meshgrid(t_bc, X1_bc, X2_bc, X3_bc, indexing='ij')

x1_points = np.stack([T.ravel(), X1.ravel(), X2.ravel(), X3.ravel()], axis=-1)
X_bc_dim4 = x1_points[:,1:]
y_bc = solution_func(*x1_points.T).reshape(-1, 1)
dim3_traversal = new_traversal
opti_x3expr = opti_nxexpr(dim3_traversal, x1_points, y_bc)
consts = np.ones(x1_points.shape[0])
res = least_squares(opti_x3expr, consts, method='lm')
y_bc_dim4 = res.x.reshape(-1,1)

In [18]:
X_input = [None, X_bc_dim4]
y_input = [None, y_bc_dim4]
LOG_PATH = 'logs/3dwave_x3.log'
config_file = 'configs/config_wave3d.json'
if os.path.exists(LOG_PATH):
    os.remove(LOG_PATH)
x3_traversal, x3_expr, x3_model = solve_single_var(X_input, y_input, 
                                                3, config_file, log_path=LOG_PATH)
x3_model.save('models/3dwave_x3_model')
print('x3 model successfully saved')
with open('models/3dwave_x3_traversal.pkl', 'wb') as f:
    pickle.dump(x3_traversal, f)
print('x3 traversal successfully saved')

Saved Trainer state to models/3dwave_x3_model/trainer.json.
x3 model successfully saved
x3 traversal successfully saved


## CFS of 3d nolinear wave equation

In [9]:
with open(f'models/3dwave_dim3_traversal.pkl', 'rb') as f:
    dim3_traversal = pickle.load(f)
with open(f'models/3dwave_x3_traversal.pkl', 'rb') as f:
    x3_traversal = pickle.load(f)
new_traversal = replace_nxexpr([dim3_traversal, x3_traversal])

for token in new_traversal:
    if token.input_var is not None:
        if token.name[0] == 'x':
            token.input_var = int(token.name[1])
    if token.name == 'const':
        token.parse_value()

test_p = Program()
test_p.traversal = new_traversal
sym_expr = sp.simplify(sp.N(sp.expand(sp.sympify(test_p.sympy_expr)),4))
print(f'Identified solution: {sym_expr}')

Identified solution: exp(-0.5*t + 1.0*x1**2 + x3**2)*cos(x2)**1.0
