# Nonlinear Wave Equation

This notebook presents a symbolic solution of the nonlinear wave equation (Klein-Gordon equation) in three dimension. The equation is given by:

$$
\frac{\partial^2 u}{\partial t^2} = c^2 (\frac{\partial^2 u}{\partial x_1^2} + \frac{\partial^2 u}{\partial x_2^2}) - u  - u^3 + f(x,t)
$$
    
where $u(x,t)$ is the wave function, $c$ is the wave speed, and $f(x,t)$ is a forcing term. 

We set the wave speed $c=1$.



In [1]:
import os
import time
import torch
import pickle
import sympy as sp
import numpy as np
from scipy.optimize import least_squares
from ssde import PDERecursionSolver
from ssde.execute import cython_recursion_execute as ce
from ssde.execute import python_execute as pe
from ssde.program import Program
from ssde.utils import jupyter_logging, rect
from ssde.const import make_const_optimizer
from ssde.pde import function_map
from ssde.task.recursion.recursion import make_regression_metric
from ssde.gradients import hessian, jacobian
from ssde.library import PlaceholderConstant


np.random.seed(0)
LEFT, RIGHT = -1, 1
TLEFT, TRIGHT = 0, 1
if not os.path.exists('logs'):
    os.makedirs('logs')

# The solution is only used to get the bcs and source term here
SOLUTION = sp.sympify("exp(x1**2)*sin(x2)*exp(-0.5*t)")

def calculate_source(solution):
    ''' Calculate the source term of the PDE '''
    # solution: sympy expr of the solution
    real_params = dict()
    for symbol in sp.preorder_traversal(solution):
        if isinstance(symbol, sp.Symbol):
            exec('%s = sp.Symbol("%s")' % (symbol.name, symbol.name))
            if symbol.name not in real_params:
                real_params[symbol.name] = None
    real_params = sorted(list(real_params.keys()))
    print(f'real_params:{real_params}')
    source = 0
    for i in real_params:
        if i == 't':
            source += sp.diff(solution, i, 2)
            continue
        source -= sp.diff(solution, i, 2)
    source += solution + solution**3
    print(f'source:{sp.N(sp.simplify(source))}')
    solution_func = sp.lambdify(real_params, solution, modules='numpy')
    source_func = sp.lambdify(real_params, source, modules='numpy')
    return solution_func, source_func, real_params

def replace_nxexpr(traversals, index=0, new_traversal=None):
    """
    Replace Nxexpr in the traversal with the corresponding expression recursively.

    Parameters
    ----------
    traversals : list
        The list of traversal of the single variable.
    index : int
        The index of current traversal.
    new_traversal : list
        The result of the replacement.
    """
    if new_traversal is None:
        new_traversal = []

    if index + 1 == len(traversals):
        return new_traversal + traversals[index]
    
    current_ls = traversals[index]
    for token in current_ls:
        if token.name != 'Nxexpr':
            new_traversal.append(token)
        else:
            sub_result = replace_nxexpr(traversals, index+1, [])
            new_traversal.extend(sub_result)
    return new_traversal

def opti_nxexpr(traversal, x, y):
    def opti_consts(nxexpr):
        for token in traversal:
            if token.name == 'Nxexpr':
                token.value = nxexpr.reshape(-1,1)
        y_hat = ce(traversal, x)
        return (y_hat-y).ravel()
    return opti_consts

@jupyter_logging("logs/single_var.log")
def solve_single_var(X_input, y_input, var_index, config_file, diff=None):
    model = PDERecursionSolver(config_file)
    start_time = time.time()
    config = model.fit(X_input, y_input, start_n_var=var_index,diff=diff)
    print(f'Time used: {time.time()-start_time}')
    traversal = config.program_.traversal
    expr = config.program_.sympy_expr
    print(f'Identified var x{var_index}\'s parametirc expression:')
    print(expr)
    print(f'Identified var x{var_index}\'s traversal:')
    print(traversal)
    return traversal, expr, model


solution_func, source_func, real_params = calculate_source(SOLUTION)

  from .autonotebook import tqdm as notebook_tqdm


real_params:['t', 'x1', 'x2']
source:((0.25 - 4.0*x1**2)*exp(1.5*t + x1**2) + exp(0.5*t + 3*x1**2)*sin(x2)**2)*exp(-2.0*t)*sin(x2)


