In [None]:
# -*- coding: utf-8 -*-
"""
2D Time-Dependent Schrodinger's equation for a particle wave double-slit 
diffraction in 2D, calculated using FFT2 in scipy.fft

author: Han Chen
email: hansolo@vt.edu
license: MIT
Please feel free to use and modify this, but keep the above information. Thanks!

This 2D solution is built upon Jake Vanderplas's 1D solution (contact info below)
"""

"""
General Numerical Solver for the 1D Time-Dependent Schrodinger's equation.

author: Jake Vanderplas
email: vanderplas@astro.washington.edu
website: http://jakevdp.github.com
license: BSD
Please feel free to use and modify this, but keep the above information. Thanks!

Han Chen upgraded it to 2D
"""
import time
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import animation
from IPython.display import HTML

from scipy.fft import fft2,ifft2                 # scipy >= 1.4

class Schrodinger(object):
    # Class which implements a numerical solution of the time-dependent
    # Schrodinger equation for an arbitrary potential
    def __init__(self, x, y, psi_xy0, V_xy,
                 u0 = None, v0 = None, hbar=1, m=1, t0=0.0):
        """
        Parameters
        ----------
        x : array_like, float
            length-N array of evenly spaced spatial coordinates
        psi_xy0 : array_like, complex
            length-N array of the initial wave function at time t0
        V_xy : array_like, float
             length-N array giving the potential at each x
        k0 : float
            the minimum value of k.  Note that, because of the workings of the
            fast fourier transform, the momentum wave-number will be defined
            in the range
              k0 < k < 2*pi / dx
            where dx = x[1]-x[0].  If you expect nonzero momentum outside this
            range, you must modify the inputs accordingly.  If not specified,
            k0 will be calculated such that the range is [-k0,k0]
        hbar : float
            value of planck's constant (default = 1)
        m : float
            particle mass (default = 1)
        t0 : float
            initial tile (default = 0)
        """
        # Validation of array inputs
        self.x, self.y, psi_xy0, self.V_xy = map(np.asarray, (x, y, psi_xy0, V_xy))
        N = self.x.size
        M = self.y.size
        assert self.x.shape == (N,)
        assert psi_xy0.shape == (M,N)
        assert self.V_xy.shape == (M,N)
        
        # Set internal parameters
        self.hbar = hbar;                  self.m = m
        self.time = t0;                    self.dt_1 = None
        self.N = len(self.x);              self.dx = self.x[1] - self.x[0]
        self.M = len(self.y);              self.dy = self.y[1] - self.y[0]
        self.du = 2 * np.pi / (self.N * self.dx)
        self.dv = 2 * np.pi / (self.M * self.dy);
        self.psi_xy = None;                 self.psi_mod_xy = None;  
        self.psi_uv = None;                 self.psi_mod_uv = None
        
        # setting up evoluation attributes, and dynamic plotting attributes
        self.xy_evolve_half = None;         self.psi_xy_line = None
        self.xy_evolve = None;              self.psi_uv_line = None
        self.uv_evolve = None;              self.V_xy_line   = None
        
        # set momentum scale
        if u0 == None:
            self.u0 = -0.5 * self.N * self.du
        else:
            self.u0 = u0
        self.u = self.u0 + self.du * np.arange(self.N)
        
        if v0 == None:
            self.v0 = -0.5 * self.M * self.dv
        else:
            self.v0 = v0
        self.v = self.v0 + self.dv * np.arange(self.M)
        self.v = np.reshape(self.v, (-1,1))
        
        # define the initial psi_xy and psi_uv
        self.psi_xy = psi_xy0;
        self.compute_psi_mod_xy(self.psi_xy)  # compute psi_xy     -> psi_mod_xy
        self.compute_xy_to_uv()               # compute psi_mod_xy -> psi_mod_uv
        self.psi_uv = self.Psi_uv             # use property getter, get psi_uv

    def compute_psi_mod_xy(self, psi_xy):
        self.psi_mod_xy = (self.dx *self.dy / np.sqrt(2 * np.pi) * psi_xy
                           * np.exp(-1j * self.u[0] * self.x
                                    -1j * self.v[0] * self.y))

    def compute_psi_mod_uv(self, psi_uv):
        self.psi_mod_uv = psi_uv * np.exp(1j * self.x[0] * self.du * np.arange(self.N) + 
                                          1j * self.y[0] * self.dv * np.arange(self.M))
 
    def compute_xy_to_uv(self):
        self.psi_mod_uv = fft2(self.psi_mod_xy)

    def compute_uv_to_xy(self):
        self.psi_mod_xy = ifft2(self.psi_mod_uv)

    def time_step(self, dt_3, Nsteps = 1):
        """
        Perform a series of time-steps via the time-dependent
        Schrodinger Equation.

        Parameters
        ----------
        dt_3 : float
            the small time interval over which to integrate
        Nsteps : float, optional
            the number of intervals to compute.  The total change
            in time at the end of this method will be dt_3 * Nsteps.
            default is N = 1
        """
        self.dt = dt_3   # calling property dt to define evolutions
        
        assert Nsteps > 0
        
        for i in range(Nsteps):
            self.psi_mod_xy *= self.xy_evolve_half    # <--- half-step in xy
            self.compute_xy_to_uv()                   # <--- FFT2
            self.psi_mod_uv *= self.uv_evolve         # <--- one-step in uv
            self.compute_uv_to_xy()                   # <--- IFFT2
            self.psi_mod_xy *= self.xy_evolve_half    # <--- half-step in xy
        
        self.compute_xy_to_uv()                       # <--- FFT2, update psi_mod_uv
        self.time += dt_3 * Nsteps                    # calculate time
        
    @property
    def Psi_xy(self):            # returns psi_xy value from psi_mod_xy
        return (self.psi_mod_xy * np.sqrt(2 * np.pi) / (self.dx * self.dy)
                * np.exp(1j * self.u[0] * self.x + 
                         1j * self.v[0] * self.y))
    
    @property
    def Psi_uv(self):            # returns psi_uv value from psi_mod_uv
        return (self.psi_mod_uv * 
                np.exp(-1j * self.x[0] * self.du * np.arange(self.N) + 
                       -1j * self.y[0] * self.dv * np.reshape(np.arange(self.M),(-1,1))))
    
    @property
    def dt(self):
        return self.dt_1
    
    @dt.setter
    def dt(self, dt):
        if dt != self.dt_1:
            self.dt_1 = dt
            self.xy_evolve_half = np.exp(-1/2 * 1j * self.V_xy * dt / self.hbar )
            self.uv_evolve = np.exp(-1/2 * 1j * self.hbar 
                                    * (self.u * self.u + self.v * self.v) 
                                    * dt / self.m)
            # self.xy_evolve = self.x_evolve_half * self.x_evolve_half
    
    @property
    def time_elapsed(self):
        return self.time

