#!pip install ffmpeg-python

### Inspired by https://ben.land/post/2022/03/09/quantum-mechanics-simulation/

In [None]:
import numpy as np
import cupy as cp
import matplotlib.pyplot as plt
from matplotlib.colors import hsv_to_rgb
from IPython.display import HTML
from matplotlib.animation import FuncAnimation

from functools import wraps
import time


def timeit(func):
    @wraps(func)
    def timeit_wrapper(*args, **kwargs):
        start_time = time.perf_counter()
        result = func(*args, **kwargs)
        end_time = time.perf_counter()
        total_time = end_time - start_time
        print(f'Function {func.__name__} Took {total_time:.4f} seconds')
        return result
    return timeit_wrapper


## Simulation 

In [None]:
# simulation code
class SimulationBase:
    def __init__(self, method, normalize, save_every, dt, steps, mass=1):
        self.method = method
        self.normalize = normalize
        self.save_every = save_every
        self.dt = dt
        self.steps = steps
        self.mass = mass

    def dt_space(self):
        raise NotImplementedError('Method not implemented')

    def gradsq(self, phi):
        raise NotImplementedError('Method not implemented')

    def d_dt(self, phi,h=1,V=0):
        return (1j*h/2/self.mass) * self.gradsq(phi) - (1j/h)*V*phi

    def norm(self, phi, npy = np):
        norm = npy.sum(npy.square(npy.abs(phi)))* self.dt_space()
        return phi/np.sqrt(norm)

    def euler(self, phi, dt, **kwargs):
        return phi + dt * self.d_dt(phi, **kwargs)

    def rk4(self, phi, dt, **kwargs):
        k1 = self.d_dt(phi, **kwargs)
        k2 = self.d_dt(phi+dt/2*k1, **kwargs)
        k3 = self.d_dt(phi+dt/2*k2, **kwargs)
        k4 = self.d_dt(phi+dt*k3, **kwargs)
        return phi + dt/6*(k1+2*k2+2*k3+k4)

    @timeit
    def run(self, phi_sim, V=0, condition=None):
        
        simulation_steps = [phi_sim]
        for i in range(self.steps):
            # evolution of the next time step
            if self.method == 'euler':
                phi_sim = self.euler(phi_sim, self.dt,V=V)
            elif  self.method == 'rk4':
                phi_sim = self.rk4(phi_sim, self.dt,V=V)
            else:
                raise Exception(f'Unknown method { self.method}')
                
            if condition:
                phi_sim = condition(phi_sim)
            
            # make sure the wave state is always normalised
            if self.normalize:
                phi_sim = self.norm(phi_sim)
            
            # save current wave state
            if self.save_every is not None and (i+1) % self.save_every == 0:
                simulation_steps.append(phi_sim)

        return simulation_steps
    

class Simulation1D(SimulationBase):
    def __init__(self, method ='rk4', normalize=True, save_every=1000, dt=1e-1, steps=40000):
        super().__init__(method, normalize, save_every, dt, steps, mass=100)
        self.x = np.linspace(-10,10,5000)
        self.deltax = self.x[1]-self.x[0]

    def dt_space(self):
        return self.deltax

    def gradsq(self, phi):
        dphi_dxdx = -2*phi
        dphi_dxdx[:-1] += phi[1:]
        dphi_dxdx[1:] += phi[:-1]
        return dphi_dxdx/self.dt_space()
    
    def wave_packet(self, pos=0,mom=0,sigma=0.1):
        return self.norm(np.exp(-1j*mom*self.x)*np.exp(-np.square(self.x-pos)/sigma/sigma,dtype=np.complex128))


    
