In [None]:
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

## Numerically evolve the time-dependent Schrodinger equation ##

Here we use FFT methods to evolve the Schrodinger equation numerically.  We "split" the Hamiltonian into a piece depending only on momentum ($p^2/2m$) and a piece depending only on position ($V(x)$).  For each piece applying the appropriate operator is easy in the right basis/space.  So we just switch back and forth between position and momentum space.

Specifically the "evolution operator" is
$$
  U = \exp\left[-iH\Delta t\right] =
  \exp[-i V \Delta t/2]
  \exp\left[-i \frac{p^2}{2m} \Delta t\right]
  \exp[-i V \Delta t/2] + \mathcal{O}\left(\Delta t^2\right)
$$

In [None]:
# Consider a 1D line with periodic boundary conditions, so we can
# use FFTs without having to worry about padding:
N = 128  # The number of grid points.

def potential():
    """Returns the potential on the grid."""
    n  = np.arange(N)
    x  = (n-0.5*N)/N
    Vx = 1-np.cos(2*np.pi*x)
    return( (x,Vx) )

In [None]:
fig,ax = plt.subplots(1,1,figsize=(8,6))
xx,Vx  = potential()
ax.plot(xx,Vx)
ax.set_xlabel(r'$x$')
ax.set_ylabel(r'$V(x)$')

In [None]:
xx,Vx = potential()
kk    = np.fft.fftfreq(N) * (2*np.pi)
# Initialize Psi(x) to be a Gaussian at x=0, moving to the right.
def set_ics():
    """Set up the initial conditions."""
    fx    = 20 * xx
    x2    = xx**2
    psi   = np.exp(-0.5*(xx/0.1)**2) * (np.cos(fx) - 1j * np.sin(fx))
    # Normalize to a convenient value.
    psi  /= np.sqrt( np.sum( np.abs(psi)**2 ) )
    return(psi)
#
psi    = set_ics()
fig,ax = plt.subplots(1,1,figsize=(8,6))
ax.plot(xx,psi.real,'b-',label='Real')
ax.plot(xx,psi.imag,'r:',label='Imag')
ax.plot(xx,np.abs(psi)**2,'k-',label='Prob')
ax.legend()
ax.set_xlabel(r'$x$')
ax.set_ylabel(r'$\psi(x)$')

In [None]:
def se_evolve(psi,dt=0.01):
    """Evolve for one step."""
    psi *= np.exp( -1j * Vx * dt/2 )
    #
    tmp  = np.fft.fft(psi)
    tmp *= np.exp( -1j * kk**2/2 * dt )
    psi  = np.fft.ifft(tmp)
    #
    psi *= np.exp( -1j * Vx * dt/2 )
    #
    return(psi)

In [None]:
maxiter = 5000
pltiter = 1000
#
fix,ax  = plt.subplots(1,1,figsize=(10,6))
psi     = set_ics()
for iter in range(maxiter):
    psi = se_evolve(psi)
    if iter%pltiter==0:
        ax.plot(xx,np.abs(psi)**2,label=str(iter))
ax.legend()
ax.set_xlabel(r'$x$')
ax.set_ylabel(r'Probability')
print("Final probability is {:12.4e}".format(np.sum(np.abs(psi)**2)))