In [None]:
# Helper functions for gaussian wave-packets
def gauss_x(x, y, a, x0, y0, k0):
    """
    a gaussian wave packet of width a, centered at x0, with momentum k0
    """ 
    y = np.reshape(y,(-1,1))
    exponent = -(((x-x0)**2 + (y-y0)**2)/(2*a**2) - 1j * x * k0)
    gauss_result = 1/(a*np.sqrt(2*np.pi)) * np.exp(exponent)
    return gauss_result

def square_barrier(x, y, x_wall, width, height, opening_1, opening_2):
    y = np.reshape(y,(-1,1))
    V_matrix = 0*x*y
    
    x_slicer   = np.where((x>(x_wall))   & (x<(x_wall+width)))
    y_slicer_1 = np.where((y> opening_1) & (y<  opening_2))
    y_slicer_2 = np.where((y>-opening_2) & (y< -opening_1))
    
    V_matrix[:         , x_slicer] += height
    V_matrix[y_slicer_1, :       ] *= 0
    V_matrix[y_slicer_2, :       ] *= 0
    return V_matrix

def parabolic_barrier(x, a, b, h):
    return a*(x - h)**2 + b

In [None]:
# Create the animation
# specify time steps and duration
dt = 0.01;                       N_steps = 25
t_max = 210;                     frames = int(t_max / float(N_steps * dt))

# specify constants
hbar = 1;  # planck's constant
m = 3.0      # particle mass

# specify range in x, y coordinate
# N = 2 ** 8;
# dx = 0.1;           x = dx * (np.arange(N) - 0.5 * N)
# dy = 0.1;           y = dy * (np.arange(N) - 0.5 * N)

