## Example 5. KdV Equation

#### This notebook solves the KdV equation $u_t+6uu_x+u_{xxx} = 0$ on the domain $[x_L,x_R]\times (0,T]$ with periodic boundary condition $u(x_L,t) = u(x_R,t)$ and different multi-soliton solutions.  The KdV equation has infinitely many conserved quantities, of which the first three considered here are
\begin{align}
    \int u \ dx \ (\text{mass}) \\
    \int \frac{1}{2} u^2 \ dx \ (\text{energy}) \\
    \int \left(2u^3 -u_x^2  \right) dx \ (\text{Whitham}) \;.
\label{KdV_invariants_3}    
\end{align}

In [None]:
#Required libraries 
import numpy as np
import sympy as sp
import pandas as pd
import matplotlib.pyplot as plt
from IPython.display import clear_output
from scipy.optimize import fsolve,root,minimize,brute,fmin,fmin_slsqp

#-----------#
from RKSchemes import ImEx_schemes

In [None]:
# choose a particular case here
sol = 1; inv = 1
# Different cases
if sol == 1 and inv == 1:
    xL = -20; xR = 60; L = xR-xL; N = 512; DT = [0.1, 0.1]; t0 = 0; T = 20
    
elif sol == 2 and inv == 1:
    xL = -80; xR = 80; L = xR-xL; N = 1024; DT = [0.1, 0.1]; t0 = -25; T = 25
    
elif sol == 3 and inv == 1:
    xL = -130; xR = 130; L = xR-xL; N = 1536; DT = [0.1, 0.1]; t0 = -50; T = 50

In [None]:
#Domain and required functions
xplot = np.linspace(xL, xR, N+1)
x = xplot[0:-1] 
dx = x[1] - x[0]
xi = np.fft.fftfreq(N) * N * 2*np.pi / L

# Identity matrix
I = np.eye(N)

def F_matrix(N):
    F = np.zeros((N,N),dtype=complex)
    for j in range(N):
        v = np.zeros(N)
        v[j] = 1.
        F[:,j] = np.fft.fft(v)
    return F

dft_mat = F_matrix(N)
inv_dft_mat = np.linalg.inv(dft_mat)
xi3 = np.tile((-1j)*xi*xi*xi,(N,1)).T
M = np.dot(inv_dft_mat,np.multiply(xi3,dft_mat))

def eta0(w): # mass
    return np.sum(w) * dx

def eta1(w): # energy
    return 0.5 * np.sum(w*w) * dx

def CompositeSimposonInt(x,f):
    dx = (x[1]-x[0])
    approx = 1/3*dx*(f[0]+4*np.sum(f[1::2])+2*sum(f[2::2][:-1])+f[-1])
    return approx

def eta2(w): # whitham
    f1 = np.append(w,w[0])**3
    what = np.fft.fft(w)
    wx = np.real(np.fft.ifft(1j*xi*what))
    f2 = np.append(wx,wx[0])**2
    int_f1 = CompositeSimposonInt(xplot,f1)
    int_f2 = CompositeSimposonInt(xplot,f2)
    return 2*int_f1 - int_f2

def kdV_stiff_rhs(u): # stiff right hand side
    uhat = np.fft.fft(u)
    return -np.real(np.fft.ifft(-1j*xi*xi*xi*uhat))

def kdV_non_stiff_rhs(u): # non-stiff right hand side with Spectral and SBP
    uhat = np.fft.fft(u)
    u2hat = np.fft.fft(u*u)
    return -6*(u * np.real(np.fft.ifft(1j*xi*uhat)) + np.real(np.fft.ifft(1j*xi*u2hat))) / 3.

def rgam1(gammas,u,inc1,E1_old): # for one invariant
    gamma1 = gammas
    uprop = u + gamma1*inc1   
    E1 = eta1(uprop)
    return np.array([E1-E1_old])

def rgam2(gammas,u,inc1,E1_old,inc2,E2_old): # for two invariant
    gamma1, gamma2 = gammas
    uprop = u + gamma1*inc1 + gamma2*inc2  
    E1 = eta1(uprop); E2 = eta2(uprop)
    return np.array([E1-E1_old,E2-E2_old])

