# Finding the ground state via imaginary time evolution

## Physics

### Quantum harmonic oscillator

Consider the quantum harmonic oscillator (QHO) - the quantum-mechanical analog
of the classical harmonic oscillator. The Hamiltonian is defined as 

$$
H = \hat T + \hat V = \frac{\hat p^2}{2m} + \frac{1}{2}m w^2\hat x^2
$$

where $\hat p$ is the momentum operator, $m$ is the mass, $\omega$ is the
angular frequency, and $\hat x$ is the position operator.

Our goal is to find the ground state of a one-dimensional QHO for a given
angular frequency $\omega_x$. For a given initial state, the imaginary time
evolution can be iteratively applied to evolve the initial state to the ground
state of our system. 

### The imaginary time evolution
The imaginary time evolution is equivalent to applying the split-step operator

$$
e^{-\tfrac{i}{\hbar}H \mathrm{dt}} \approx e^{-\tfrac{i}{\hbar}\frac{\hat T}{2} \mathrm{dt}} e^{-\tfrac{i}{\hbar}\hat V \mathrm{dt}} e^{-\tfrac{i}{\hbar}\frac{\hat T}{2} \mathrm{dt}}
$$

with imaginary time step $\mathrm{dt}=-i\mathrm{dt}$. 
By exchanging the time step, the time evolution operator turns into a
real-valued coefficient. Expanding the initial state in terms of eigenstates of
the system $\Psi = \sum_n a_n\Psi_n$ reveals that each eigenstate $\Psi_n$ is
scaled by $e^{-\frac{1}{\hbar} E_n \mathrm{dt}}a_n$ where $H \Psi_n = E_n
\Psi_n$. Thus, by iteratively applying the split-step operator, an approximation
for the ground state will remain. States with higher energy will be suppressed
by their small coefficient. Note that the wavefunction has to be normalized
after each application as the split-step operator became non-unitary. 
Spoiler alert: this is covered inside the `split_step` function if the argument
`is_complex` is set to `True`.

## Code

### Initialize global variables

In [None]:
# plotting in jupyter notebook
import panel as pn
pn.extension()
from bokeh.io import output_notebook
output_notebook()

from scipy.constants import pi
from matterwave.rb87 import m as m_rb87
# here: mass of rb87
mass = m_rb87 # kg
# the angular frequency of the QHO whose ground state and energy is to find:
omega_x = 2.*pi # Hz
# omega_x_init is used to generate the initial state that will be evolved to the
# desired ground state (the generated state is the ground state of the QHO with
# angular frequency omega_x_init):
omega_x_init = 2.*pi*0.1 # Hz
# time step for split-step:
dt = 1e-4 # s
# choose dimension x [m]

### Initialize the grid
The `FFTDimension` is created using `fft_dim_constraints` that solves the 
special boundary conditions of the position and frequency space grid for a 
discrete FFT.

In [None]:
from fftarray import FFTDimension
from fftarray.fft_constraint_solver import fft_dim_from_constraints
# insert the constraints for the x dimension, all other free variables will be
# set accordingly
x_dim: FFTDimension = fft_dim_from_constraints(
    "x",
    n = 2048,
    pos_min = -200e-6,
    pos_max = 200e-6,
    freq_middle = 0.
)

### Initialize the wavefunction

As a first guess for our groundstate, we initialize an `FFTArray` as the ground
state of the QHO with frequency `omega_x_init` and mass `mass`. For this, we use
the `get_ground_state_ho` function and pass it our `FFTDimension` and TensorLib
(here: `JaxTensorLib`).

In [None]:
from fftarray import FFTArray
from matterwave import get_ground_state_ho
from fftarray.backends.jax_backend import JaxTensorLib

from matterwave.helpers import generate_panel_plot
# initialize the wavefunction as the ground state of the QHO with omega_x_init
wf_init: FFTArray = get_ground_state_ho(
    dim = x_dim,
    tlib = JaxTensorLib(),
    omega = omega_x_init,
    mass = mass
)
# plotting
generate_panel_plot(wf_init)

### Define the QHO potential
In this step, we define the QHO potential:
$$
V_\mathrm{QHO} = \frac{1}{2}m w^2\hat x^2
$$
from the position operator $\hat x$.