N = 2**8;                       M = 2**7
x = np.linspace(-100,100, N, endpoint=True)
y = np.linspace(-100,100, M, endpoint=True)
dx = x[1] - x[0];               dy = y[1] - y[0]
y = np.reshape(y,(-1,1))

# specify potential
V0 = 1.5;
L = hbar / np.sqrt(2 * m * V0)
a = 3 * L;          a = 10
opening_1 = 3
opening_2 = 10
x_wall = 0;        width = 2
height = 1e8;      x0 = -60 * L
V_xy = square_barrier(x, y, x_wall, width, height, 
                      opening_1, opening_2)
V_xy[:, np.where(x < -98)] = 1e8
V_xy[:, np.where(x >  98)] = 1e8

# specify initial momentum and quantities derived from it
p0 = np.sqrt(2 * m * 0.2 * V0);      dp2 = p0 * p0 * 1./80
d = hbar / np.sqrt(2 * dp2);         k0 = p0 / hbar;
v0 = p0 / m;                         x0 = -30;
y0 = 0
psi_xy_0 = gauss_x(x, y, a, x0, y0, k0)
                
# define the Schrodinger object which performs the calculations
S = Schrodinger(x=x, 
                y=y,
                psi_xy0=psi_xy_0,
                V_xy=V_xy,
                u0 = None,
                v0 = None,
                hbar=hbar, 
                m=m, 
                t0=0.0)

# plt.contourf(x,y.flatten(),V_xy)

In [None]:
%%time
arr_x = S.x;                      arr_y = S.y
arr_u = S.u;                      arr_v = S.v
arr_V_xy = S.V_xy
arr_psi_xy = np.zeros((frames, len(arr_y), len(arr_x)))
arr_psi_uv = arr_psi_xy.copy()
arr_t = np.zeros(frames)
arr_p = arr_t.copy()    # <--- important! create a copy

# initial condition:
i = 0
arr_psi_xy[i, :, :] = 4 * abs(S.Psi_xy)
arr_psi_uv[i, :, :] =     abs(S.Psi_uv)
arr_t[i] = S.time_elapsed
arr_p[i] = (x0 + S.time_elapsed * p0 / m)

for i in range(1,len(arr_t)):
    S.time_step(dt, N_steps)
    arr_psi_xy[i, :, :] = 4 * abs(S.Psi_xy)
    arr_psi_uv[i, :, :] =     abs(S.Psi_uv)
    arr_t[i] = S.time_elapsed
    arr_p[i] = (x0 + S.time_elapsed * p0 / m)

In [None]:
# Set up plot
# cmap = plt.get_cmap('hot')
# from matplotlib.colors import BoundaryNorm
# from matplotlib.ticker import MaxNLocator
# levels = MaxNLocator(nbins=15).tick_values(arr_psi_xy.min(), arr_psi_xy.max())
# norm = BoundaryNorm(levels, ncolors=cmap.N, clip=True)
psi_max = arr_psi_xy.max() 
psi_min = arr_psi_xy.min() 
cmap = plt.get_cmap('PiYG_r')

fig1, ax1 = plt.subplots(1,1,figsize = (9,8))
ax1.axis('equal')
ax1.axis([-50,50,-50,50])
plt.close()

def animate(i):
    ax1.collections = []
    ax1.contourf(arr_x, arr_y.flatten(), abs(arr_psi_xy[i , :, :]), cmap=cmap,
                 vmin = psi_min, vmax=psi_max)
    ax1.fill_between([x_wall, (x_wall+width)], [ opening_2,  opening_2], 
                     [max(y.flatten()), max(y.flatten())], facecolor='white')
    ax1.fill_between([x_wall, (x_wall+width)], [ opening_1,  opening_1], 
                     [-opening_1, -opening_1], facecolor='white')
    ax1.fill_between([x_wall, (x_wall+width)], [-opening_2, -opening_2], 
                     [min(y.flatten()), min(y.flatten())], facecolor='white')
    ax1.set_title(r"$|\psi (x,y,t)|$, t = %.2f" % arr_t[i])

# call the animator.  blit=True means only re-draw the parts that have changed.
anim = animation.FuncAnimation(fig1, animate,
                               frames=frames, interval=30, blit=False);

In [None]:
%%time
HTML(anim.to_jshtml())