def norm_rgam2(gammas,u,inc1,E1_old,inc2,E2_old): # for two invariant
    gamma1,gamma2 = gammas
    uprop = u + gamma1*inc1 + gamma2*inc2   
    E1 = eta1(uprop); E2 = eta2(uprop)
    return np.linalg.norm(np.array([E1-E1_old,E2-E2_old]))
 
def analytical_sol(sol,x,t): # different multi-soliton solution
    if sol == 1: # 1-soliton solution
        c1 = 2; xi1 = x - c1*t
        c1sqrt = np.sqrt(c1)
        u = (c1/2)/np.cosh(c1sqrt*xi1/2)**2 
    elif sol == 2:  # 2-soliton solution
        c1=2; c2=1; xi1 = x - c1*t; xi2 = x - c2*t
        c1sqrt = np.sqrt(c1)
        c2sqrt = np.sqrt(c2)
        u = 2*(c1-c2)*(c1*np.cosh(c2sqrt*xi2/2)**2 +c2*np.sinh(c1sqrt*xi1/2)**2) / \
            ( (c1sqrt-c2sqrt)*np.cosh((c1sqrt*xi1+c2sqrt*xi2)/2)   \
             + (c1sqrt+c2sqrt)*np.cosh((c1sqrt*xi1-c2sqrt*xi2)/2) )**2
    elif sol == 3:  # 3-soliton solution
        b1 = 0.4; b2 = 0.7; b3 = 1
        X1 = np.sqrt(b1/2)*(x - 2*b1*t); X2 = np.sqrt(b2/2)*(x - 2*b2*t); X3 = np.sqrt(b3/2)*(x - 2*b3*t)  
        b1sqrt = np.sqrt(b1); b2sqrt = np.sqrt(b2); b3sqrt = np.sqrt(b3)
        sech2X1 = 1/np.cosh(X1)**2; sech2X3 = 1/np.cosh(X3)**2; cosech2X2 = 1/np.sinh(X2)**2;
        tanhX1 = np.tanh(X1); tanhX3 = np.tanh(X3); cothX2 = np.cosh(X2)/np.sinh(X2);
        Num1 = 2*(b3-b1)*(b3*sech2X3 - b1*sech2X1) / (np.sqrt(2)*b3sqrt*tanhX3 - np.sqrt(2)*b1sqrt*tanhX1)**2
        Num2 = 2*(b1-b2)*(b2*cosech2X2 + b1*sech2X1) / (np.sqrt(2)*b1sqrt*tanhX1 - np.sqrt(2)*b2sqrt*cothX2)**2
        Num = 2*(b2 - b3)*(Num1 - Num2)
        Den1 = 2*(b1 - b2) / (np.sqrt(2)*b1sqrt*tanhX1 - np.sqrt(2)*b2sqrt*cothX2)
        Den2 = 2*(b3-b1) / (np.sqrt(2)*b3sqrt*tanhX3 - np.sqrt(2)*b1sqrt*tanhX1)
        Den = (Den1 - Den2)**2
        u = b1*sech2X1 - Num/Den
    return u

In [None]:
# Plotting n-soliton
w0 = analytical_sol(sol,xplot,t0)
fig = plt.figure()
plt.plot(xplot,w0,label = "t = %1.2f"%t0)
plt.title('%d-soliton solution'%sol)
plt.legend()
print("Value of u(%d,0) = %1.5e."%(xL,w0[0]))
print("Value of u(%d,0) = %1.5e."%(xR,w0[-1]))
print("u(%d,0)-u(%d,0) = %1.5e. So numerically we have periodic BCs."%(xL,xR,w0[0]-w0[-1]))

### Numerical solution by baseline RK methods

