In [None]:
import scipy.linalg as LA
from matplotlib import animation
from IPython.display import HTML
import numpy as np
import matplotlib.pyplot as plt

%matplotlib inline

In [None]:
def H1diag(t):
    k = 2*np.pi*np.fft.fftfreq(npoints,d=dx)
    Es = k*k/2
    return Es

def H2diag(t,psi):
    return 1/2*np.conjugate(psi)*psi

In [None]:
def TimeStap(psi,t):
    psi  = np.exp(-1j*(dt/2)*H2diag(t,psi))*psi
    t   += dt/2
    psi_ = np.fft.fft(psi)
    psi_ = np.exp(-1j*dt*H1diag(t))*psi_
    t   += dt/2
    psi  = np.fft.ifft(psi_)
    psi  = np.exp(-1j*(dt/2)*H2diag(t,psi))*psi
    return psi

In [None]:
def make_grid(npoints,L):
    dx=L/npoints
    return np.arange(-L/2,L/2,dx,complex)

def grey_soliton(grid,nu=0.5,z0=0):
    gamma=1/np.sqrt(1-nu**2)
    return 1j*nu+1/gamma*np.tanh((grid-z0)/(np.sqrt(2)*gamma))

def dark_soliton(grid,z0=0):
    return grey_soliton(grid,nu=0,z0=z0)

In [None]:
L       =  40
npoints = 500
dx      = L/npoints
dt      =   0.025
tsteps  = 3000

In [None]:
grid  = make_grid(npoints,L)
psi0  = grey_soliton(grid, 0.5,-10)
psi0 *= grey_soliton(grid,-0.5, 10)

In [None]:
def animateTimeEvolution(xrange=[-L/2,L/2],yrange=[0,1.1]):
    fig, ax = plt.subplots()
    # add more axis specifications here, e.g. axis limits
    line, = ax.plot(xrange,yrange)
    global psi_t
    psi_t = psi0

    def animate(t):
        # assign current wave function to y at time t, x is the vector of grid points (to be defined!)
        #y = TimeEvolution(psi0, t)
        global psi_t
        y = psi_t
        psi_t = TimeStap(psi_t,t)
        line.set_data((grid).real, (np.conjugate(y)*y).real)
        plt.close()
        return (line,)

    anim = animation.FuncAnimation(fig, animate,
                                   frames=np.arange(0,tsteps*dt,dt), # t-values
                                   interval=50, # wait time before displaying new frame in ms
                                   blit=True)

    return anim.to_jshtml()

In [None]:
HTML(animateTimeEvolution())

In [None]:
plt.plot(np.real(grid), np.real(psi0))
plt.plot(np.real(grid), np.imag(psi0))

In [None]:
grid  = make_grid(npoints,L)
psi0  = grey_soliton(grid,  0.3,  -10)
psi0 *= grey_soliton(grid, -0.05,   -2)
psi0 *= grey_soliton(grid,  0.967746031217134,  6)

In [None]:
HTML(animateTimeEvolution())

In [None]:
def TimeEvolution(psi0):
    all_psis = np.zeros((tsteps+1,npoints))
    t = 0
    psi = np.exp(-1j*(dt/2)*H2diag(t,psi0))*psi0
    t += dt/2
    all_psis[0] = np.real(np.conjugate(psi)*psi)
    for i in range(tsteps):
        psi_ = np.fft.fft(psi)
        psi_ = np.exp(-1j*dt*H1diag(t))*psi_
        t += dt/2
        psi  = np.fft.ifft(psi_)
        psi  = np.exp(-1j*dt*H2diag(t,psi))*psi
        t += dt/2
        all_psis[i+1] = np.real(np.conjugate(psi)*psi)
    return all_psis

In [None]:
all_psis = TimeEvolution(psi0)

X, Y = np.meshgrid(np.real(grid), np.linspace(dt/2,(tsteps+1/2)*dt,tsteps+1))
plt.contourf(X, Y, all_psis, levels=np.linspace(0,1.1,20))
plt.xlabel('$x$ $[\epsilon]$')
plt.ylabel('t')
plt.colorbar()
plt.show()