In [None]:
# define the position operator
x: FFTArray = x_dim.fft_array(tlib=JaxTensorLib(), space="pos")
# define the potential
V_qho: FFTArray = 0.5 * mass * omega_x**2. * x**2.

Note that `V_qho` is of instance `FFTArray` as it incorporates the values of the
potential along the `FFTDimension` "x" in position space.

### Imaginary time evolution
First we define the function that is iteratively called by `jax.lax.scan`.

In [None]:
from scipy.constants import Boltzmann
from matterwave import split_step, get_e_kin, expectation_value

def step(wf: FFTArray, *_):
    """
    The step function for the iteration using jax.lax.scan.

    Args:
        wf (FFTArray): The wavefunction which sould be evolved.

    Returns:
        Tuple[FFTArray, dict]: Returns the wavefunction for the next iteration
        step and a dictionary containing the energy values.
    """
    # save the energy in µK to avoid small values (~1e-33 for Joule)
    # calculate the kinetic energy (result is returned in µK)
    E_kin = get_e_kin(wf, m=mass, return_microK=True)
    # calculate the potential energy (and convert to µK)
    E_pot = expectation_value(wf, V_qho) / (Boltzmann * 1e-6)
    # calculate the total energy
    E_tot = E_kin + E_pot
    # split-step application (set is_complex=True to use imaginary time step)
    wf = split_step(wf=wf, dt=dt, mass=mass, V=V_qho, is_complex=True)
    # split_step() normalizes the wavefunction if: is_complex=True
    # return the wavefunction for the next iteration step and the energies in a
    # dictionary for plotting (every iteration the energies are appended to a
    # list)
    return wf, {"E_kin": E_kin, "E_pot": E_pot, "E_tot": E_tot}


Now, we utilize the `jax.lax.scan` function to loop over the previously defined
`step` function. We additionally pass it our initial wavefunction `wf_init` and 
the number of steps we want to perform (here: 10654 are enough such that the 
total energy converges).

`jax.lax.scan` requires that the step function returns an `FFTArray` in the same
space as its input. As `split_step` (and, thus, `step` as well) returns the
`FFTArray` in the frequency space, we need to transform the initial wavefunction
into the frequency space first, before passing it to `jax.lax.scan`.

In [None]:
from jax.lax import scan
# 10654 iteration steps are performed (this was found to be enough such that the
# total energy converges)
N_iter = 10654
# calls jax.lax.scan to start the iteration
# scan returns the final wavefunction and a dictionary containing the energies
# {"E_kin": [...], "E_pot": [...], "E_tot": [...]}
wf_final, energies = scan(f=step, init=wf_init.into(space="freq"), xs=None, length=N_iter)

The final wavefunction after iteration using `jax.lax.scan`:

In [None]:
generate_panel_plot(wf_final)

### Plot the energy convergence

The analytical solution for this problem is very well known, where the
grounstate energy should be
$$
    E_0 = \tfrac{1}{2}\hbar\omega
$$

In [None]:
from scipy.constants import hbar
# the analytical solution of the ground state energy
E_tot_analy = 0.5*omega_x*hbar/(Boltzmann * 1e-6) #microK

We plot the analytical solution together with the kinetic, potential and total
energy during the imarginary time evolution:

In [None]:
import numpy as np
from bokeh.plotting import figure, show
# plot the energy trend during the imaginary time evolution
plt = figure(
    width=800, height=400, min_border=50,
    title="Energy values during imaginary time evolution",
    x_axis_label="Iteration step",
    y_axis_label="Energy in µK"
)
x_num_iter = np.arange(N_iter)
# numerical solution
plt.line(
    x_num_iter, energies["E_kin"],
    line_width=1.5, color="red", legend_label="Kinetic Energy"
)
plt.line(
    x_num_iter, energies["E_pot"],
    line_width=1.5, color="green", legend_label="Potential Energy"
)
plt.line(
    x_num_iter, energies["E_tot"],
    line_width=1.5, color="blue", legend_label="Total Energy"
)
# analytical solution
plt.line(
    x_num_iter, np.full((N_iter,), E_tot_analy),
    line_width=1.5, line_dash="dashed", color="black",
    legend_label="Ground state energy (analytical solution)"
)
# show the plot
show(plt)