In [None]:
def Compute_Sol_Without_Relaxation(Mthdname,rkim, rkex, c, b, bhat, dt, f_stiff, f_non_stiff, T, u0, t0):    
    tt = np.zeros(1); t = t0; tt[0] = t

    uu = np.zeros((1,np.size(u0))); uu[0,:] = u0.copy()
    
    s = len(rkim); Rim = np.zeros((s,len(u0))); Rex = np.zeros((s,len(u0))); steps = 0
    
    # time loop
    while t < T and not np.isclose(t, T):
        clear_output(wait=True)
        if t + dt > T:
            dt = T - t
        for i in range(s):
            rhs = uu[-1].copy()
            if i>0:
                for j in range(i):
                    rhs += dt*(rkim[i,j]*Rim[j,:] + rkex[i,j]*Rex[j,:])

            Mat = I + dt*rkim[i,i]*M
            g_j = np.linalg.solve(Mat, rhs)
            Rim[i,:] = f_stiff(g_j)
            Rex[i,:] = f_non_stiff(g_j)

        inc = dt*sum([ b[j]*(Rim[j]+Rex[j]) for j in range(s)])    
        unew = uu[-1]+inc; t+= dt
        tt = np.append(tt, t)
        steps += 1
        uu = np.append(uu, np.reshape(unew.copy(), (1,len(unew))), axis=0)  
        
        print("Method = Baseline %s: Step number = %d (time = %1.2f)"%(Mthdname,steps,tt[-1]))
        
    return tt, uu

### Numerical solution by multiple Relaxation RK methods

In [None]:
def compute_sol_multi_relaxation(Mthdname, rkim, rkex, c, b, bhat, dt, f_stiff, f_non_stiff, T, u0, t0, inv):
    tt = np.zeros(1); t = t0; tt[0] = t

    uu = np.zeros((1,np.size(u0))); uu[0,:] = u0.copy()

    s = len(rkim); Rim = np.zeros((s,len(u0))); Rex = np.zeros((s,len(u0))); gamma0 = np.zeros(inv) 

    steps = 0; no_ier_five = 0; no_ier_one = 0; no_ier_else = 0
    
    # time loop
    while t < T and not np.isclose(t, T):
        clear_output(wait=True)
        if t + dt > T:
            dt = T - t
        for i in range(s):
            rhs = uu[-1].copy()
            if i>0:
                for j in range(i):
                    rhs += dt*(rkim[i,j]*Rim[j,:] + rkex[i,j]*Rex[j,:] )

            Mat = I + dt*rkim[i,i]*M
            g_j = np.linalg.solve(Mat, rhs)
            Rim[i,:] = f_stiff(g_j)
            Rex[i,:] = f_non_stiff(g_j)
            
        inc1 = dt*sum([ b[i]*(Rim[i]+Rex[i]) for i in range(s)]) 
        unew = uu[-1]+inc1; E1_old = eta1(uu[-1]);
        # fsolve
        gamma, info, ier, mesg = fsolve(rgam1,gamma0,args=(unew,inc1,E1_old),full_output=True) 
        gamma1 = gamma

        steps += 1
        if ier == 1:
            no_ier_one += 1
        elif ier == 5:
            no_ier_five += 1
        else:
            no_ier_else += 1

        unew = unew + gamma1*inc1; t+=(1+gamma1)*dt
        tt = np.append(tt, t)
        uu = np.append(uu, np.reshape(unew.copy(), (1,len(unew))), axis=0)       
        print("Method = Relaxation %s: At step number = %d (time = %1.2f), integer flag for fsolve = %d and γ1 = %f"%(Mthdname,steps,tt[-1],ier,gamma1))
    return tt, uu, no_ier_one, no_ier_five, no_ier_else

### Compute solutions by all the methods

In [None]:
method_names = ["ImEx32_2","ImEx43"]; S = [4,6]; P = [3,4]; em_P = [2,3]; Sch_No = [3,4]
f_stiff = kdV_stiff_rhs; f_non_stiff = kdV_non_stiff_rhs; u0 = analytical_sol(sol,x,t0)

eqn = 'KdV_%d_sol_%d_inv'%(sol,inv)

