In [None]:
%matplotlib inline

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML
import numba

In [None]:
x_max = 20
dx = 0.1
n_points = (int) ((2 * x_max) / dx) + 1
if n_points % 2 == 0:
    n_points += 1

In [None]:
x = np.linspace(- (n_points // 2) * dx, (n_points // 2) * dx, n_points)

In [None]:
kappa = 1 
psi0 = np.vectorize(lambda x: np.sqrt(kappa) * np.exp(kappa * x) + 0j if x < 0 else np.sqrt(kappa) * np.exp(- kappa * x) + 0j)(x)

In [None]:
plt.plot(x, np.real(psi0))

In [None]:
@numba.jit(nopython=True)
def iterate(psi, dt):
    dpsi = np.zeros(psi.shape, dtype=psi.dtype)
    dpsi[0] = 0.5 * 1j * (psi[1] + psi[-1] - 2 * psi[0]) / dx ** 2
    dpsi[-1] = 0.5 * 1j * (psi[0] + psi[-2] - 2 * psi[-1]) / dx ** 2
    center = n_points // 2
    for i in range(1, n_points-1):
        dpsi[i] = 0.5 * 1j * (psi[i+1] + psi[i-1] - 2 * psi[i]) / dx ** 2
    for i in range(n_points):
        psi[i] += dpsi[i] * dt
    #psi[center] = (psi[center-1] + psi[center+1]) / (2 - 2 * kappa * dx)

In [None]:
output_dt = 0.05

psi = psi0.copy()
dt = 0.00001
save_every = int(output_dt / dt)
t_end = 2 * np.pi
points = (int) (t_end / dt) + 1

psis = []
ts = []
for i in range(points):
    if i % save_every == 0:
        psis.append(psi.copy())
        ts.append(i * dt)
    iterate(psi, dt)

In [None]:
%%capture

def plot(i):
    plt.clf()
    plt.plot(x, np.real(psis[i]))
    plt.plot(x, np.imag(psis[i]))
    plt.ylim(-1.1, 1.1)
    
fig = plt.figure()
anim = animation.FuncAnimation(fig, plot, frames=len(psis), interval=20).to_html5_video()

In [None]:
HTML(anim)

In [None]:
%%capture

def plot(i):
    plt.clf()
    plt.plot(x, np.abs(psis[i]))
    plt.ylim(-1.1, 1.1)
    
fig = plt.figure()
anim = animation.FuncAnimation(fig, plot, frames=len(psis), interval=20).to_html5_video()

In [None]:
HTML(anim)

In [None]:
norms = [dx * np.sum(np.abs(psi) ** 2) for psi in psis]
plt.plot(ts, norms)

In [None]:
import solver

In [None]:
e = solver.EulerSolverDelta(10, 0.1, delta_depth=1.0)

In [None]:
%load_ext autoreload

In [None]:
%autoreload 2

In [None]:
ts, psis = e.execute(2 * np.pi, 0.0005, 0.05)

In [None]:
%%capture

def plot(i):
    plt.clf()
    plt.plot(e.x, np.abs(psis[i]))
    plt.ylim(-1.1, 1.1)
    
fig = plt.figure()
anim = animation.FuncAnimation(fig, plot, frames=len(psis), interval=20).to_html5_video()

In [None]:
HTML(anim)