class Simulation2D(SimulationBase):
    def __init__(self, method ='rk4', normalize=True, save_every=100, dt=1e-1, steps=4000):
        super().__init__(method, normalize, save_every, dt, steps, mass=1000)
        x = cp.linspace(-10,10,500)
        y = cp.linspace(-10,10,500)
        self.extent = cp.asnumpy(cp.asarray([cp.min(x), cp.max(x), cp.min(y), cp.max(y)]))
        deltax = x[1]-x[0]
        deltay = y[1]-y[0]
        self.deltaxy = deltax*deltay
        self.xv, self.yv = cp.meshgrid(x, y, indexing='ij')
        
    def dt_space(self):
        return self.deltaxy

    def wave_packet(self, p_x = 0, p_y = 0, disp_x = 0, disp_y = 0, sqsig = 0.5):
        xv,yv = self.xv,self.yv
        return self.norm( cp.exp(1j*xv*p_x) * cp.exp(1j*yv*p_y)
                    *cp.exp(-cp.square(xv-disp_x)/sqsig,dtype=cp.complex128) 
                    *cp.exp(-cp.square(yv-disp_y)/sqsig,dtype=cp.complex128), npy=cp)
    a
    def gradsq(self, phi):
        gradphi = -4*phi
        gradphi[:-1,:] += phi[1:,:]
        gradphi[1:,:] += phi[:-1,:]
        gradphi[:,:-1] += phi[:,1:]
        gradphi[:,1:] += phi[:,:-1]
        r =  gradphi/self.dt_space()
        return r
    


## Ploting 

In [None]:

class DrawParticle1D:
    def __init__(self, particle1d, psi):
        self.particle1d=particle1d
        self.psi = psi

    @staticmethod
    def polygon(x1,y1,x2,y2,c,ax=None):
        # Draw function dt patch with a color
        if ax is None:
            ax = plt.gca()
        polygon = plt.Polygon( [ (x1,y1), (x2,y2), (x2,0), (x1,0) ], color=c )
        ax.add_patch(polygon)
    
    def draw(self,ax=None,**kwargs):
        x=self.particle1d.x
        y=self.psi
        # convert to magitude and phase
        mag = np.abs(y)
        phase = np.angle(y)/(2*np.pi)

        # circle phase to positive values
        mask = phase < 0.0
        phase[mask] = 1+phase[mask]

        # create color using phase
        hsv = np.asarray([phase,np.full_like(phase,0.5),np.ones_like(phase)]).T
        rgb = hsv_to_rgb(hsv[None,:,:])[0]

        # plot wave magnitude function
        if ax is None:
            ax = plt.gca()
        ax.plot(x,mag,color='k')   

        # fill only under big enough magnitude
        mask = mag > np.max(mag)*1e-2
        [DrawParticle1D.polygon(x[n],mag[n],x[n+1],mag[n+1],rgb[n],ax=ax) for n in range(0,len(x)-1) if mask[n] and mask[n+1]]
        ax.set_xlabel('Position')
        ax.set_xlim(-2,2)
        ax.set_ylim(0,2)


class DrawParticle2D:
    def __init__(self, particle2d, psi):
        self.particle2d=particle2d
        self.psi = psi
    
    @staticmethod
    def to_image(z,z_min=0,z_max=None,abssq=False):
        hue = cp.ones(z.shape) if abssq else cp.angle(z)/(2*cp.pi)
        mask = hue < 0.0
        hue[mask] = 1.0+hue[mask]
        mag = cp.abs(z)
        if z_max is None:
            z_max = cp.max(mag)
        if z_min is None:
            z_min = cp.min(mag)
        val = (mag-z_min)/(z_max-z_min)
        hsv_im = cp.transpose(cp.asarray([hue,cp.full_like(hue,0.5),val]))
        return hsv_to_rgb(hsv_im.get())

    def draw(self, z_min=None,z_max=None,abssq=False, ax=None,**kwargs):
        z = self.psi
        ax.set_xlim(-2,2)
        ax.set_ylim(-2,2)
        return ax.imshow(DrawParticle2D.to_image(z,z_min,z_max,abssq),
                         extent=self.particle2d.extent,
                         interpolation='bilinear',
                         **kwargs)

    