## Expression of `t` 

In [2]:
# boundary conditions of t direction
n_tbc_dim1 = 10
n_x12bc_dim1 = 10

t_bc = np.linspace(TLEFT, TRIGHT, n_tbc_dim1)
t_bc_dim1 = t_bc.reshape(-1, 1)
X12bc = rect([LEFT, LEFT], [RIGHT, RIGHT], n_x12bc_dim1)
X1bc = X12bc[:, 0]
X2bc = X12bc[:, 1]
X, yz_inx = np.meshgrid(t_bc, np.arange(n_x12bc_dim1), indexing='ij')
coordinates = np.stack([X.flatten(), X1bc[yz_inx].flatten(), X2bc[yz_inx].flatten()], axis=-1)

y_bc_dim1 = solution_func(*coordinates.transpose(1,0)).reshape(-1, n_x12bc_dim1)



In [3]:
X_input = [None, t_bc_dim1]
y_input = [None, y_bc_dim1]

def u_tt_forward(y, t):
    return hessian(y, t, i=0, j=0)
LOG_PATH = 'logs/2dwave_t.log'
config_file = 'configs/config_wave2d.json'
if os.path.exists(LOG_PATH):
    os.remove(LOG_PATH)
t_traversal, t_expr, t_model = solve_single_var(X_input, y_input, 
                                                1, config_file, log_path=LOG_PATH,diff=u_tt_forward)


t_model.save('models/2dwave_t_model')
print('t model successfully saved')
for i in t_traversal:
    if i.name == 'x1':
        i.name = 't'
with open('models/2dwave_t_traversal.pkl', 'wb') as f:
    pickle.dump(t_traversal, f)
print('t traversal successfully saved')

Saved Trainer state to models/2dwave_t_model_2/trainer.json.
t model successfully saved
t traversal successfully saved


## Expression of `X1` 

In [2]:
# Compute the value of const (actually the label of x2)
n_tbc = 1
n_x1bc = 20
n_x2bc = 4
t_bc = np.array([TLEFT])
X1_bc = np.linspace(LEFT, RIGHT, n_x1bc).reshape(-1,1)
X_bc_dim2 = X1_bc
X2_bc = np.random.uniform(LEFT, RIGHT, (n_x2bc,))
X1, X2, X3 = np.meshgrid(t_bc, X1_bc, X2_bc, indexing='ij')
x1_points = np.stack([X1.ravel(), X2.ravel(), X3.ravel()], axis=-1)
y_bc = solution_func(*x1_points.T).reshape(-1, 1)

In [3]:
with open('models/2dwave_t_traversal.pkl', 'rb') as f:
    t_traversal = pickle.load(f)
opti_x1expr = opti_nxexpr(t_traversal, x1_points, y_bc)

consts = np.ones(x1_points.shape[0])
res = least_squares(opti_x1expr, consts, method='lm')
y_bc_reshaped = res.x.reshape(n_tbc, n_x1bc, n_x2bc,1)
y_bc_transposed = y_bc_reshaped.transpose(1,0,2,3)
y_bc_dim2 = y_bc_transposed.reshape((n_x1bc, -1))

In [4]:
X_input = [None, X_bc_dim2]
y_input = [None, y_bc_dim2]

config_file = 'configs/config_wave2d.json'
LOG_PATH = 'logs/2dwave_x1.log'
if os.path.exists(LOG_PATH):
    os.remove(LOG_PATH)

x1_traversal, x1_expr, x1_model = solve_single_var(X_input, y_input, 
                                                1, config_file, log_path=LOG_PATH)
x1_model.save('models/2dwave_x1_model')
print('x1 model successfully saved')
with open('models/2dwave_x1_traversal.pkl', 'wb') as f:
    pickle.dump(x1_traversal, f)
print('x1 traversal successfully saved')

Saved Trainer state to models/2dwave_x1_model_1/trainer.json.
x1 model successfully saved
x1 traversal successfully saved


In [18]:
traversals = []
with open(f'models/2dwave_t_traversal.pkl', 'rb') as f:
    traversals.append(pickle.load(f))
for i in range(1,2):
    with open(f'models/2dwave_x{i}_traversal.pkl', 'rb') as f:
        traversals.append(pickle.load(f))