data = {'Method': method_names,
        'Method_labels': ['ARK3(2)4L[2]SA','ARK4(3)4L[2]SA'],
        'Mthd_Save_Name': ['ARK32','ARK43'],
        'B: dt': DT,
        'R: dt': DT,
        'Domain':'[%d,%d]'%(xL,xR),
        'N': N,
        't0': t0,
        'tf': T
       }
df = pd.DataFrame(data)
df['R: ier = 1'] = np.nan; df['R: ier = 5'] = np.nan; df['R: ier = else'] = np.nan

b_tt = []; b_uu = []; r_tt = []; r_uu = [];  
for idx in range(len(method_names)):
    dt = DT[idx]; rkim, rkex, c, b, bhat = ImEx_schemes(S[idx],P[idx],em_P[idx],Sch_No[idx])
    tt, uu,IF_1,IF_5,IF_else = compute_sol_multi_relaxation(method_names[idx],rkim, rkex, c, b, bhat, dt, f_stiff, f_non_stiff, T, u0, t0, inv)
    df.at[idx,'R: ier = 1'] = int(IF_1); df.at[idx,'R: ier = 5'] = int(IF_5); df.at[idx,'R: ier = else'] = int(IF_else)
    r_tt.append(tt); r_uu.append(uu)
    
for idx in range(len(method_names)):
    dt = DT[idx]; rkim, rkex, c, b, bhat = ImEx_schemes(S[idx],P[idx],em_P[idx],Sch_No[idx])
    tt, uu = Compute_Sol_Without_Relaxation(method_names[idx], rkim, rkex, c, b, bhat,dt, f_stiff, f_non_stiff, T, u0, t0)
    b_tt.append(tt); b_uu.append(uu)

In [None]:
df

In [None]:
import os
path = '%s'%('Figures/%s'%eqn)

import os
if not os.path.exists(path):
   os.makedirs(path)

### Computing analytical solutions

In [None]:
def analytical_u_KdV(sol,tvec,u0):
    true_u = np.zeros((len(tvec), len(u0))) 
    for idx in range(len(tvec)):
        true_u[idx,:] = analytical_sol(sol,x,tvec[idx])
    return true_u

# Computing reference solution corresponding to methods without relaxation
b_UTrue = [];
for idx in range(len(method_names)):
    b_tvec = b_tt[idx]
    b_utrue = analytical_u_KdV(sol,b_tvec,u0)
    b_UTrue.append(b_utrue)
    
# Computing reference solution corresponding to methods with relaxation
r_UTrue = []
for idx in range(len(method_names)):
    r_tvec = r_tt[idx]
    r_utrue = analytical_u_KdV(sol,r_tvec,u0)
    r_UTrue.append(r_utrue)
    

### Compute and plot the changes in invariants by different methods

In [None]:
b_ETA_0 = []; r_ETA_0 = []; b_ETA_1 = []; r_ETA_1 = []; b_ETA_2 = []; r_ETA_2 = [];
for i in range(len(df['Method'])):
    b_eta_0 = [eta0(u) for u in b_uu[i]] - eta0(b_uu[i][0])
    b_eta_1 = [eta1(u) for u in b_uu[i]] - eta1(b_uu[i][0])
    b_eta_2 = [eta2(u) for u in b_uu[i]] - eta2(b_uu[i][0])
    r_eta_0 = [eta0(u) for u in r_uu[i]] - eta0(r_uu[i][0])
    r_eta_1 = [eta1(u) for u in r_uu[i]] - eta1(r_uu[i][0])
    r_eta_2 = [eta2(u) for u in r_uu[i]] - eta2(r_uu[i][0])
    b_ETA_0.append(b_eta_0); b_ETA_1.append(b_eta_1); b_ETA_2.append(b_eta_2); 
    r_ETA_0.append(r_eta_0); r_ETA_1.append(r_eta_1); r_ETA_2.append(r_eta_2)

In [None]:
# plotting invariants
# Font size    
font = {#'family' : 'normal',
'weight' : 'normal',
'size'   : 14}
plt.rc('font', **font)
plt.figure(figsize=(15, 4))