# drawing code
class ParticleRender:
    @staticmethod
    def animate(particle, simulation_steps,init_func=None):
        fig, ax = plt.subplots()
        # draw first frame
        draw_obj_cls = DrawParticle2D if type(particle)==Simulation2D else DrawParticle1D
        draw_obj_cls(particle, simulation_steps[0]).draw(ax=ax)
        if init_func:
            init_func(ax)

        def next_frame(frame):
            ax.clear()
            draw_obj_cls(particle, simulation_steps[frame]).draw(ax=ax)
            if init_func:
                init_func(ax)

        anim = FuncAnimation(fig, next_frame, frames=int(len(simulation_steps)), interval=100)
        plt.close()    
        return anim 
    
    @staticmethod
    def get_sim_video(particle, results, init_func=None):
        anim = ParticleRender.animate(particle, results, init_func=init_func)
        return anim.to_html5_video()
    




## A free, stationary particle

### 1D

In [None]:
simple_particle = Simulation1D()

#results = simple_particle.run(simple_particle.wave_packet(),V=0)
HTML(ParticleRender.get_sim_video(simple_particle, results))


### 2D

In [None]:
simple_particle_2d = Simulation2D()
results = simple_particle_2d.run(simple_particle_2d.wave_packet(p_x=10, sqsig=0.5),V=0)
HTML(ParticleRender.get_sim_video(simple_particle_2d,results))


## A particle in a box

### 1D

In [None]:
simple_particle = Simulation1D()
box_potential = np.where((simple_particle.x>-2)&(simple_particle.x<2),0,1)
results = simple_particle.run(
                                simple_particle.wave_packet(mom=40, sigma=0.2),
                                V=box_potential)

def box_init(ax):
    ax.axvspan(2, 3, alpha=0.2, color='red')
    ax.axvspan(-3, -2, alpha=0.2, color='red')
    ax.set_xlim(-3,3)
    ax.set_ylim(0,2)
             
HTML(ParticleRender.get_sim_video(simple_particle, results,  init_func=box_init))


### 2D

In [None]:
simple_particle_2d = Simulation2D(steps=8000)
box_potential = np.where((simple_particle_2d.xv>-2)&
                         (simple_particle_2d.xv<2) &
                         (simple_particle_2d.yv>-2)&
                         (simple_particle_2d.yv<2)
                         ,0,1)

results = simple_particle_2d.run(simple_particle_2d.wave_packet(p_x=20,p_y=0, sqsig=0.3),V=box_potential)
HTML(ParticleRender.get_sim_video(simple_particle_2d,results))



## A particle encounters a barrier

### 1D

In [None]:
simple_particle = Simulation1D()
barrier_weak_potential = np.where((simple_particle.x>2.4)&(simple_particle.x<2.6),3.5e-2,0)
wave_packet = simple_particle.wave_packet(mom=-40, sigma=0.2)
results = simple_particle.run(wave_packet, V=barrier_weak_potential)

def barrier_init(ax):
    ax.axvspan(2.4, 2.6, alpha=0.2, color='orange')
    ax.set_xlim(-2,5)
    ax.set_ylim(-1,3)

HTML(ParticleRender.get_sim_video(simple_particle, results,  init_func=barrier_init))


### 2D

In [None]:
barrier_weak_potential = np.where((simple_particle_2d.xv>2.4)&(simple_particle_2d.xv<2.6),1e-1, 0)

def barrier_init(ax):
    ax.axvspan(2.4, 2.6, alpha=0.3, color='orange')
    ax.set_xlim(-2,5)

simple_particle_2d = Simulation2D()
results = simple_particle_2d.run(simple_particle_2d.wave_packet(p_x=15,p_y=0, sqsig=1),V=barrier_weak_potential)
HTML(ParticleRender.get_sim_video(simple_particle_2d,results, init_func=barrier_init))


## A particle in a quadratic potential

In [None]:
simple_particle = Simulation1D(steps=100000)
quadratic_potential = 1e-2*np.square(simple_particle.x)
wave_packet = simple_particle.wave_packet(mom=-40, sigma=0.2)
results = simple_particle.run(wave_packet, V=quadratic_potential)

def quadratic_init(ax):
    ax.fill_between(simple_particle.x,(np.square(simple_particle.x)-3),-3,color='orange',alpha=0.2)
    ax.set_xlim(-3,3)
    ax.set_ylim(-0.5,3)
    