new_traversal = replace_nxexpr(traversals)
for token in new_traversal:
    if token.input_var is not None and token.name[0] == 'x':
        token.input_var = int(token.name[1])

test_p = Program()
test_p.traversal = new_traversal
sym_expr = sp.simplify(sp.N(sp.expand(sp.sympify(test_p.sympy_expr)),4))
print(f'Identified solution: {sym_expr}')

# To optimize nxexpr more accurately, we need use more points on the boundary to optimize the generated skeleton before
n_t = 50
n_x1bc = 100
t_refine = np.linspace(TLEFT, TRIGHT, n_t)
X1_refine = np.linspace(LEFT, RIGHT, n_x1bc)
X2_refine = np.array([LEFT]), np.array([RIGHT])
coordinates = np.array(np.meshgrid(t_refine, X1_refine, X2_refine)).T.reshape(-1, 3)
y_refine = solution_func(*coordinates.T).reshape(-1, 1)

Identified solution: 0.5403*Nxexpr*exp(-0.50023418340810012*t + x1**2)


In [19]:
const_pos = [i for i, t in enumerate(new_traversal) if isinstance(t, PlaceholderConstant)]
for i in const_pos:
    if new_traversal[i].name == 'Nxexpr':
        temp = new_traversal[i].value
        new_traversal[i].value = 1
ini_const = np.array([new_traversal[i].value for i in const_pos], dtype=np.float64).ravel()
def refine_consts(consts):
    for i, j in enumerate(const_pos):
        new_traversal[j].value = consts[i].reshape(-1,1)
    y_hat = ce(new_traversal, coordinates)
    return (y_hat-y_refine).ravel()
res = least_squares(refine_consts, ini_const, method='lm')
consts = res.x
for i, j in enumerate(const_pos):
    if new_traversal[j].name != 'Nxexpr':
        new_traversal[j].value = consts[i].reshape(-1)
        new_traversal[j].parse_value()
    else:
        new_traversal[j].value = temp
        

test_p = Program()
test_p.traversal = new_traversal
sym_expr = sp.simplify(sp.N(sp.expand(sp.sympify(test_p.sympy_expr)),4))
print(f'Identified solution: {sym_expr}')

with open('models/2dwave_dim2_traversal.pkl', 'wb') as f:
    pickle.dump(new_traversal, f)

Identified solution: 0.5403*Nxexpr*exp(-0.5002341886109848*t + x1**2)


## Expression of `X2`

In [2]:
# Compute the value of const (actually the label of x2)
n_tbc = 1
n_x1bc = 1
n_x2bc = 20
t_bc = np.array([TLEFT])
X1_bc = np.array([LEFT])
X2_bc = np.linspace(LEFT, RIGHT, n_x2bc).reshape(-1,1)
X1, X2, X3 = np.meshgrid(t_bc, X1_bc, X2_bc, indexing='ij')

x1_points = np.stack([X1.ravel(), X2.ravel(), X3.ravel()], axis=-1)
X_bc_dim3 = x1_points[:,1:]
y_bc = solution_func(*x1_points.T).reshape(-1, 1)

In [7]:
# with open('models/2dwave_t_traversal.pkl', 'rb') as f:
#     t_traversal = pickle.load(f)
# with open('models/2dwave_x1_traversal.pkl', 'rb') as f:
#     x1_traversal = pickle.load(f)
# # replace the nxexpr with x2_traversal
# new_traversal = replace_nxexpr([t_traversal, x1_traversal])
with open('models/2dwave_dim2_traversal.pkl','rb') as f:
    new_traversal = pickle.load(f)
for token in new_traversal:
    if token.input_var is not None and token.name[0] == 'x':
        token.input_var = int(token.name[1])

opti_x2expr = opti_nxexpr(new_traversal, x1_points, y_bc)
consts = np.ones(x1_points.shape[0])
res = least_squares(opti_x2expr, consts, method='lm')
y_bc_dim3 = res.x.reshape(-1,1)
X_input = [None, X_bc_dim3]
y_input = [None, y_bc_dim3]

In [8]:
config_file = 'configs/config_wave2d.json'
LOG_PATH = 'logs/2dwave_x2.log'
if os.path.exists(LOG_PATH):
    os.remove(LOG_PATH)