for i in range(len(df['Method'])):
    plt.subplot(1,len(df['Method']),i+1)
    plt.plot(b_tt[i],b_ETA_0[i],':k',label="Baseline: $\eta_{0}(U(t))-\eta_{0}(U(0))$")
    plt.plot(r_tt[i],r_ETA_0[i],'-r',label="Relaxation: $\eta_{0}(U(t))-\eta_{0}(U(0))$")
   
    if inv == 1:
        ncol = 4;
        plt.plot(b_tt[i],b_ETA_1[i],':b',label="Baseline: $\eta_{1}(U(t))-\eta_{1}(U(0))$")
        plt.plot(r_tt[i],r_ETA_1[i],'-g',label="Relaxation: $\eta_{1}(U(t))-\eta_{1}(U(0))$")
        plt.title('%s with $\Delta t$ = %.1f'%(df['Method_labels'][i],df['B: dt'][i]))
    elif inv == 2:
        ncol = 3;
        plt.plot(b_tt[i],b_ETA_1[i],':b',label="Baseline: $\eta_{1}(U(t))-\eta_{1}(U(0))$")
        plt.plot(r_tt[i],r_ETA_1[i],'-g',label="Relaxation: $\eta_{1}(U(t))-\eta_{1}(U(0))$")
        plt.plot(b_tt[i],b_ETA_2[i],':m',label="Baseline: $\eta_{2}(U(t))-\eta_{2}(U(0))$")
        plt.plot(r_tt[i],r_ETA_2[i],'-y',label="Relaxation: $\eta_{2}(U(t))-\eta_{2}(U(0))$")
        plt.title('%s with $\Delta t$ = %.2f'%(df['Method_labels'][i],df['B: dt'][i]))
        
    plt.xlabel('$t$')
    plt.yscale("symlog", linthresh=1.e-14)
    plt.yticks([-1.e-2, -1.e-6, -1.e-10, -1.e-14, 1.e-14, 1.e-10, 1.e-6, 1.e-2])
    
plt.tight_layout()
ax = plt.gca()
handles, labels = ax.get_legend_handles_labels()
plt.figlegend(handles, labels, loc='upper center', ncol = ncol, bbox_to_anchor=(0.5, 1.2))
plt.savefig('./Figures/%s/%s_InvariantVsTime_dt%1.0e_tf%d.pdf'%(eqn,eqn,df['B: dt'][i],T),format='pdf', bbox_inches="tight",transparent=True)


### Max of invariant errors

In [None]:
for i in range(len(df['Method'])):
    print("Baseline %s: Max of mass invariant error = %1.2e. \n"%(df['Method'][i],np.max(np.abs(b_ETA_0[i]))))
    print("Baseline %s: Max of energy invariant error = %1.2e. \n"%(df['Method'][i],np.max(np.abs(b_ETA_1[i]))))
    print("Baseline %s: Max of Whitham invariant error = %1.2e. \n"%(df['Method'][i],np.max(np.abs(b_ETA_2[i]))))

    print("Relaxation %s: Max of mass invariant error = %1.2e. \n"%(df['Method'][i],np.max(np.abs(r_ETA_0[i]))))
    print("Relaxation %s: Max of energy invariant error = %1.2e. \n"%(df['Method'][i],np.max(np.abs(r_ETA_1[i]))))
    print("Relaxation %s: Max of Whitham invariant error = %1.2e. \n"%(df['Method'][i],np.max(np.abs(r_ETA_2[i]))))

### Compute and plot errors by different methods

In [None]:
b_ERR = []; r_ERR = []
for i in range(len(df['Method'])):
    b_uexact = b_UTrue[i]; b_err = np.max(np.abs(b_uu[i]-b_uexact),axis=1)
    r_uexact = r_UTrue[i]; r_err = np.max(np.abs(r_uu[i]-r_uexact),axis=1)
    b_ERR.append(b_err); r_ERR.append(r_err) 

In [None]:
# Plotting
if sol == 1 and inv == 1:
    sl2_cons_mult = [8e-4,3e-4]; sl2_p = [2,2]
    sl1_cons_mult = [2.5e-3,3e-4]; sl1_p = [1,1]
    slope_st_pt = 5