HTML(ParticleRender.get_sim_video(simple_particle, results,  init_func=quadratic_init))


## An aside on Eigenstates

#### This simulation framework is fun to generate visualizations with, but it can also be used to do real science by finding the ground and excited states of systems. This can be done by exploiting a technique called imaginary time evolution. Essentially, simply replacing dt with -idt in the simulation and propagating into “imaginary time” will damp out all but the lowest energy eigenstates.
#### Critically, this factor goes to zero faster for higher energy states, meaning the lowest energy state is the last to disappear. So, if we require that the wave function remain normalized, which the simulate method already does, simply evolving in imaginary time will damp out all but the lowest energy eigenstate

In [None]:
simple_particle = Simulation1D(dt=-1e-1j, steps=50000)
wave_packet = simple_particle.wave_packet(mom=-40, sigma=0.2)
sim_quad_0 = simple_particle.run(wave_packet, V=quadratic_potential)    
HTML(ParticleRender.get_sim_video(simple_particle.x, sim_quad_0,  init_func=quadratic_init))


#### To generate an excited state, in principle, one could:
- take any wave packet
- remove the ground state from it (i.e. set it’s coefficient to zero)
- perform the same procedure on the resulting to find the first excited eigenstate 


In [None]:
sim_quad_0[-1]

In [None]:
psi = simple_particle.wave_packet(mom=40)
phi_0 = sim_quad_0[-1]
Phi_1 = psi - np.sum(np.conjugate(phi_0)*psi)*simple_particle.deltax*phi_0
ParticleRender.complex_plot(simple_particle.x,Phi_1)

#### Numerical instability in integrating the Schrodinger equation will invariably put some infinitesimal probability back into the ground state, causing it imaginary time evolution to once again collapse to it. The canonical (quick and dirty) solution to this problem is to simply remove the ground state from the wave function after each time step, to ensure its coefficient stays approximately zero, and then normalizing the wave function again

In [None]:
def orthogonal_to(deltax, states):
    def orthogonalize(phi):
        for state in states:
            phi = phi - np.sum(np.conjugate(state)*phi)*deltax*state
        return phi
    return orthogonalize

simple_particle = Simulation1D(dt=-1e-1j, steps=50000)
sim_quad_1 = simple_particle.run(Phi_1, V=quadratic_potential,
                                condition=orthogonal_to(simple_particle.deltax, [phi_0]))    

HTML(ParticleRender.get_sim_video(simple_particle.x, sim_quad_1,  init_func=quadratic_init))


In [None]:
simple_particle = Simulation1D(dt=-1e-1j, steps=50000)
phi_1 = sim_quad_1[-1]
Phi_2 = psi - np.sum(np.conjugate(phi_1)*psi)*simple_particle.deltax*phi_1

sim_quad_2 = simple_particle.run(Phi_2, V=quadratic_potential,
                                    condition=orthogonal_to(simple_particle.deltax, [phi_0, phi_1]))    

HTML(ParticleRender.get_sim_video(simple_particle.x, sim_quad_2,  init_func=quadratic_init))


In [None]:
simple_particle = Simulation1D(dt=-1e-1j, steps=50000)
phi_2 = sim_quad_2[-1]
Phi_3 = psi - np.sum(np.conjugate(phi_2)*psi)*simple_particle.deltax*phi_2

sim_quad_3 = simple_particle.run(Phi_3, V=quadratic_potential,
                                    condition=orthogonal_to(simple_particle.deltax, [phi_0, phi_1, phi_2]))    

HTML(ParticleRender.get_sim_video(simple_particle.x, sim_quad_3,  init_func=quadratic_init))


#### Now, you might ask yourself how to verify that these states are in fact eigenstates without having to look up and plot the analytic solutions to the QHO
#### The answer is simple: evolve it in time, and only the phase should change!

In [None]:
phi_3 = sim_quad_3[-1]

simple_particle = Simulation1D(dt=1e-1, steps=10000)
result = simple_particle.run(
                                phi_3,
                                V=quadratic_potential)

HTML(ParticleRender.get_sim_video(simple_particle.x, result,  init_func=quadratic_init))
