<a href="https://colab.research.google.com/github/OleKrarup123/NLSE-vector-solver/blob/main/Quantum_SSFM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## About this notebook

This notebook attempts to solve the Schrodinger Equation using the Split-Step Fourier method with the intention of modelling the temporal evolution of a 1D quantum mechanical system.

A similar approach was used elsewhere to model the propagation of pulses subject to attenuation, dispersion and nonlinear effects in an optical fiber:

https://github.com/OleKrarup123/NLSE-vector-solver

https://colab.research.google.com/drive/1XyLYqrohf5GL6iFSrS6VlHoj_eSm-kAG#scrollTo=K42UGCF-Wrt4

Additional information about the SSFM can be found [here](https://www.researchgate.net/publication/281441538_An_introduction_to_the_Split_Step_Fourier_Method_using_MATLAB). Th


First, we define the Schrodinger Equation, where $ψ(x,t)$ is abbreviated with $ψ$:

$$i\hbar \partial_tψ =  \frac{-\hbar^2}{2m}  \partial^2_x\psi +V(x,t)ψ$$
$$ $$ 
$$ \partial_tψ =  \frac{i\hbar}{2m}  \partial^2_x\psi - \frac{i}{\hbar}V(x,t)ψ$$

Set $V(x,t)=\hbar\omega_0 U(x,t)$, where $\hbar\omega_0$ is some characteristic energy of the potential 
$$ $$ 
$$ \partial_tψ =  i\frac{\hbar}{2m}  \partial^2_x\psi - i\omega_0 U(x,t)ψ.$$
Now define $\tau = t\omega_0$ (causing $\tau$ to be in units of "radians"), so 
$$ \omega_0 \partial_\tauψ =  i\frac{\hbar}{2m}  \partial^2_x\psi - i\omega_0 U(x,t)ψ. $$
$$ $$
$$  \partial_\tauψ =  i\frac{\hbar}{2m\omega_0}  \partial^2_x\psi - i U(x,t)ψ. $$
$$ $$

Define $s=\frac{x}{\sqrt{\frac{\hbar}{2m\omega_0}}}$, so
$$  \partial_\tauψ =  i \partial^2_s\psi - i U(x,t)ψ. \quad (1) $$ 
$$ $$
It is equation (1) that we wish to solve using the split step method


##Import useful libraries

In [None]:
import numpy as np
from scipy.fftpack import fft, ifft, fftshift, ifftshift, fftfreq

import matplotlib.pyplot as plt
from matplotlib import cm

global pi; pi=np.pi 

#Define constants and simulation parameters

In [None]:
global hbar; hbar = 6.62607015e-34                                # Planck's reduced constant in J*s/rad
global h; h=hbar/2/pi                                             # Planck's constant in J*s
global mH; mH=1.6735e-27                                          # Hydrogen mass in kg
global c; c=3e8                                                   # Speed of light m/s
global joule_to_ev; joule_to_ev = 6.2415e18                       # Conversion from J to eV
global kelvin_to_joule; kelvin_to_joule = 1.380648780669e-29      # kelvin to joule
global kelvin_to_ev ; kelvin_to_ev = kelvin_to_joule*joule_to_ev  #  


In [None]:
N  = 2**15 #Number of points the x-axis will be divided into
print(N)




In [None]:
def getMomentumFromPos(pos):
    return fftshift(fftfreq(len(pos), d=pos[1]-pos[0]))*h

def getPosFromMomentum(momentum):  
    return fftshift(fftfreq(len(momentum), d=momentum[1]-momentum[0]))*h    

In [None]:
#Define a harmonic (x**2) potential from particle mass and width of classical turning range
def harmonicPotential(x,m,width):

  #Define characteristic oscillation frequency from particle mass and classical turning range
  omega=2*hbar/(m*(width/2)**2)

  #Define characteristic distance, momentum and energy for harmonic potential
  x_char=np.sqrt(hbar/(2*m*omega)) #Note: x_char = width/4
  p_char=np.sqrt(m*omega*hbar/2)
  E_char=hbar*omega
  

  return 0.5*m*omega**2*(x/x_char)**2,x_char,p_char,E_char 

#Define an infinite well potential from particle mass and widt
def infiniteWellPotential(x,m,width): 
  
  #Define characteristic distance, momentum and energy for infinite well potential
  x_char = width
  p_char = hbar*pi/width
  E_char = h**2/8/m/width**2

  

  V=np.ones_like(x)*1e100 #Assume that 10**100 potential energy is practicall infinite
  
  belowZero=np.where(np.abs(x)-width/2<0) #Set entries within width of x=0 to 0.0
  V[belowZero]=0.0
  return V,x_char,p_char,E_char

#Class for storing info about potentials. 
class potentials:
  def __init__(self,N,m,width):
    self.N=N
    self.m=m
    self.width=width
    
    self.xrange=np.linspace(-5*width/2,5*width/2,N)
    self.dx=self.xrange[1]-self.xrange[0]
    self.xmax=np.max(self.xrange)
    self.xmin=np.min(self.xrange)

    self.prange=getMomentumFromPos(self.xrange)
    self.dp=self.prange[1]-self.prange[0]
    self.pmax=np.max(self.prange)
    self.pmin=np.min(self.prange)


    self.harmpot,self.x_harm,self.p_harm,self.E_harm=harmonicPotential(self.xrange,m,width)
    self.omega_harm=self.E_harm/hbar
    self.f_harm=self.omega_harm/2/pi
  

    self.IWpot,self.x_IW,self.p_IW,self.E_IW=infiniteWellPotential(self.xrange,m,width)
    self.omega_IW=self.E_IW/hbar
    self.f_IW=self.omega_IW/2/pi
    
    self.self_describe()

  def self_describe(self):
    print("### Configuration Parameters ###")
    print(f" Number of points = {self.N}")
    print(f" Start pos, xmin = {self.xmin*1e6}um")
    print(f" Stop  pos, xmax = {self.xmax*1e6}um")
    print(f" Pos resolution, dx = {self.dx*1e9}nm")
    print("  ")
    print(f" Start momentum= {self.pmin/mH}amu m/s")
    print(f" Stop momentum = {self.pmax/mH}amu m/s")
    print(f" Momentum resolution, dp = {self.dp/mH}amu m/s")
    print("   ")
    print(f" Specified width = {width*1e6}um and mass = {m/mH}amu")
    print("   ")
    print(f" Harmonic potential parameters:")
    print(f" x_harm = {self.x_harm*1e6}um")
    print(f" p_harm = {self.p_harm/mH} amu m/s")
    print(f" E_harm = {self.E_harm*joule_to_ev*1e6}ueV")
    print(f" f_harm = {self.f_harm}Hz")
    print("   ")
    print(f" Infinite square well parameters:")
    print(f" x_IW = {self.x_IW*1e6}um")
    print(f" p_IW = {self.p_IW/mH} amu m/s")
    print(f" E_IW = {self.E_IW*joule_to_ev*1e6}ueV")
    print(f" f_IW = {self.f_IW}Hz")
    print("   ")




#Initialize potentials for a certain particle mass (87Rb):
A  = 87    #Number of nucleons in atom to be simulated.
m=mH*A     #Mass of atom to be simulated

width=15e-6 #Width of trapping potentials in m

pots=potentials(N,m,width)

In [None]:
#Plot potentials
plt.figure()
plt.plot(pots.xrange,pots.harmpot*joule_to_ev*1e6)
plt.plot(pots.xrange,pots.IWpot*joule_to_ev*1e6)
plt.axis([np.min(pots.xrange),np.max(pots.xrange),0,np.max(pots.harmpot)*joule_to_ev*1e6])
plt.show()

plt.figure()
plt.plot(pots.xrange)
plt.plot(getPosFromMomentum( getMomentumFromPos(pots.xrange)))
plt.show()


plt.figure()
plt.plot(pots.prange/mH)
plt.plot(getMomentumFromPos( getPosFromMomentum(pots.prange))/mH)
plt.show()





In [None]:


#Function returns pulse power or spectrum PSD
def getProbDens(amplitude):
    return np.abs(amplitude)**2  

#Function gets the energy of a pulse pulse or spectrum by integrating the power
def getProb(pos_or_freq,amplitude):
    prob=np.trapz(getProbDens(amplitude),pos_or_freq)
    
    assert np.abs(prob-1)<1e-7, f"ERROR: Integrated probability is {prob}, which is not equal to 1!!!"
    return prob

In [None]:
def QFFT(x,psi_x):
    posProb=getProb(x,psi_x) #Get pulse energy
    p=getMomentumFromPos(x) 
    dx=x[1]-x[0]
    
    psi_p=fftshift(fft(psi_x))*dx #Take FFT and do shift
    momentumProb=getProb(p, psi_p) #Get spectrum energy
    
    err=np.abs((momentumProb/posProb-1))
    
    assert( err<1e-7 ), f'ERROR = {err}: Total prob. changed when going from pos to momentum!!!' 
    
    return psi_p

In [None]:
def IW_states(x,w,n):
  kn=n*pi/w
  V=np.zeros_like(x) #Assume that 10**100 potential energy is practicall infinite
  
  not_zero=np.where(np.abs(x)-width/2<0) #Set entries within width of x=0 to 0.0
  V[not_zero]=1.0
  if n%2 == 0:
    return np.sqrt(2/w)*(np.sin(kn*x) )*(1+0j)*V
  else:
    return np.sqrt(2/w)*(np.cos(kn*x) )*(1+0j)*V


In [None]:
wf_test1=IW_states(pots.xrange,pots.width,1)
wf_test2=IW_states(pots.xrange,pots.width,2)
wf_test3=IW_states(pots.xrange,pots.width,3)


plt.figure()
plt.plot(pots.xrange,pots.IWpot*joule_to_ev*1e6,'k-')
plt.plot(pots.xrange,getProbDens(wf_test1),label=f"Prob = {getProb(pots.xrange,wf_test1)}")
plt.plot(pots.xrange,getProbDens(wf_test2),label=f"Prob = {getProb(pots.xrange,wf_test2)}")
plt.plot(pots.xrange,getProbDens(wf_test3),label=f"Prob = {getProb(pots.xrange,wf_test3)}")
plt.legend(bbox_to_anchor=(1.05,0.3))
plt.axis([np.min(pots.xrange)/4,np.max(pots.xrange)/4,0,np.max(getProbDens(wf_test1))])
plt.show()

plt.figure()
#plt.plot(pots.prange,pots.IWpot*joule_to_ev*1e6,'k-')
plt.plot(pots.prange,getProbDens( QFFT(pots.xrange, wf_test1) ),label=f"Prob = {getProb(pots.prange,QFFT( wf_test1))}")
plt.plot(pots.prange,getProbDens( QFFT(pots.xrange, wf_test2) ),label=f"Prob = {getProb(pots.prange,QFFT( wf_test2))}")
plt.plot(pots.prange,getProbDens( QFFT(pots.xrange, wf_test3) ),label=f"Prob = {getProb(pots.prange,QFFT( wf_test3))}")
plt.legend(bbox_to_anchor=(1.05,0.3))
plt.axis([np.min(pots.prange)/4,np.max(pots.prange)/4,0,np.max(getProbDens(wf_test1))])
plt.show()