elif sol == 2 and inv == 1:
    sl2_cons_mult = [10e-4,3e-4,3e-4]; sl2_p = [2,2]
    sl1_cons_mult = [6e-3,10e-4,4e-4]; sl0_p = [1.5,1.5]
    sl0_cons_mult = [6e-4,1e-4,4e-4]; sl1_p = [0.5,0.5]
    slope_st_pt = 10
elif sol == 2 and inv == 2:
    sl2_cons_mult = [4e-4,1e-8]; sl2_p = [2,2]
    sl1_cons_mult = [10e-4,5e-8]; sl1_p = [1,1] 
    slope_st_pt = 5
elif sol == 3 and inv == 1:
    sl2_cons_mult = [10e-4,3e-4,3e-4]; sl2_p = [2,2,2]
    sl1_cons_mult = [1e-4,5e-5,5e-5]; sl1_p = [2,2,2]  
    slope_st_pt = 10
elif sol == 3 and inv == 2:
    sl2_cons_mult = [5e-6,2e-7,2e-9]; sl2_p = [2,2,2]
    sl1_cons_mult = [9e-5,8e-5,20e-9]; sl1_p = [1,1,1]  
    slope_st_pt = 10

# Font size    
font = {#'family' : 'normal',
'weight' : 'normal',
'size'   : 14}
plt.rc('font', **font)
plt.figure(figsize=(15, 4))

for i in range(len(df['Method']))[0:]:
    plt.subplot(1,len(df['Method']),i+1)
    plt.plot(-df['t0'][0]+b_tt[i],b_ERR[i],':',color='orangered',label="Baseline")
    plt.plot(-df['t0'][0]+r_tt[i],r_ERR[i],'-g',label="Relaxation")
    sl_b = np.linspace(slope_st_pt,-df['t0'][0]+df['tf'][0],100)
    sl_r = np.linspace(slope_st_pt,-df['t0'][0]+df['tf'][0],100)
    
    if sol == 2 and inv ==1:
        plt.plot(sl_r,sl1_cons_mult[i]*sl_r**sl1_p[i],'--',color='0.5',label="$\mathcal{O}(t^{%1.1f})$"%(sl1_p[i]))
        plt.plot(sl_r,sl0_cons_mult[i]*sl_r**sl0_p[i],'-k',label="$\mathcal{O}(t^{%1.1f})$"%(sl0_p[i]))
        plt.plot(sl_b,sl2_cons_mult[i]*sl_b**sl2_p[i],'-',color='0.5',label="$\mathcal{O}(t^{%1.1f})$"%(sl2_p[i]))
    else:
        plt.plot(sl_r,sl1_cons_mult[i]*sl_r**sl1_p[i],'--',color='0.5',label="$\mathcal{O}(t^{%1.1f})$"%(sl1_p[i]))
        plt.plot(sl_b,sl2_cons_mult[i]*sl_b**sl2_p[i],'-',color='0.5',label="$\mathcal{O}(t^{%1.1f})$"%(sl2_p[i]))

    if sol == 1 and inv ==1:
        plt.xlabel('$t$')
    else:
        plt.xlabel('$t+%d$'%(-df['t0'][0]))
        
    plt.ylabel('Error in u')
    plt.xscale("log"); plt.yscale("log")
    
    if inv == 1:
        plt.title('%s with $\Delta t$ = %.1f'%(df['Method_labels'][i],df['B: dt'][i]))
    elif inv == 2:
        plt.title('%s with $\Delta t$ = %.2f'%(df['Method_labels'][i],df['B: dt'][i]))
        
    
ax = plt.gca()
handles, labels = ax.get_legend_handles_labels()
plt.figlegend(handles, labels, loc='upper center', ncol=5, bbox_to_anchor=(0.5, 1.15))
plt.savefig('./Figures/%s/%s_ErrorVsTime_dt%1.0e_tf%d.pdf'%(eqn,eqn,df['B: dt'][i],T),format='pdf', bbox_inches="tight")