x2_traversal, x2_expr, x2_model = solve_single_var(X_input, y_input, 
                                                2, config_file, log_path=LOG_PATH)
x2_model.save('models/2dwave_x2_model')
print('x2 model successfully saved')
with open('models/2dwave_x2_traversal.pkl', 'wb') as f:
    pickle.dump(x2_traversal, f)
print('x2 traversal successfully saved')

Saved Trainer state to models/2dwave_x2_model/trainer.json.
x2 model successfully saved
x2 traversal successfully saved


## CFS of 2d wave equation

In [2]:
traversals = []
with open(f'models/2dwave_t_traversal.pkl', 'rb') as f:
        traversals.append(pickle.load(f))
for i in range(1,3):
    with open(f'models/2dwave_x{i}_traversal.pkl', 'rb') as f:
        traversals.append(pickle.load(f))

new_traversal = replace_nxexpr(traversals)

for token in new_traversal:
    if token.input_var is not None:
        if token.name[0] == 'x':
            token.input_var = int(token.name[1])

test_p = Program()
test_p.traversal = new_traversal
sym_expr = sp.simplify(sp.N(sp.expand(sp.sympify(test_p.sympy_expr)),4))
print(f'Identified solution: {sym_expr}')

ini_consts = []
for token in new_traversal:
    if token.name == 'const':
        ini_consts.append(token.value)
ini_consts = [torch.tensor(i, requires_grad=True) for i in ini_consts]

Identified solution: 1.0*exp(-0.50023418340810012*t + x1**2)*sin(x2)


In [3]:
# samples for the domain and boundary for refine
n_samples, n_x12_bc, n_t_bc,n_ic = 1000, 100, 10, 1000
X = np.random.uniform(LEFT, RIGHT, (n_samples, 3))
X12_bc = rect([LEFT, LEFT], [RIGHT, RIGHT], n_x12_bc)
X1_bc = X12_bc[:, 0]
X2_bc = X12_bc[:, 1]
Xt_bc = np.linspace(TLEFT, TRIGHT, n_t_bc)
indx = np.meshgrid(Xt_bc, np.arange(n_x12_bc), indexing='ij')
X_bc = np.stack([indx[0].flatten(), X1_bc[indx[1]].flatten(), X2_bc[indx[1]].flatten()], axis=-1)
X_ic = np.concatenate([np.ones((n_ic,1)) * TLEFT, np.random.uniform(LEFT,RIGHT,(n_ic,2))], axis=1)
X_ibc = np.concatenate([X_bc, X_ic], axis=0)
X_combine = np.concatenate([X, X_ibc], axis=0)
X_combine_torch = torch.tensor(X_combine, dtype=torch.float32, requires_grad=True)
y = source_func(*X_combine.T).reshape(-1, 1)
y_ibc = solution_func(*X_ibc.T).reshape(-1, 1)
FORWARD_NUM = y_ibc.shape[0]

In [4]:
y_input = [y, y_ibc]
y_input_torch = [torch.tensor(i, dtype=torch.float32, requires_grad=True) for i in y_input]
consts_index = [i for i in range(len(new_traversal)) if new_traversal[i].name == 'const']
metric,_,_ = make_regression_metric("neg_smse_torch", y_input)
def pde_r(consts):
    for i in range(len(consts)):
        new_traversal[consts_index[i]].torch_value = consts[i]
    y = pe(new_traversal, X_combine_torch)
    f = function_map['wave2d'](y, X_combine_torch)
    y_hat = [f, y[-FORWARD_NUM:,0:1]]
    r = metric(y_input_torch,y_hat)
    obj = -r
    return obj

optimized_consts, smse = make_const_optimizer('torch')(pde_r, ini_consts)
for i in range(len(optimized_consts)):
    new_traversal[consts_index[i]].value = optimized_consts[i]
    new_traversal[consts_index[i]].parse_value()
test_p = Program()
test_p.traversal = new_traversal
sym_expr = sp.simplify(sp.N(sp.expand(sp.sympify(test_p.sympy_expr)),4))
print(f'Identified solution: {sym_expr}')

Identified solution: 1.0*exp(-0.5000248*t + x1**2)*sin(x2)
