# Discovering the KdV equation from data

In [None]:
import pysindy as ps

import matplotlib.pyplot as plt
import numpy as np
from scipy.io import loadmat


## Importing the Dataset

In [None]:
KdV_data = loadmat('./kdv_data.mat')

In [None]:
u = KdV_data['u']
x = KdV_data['x'].flatten()
t = KdV_data['t'].flatten()
dt = t[1] - t[0]
dx = x[1] - x[0]

Checking the shape of the data

In [None]:
print(' Number of time points:', t.shape, '\n Number of spatial points:', x.shape, '\n Shape of the u:', u.shape)

Therefore, the data imported has the shape u(T,X), with T being time, and X being space.
You should also check that the time and space are indeed sampled at fixed frequency (open `t` and `x` to check!)

## Visualise

It is always a good idea to have a sanity check that the data make sense and see if there are potential challenges.

In [None]:
plt.plot(x,u[-1,:])
plt.xlabel('x')
plt.ylabel('u')
plt.title('KdV solution at final time')
plt.show()

In [None]:
plt.plot(t,u[:,200])
plt.xlabel('t')
plt.ylabel('u')
plt.title('Temporal evolution at the middle of the spatial domain')
plt.show()

In [None]:
# Plot derivative results
plt.figure()
plt.pcolormesh(x, t, u)
plt.xlabel('x', fontsize=16)
plt.ylabel('t', fontsize=16)
plt.title(r'$u(x, t)$', fontsize=16)


Things to look out for:
- What is the scale of time `t`, `x`, and `u`?
    - That gives you a sense of the expected scale of the derivatives. 
    - Do you need to rescale things to make the learning easier?
- What's the signal time scale and length scale? 
    - Am I sampling frequently enough (e.g. above Nyquist frequency)?
- If the signal is noisy, can you at least make out the lengthscale and time scale? That will help you choose the right derivative scheme. For example, if it's too noisy but frequently sampled, may be you want to use weak formulation to filter the frequency signal.

# Your code here

How would you prepare the data for PySINDy?

Which PySINDy feature would you use?

In [None]:
# Define the library
#   Here we use the standard PDE library, which computes the derivatives automatically (using finite difference by default)
pde_lib = ps.PDELibrary ## Your code here

# Define the optimizer 
optimizer = ps.STLSQ ## Your code here

# Define the SINDy model
model = ps.SINDy ## Your code here

# Reshape the data and fit the model
#   Note about PDELibrary and data shape:
#   The PDELibrary expects the input data to be in the shape (T, X, D), where T is the number of time points, X is the number of spatial points, and D is the number of dimensions of the target (D=1 for scalar PDEs).
model_data = u.T.reshape(u.shape[1],u.shape[0],1)

# Call the regression algorithm
model.fit(model_data, t=dt, feature_names=['u'])

# print identified PDE
print("Identified PDE: ")
model.print()