# Bragg beam splitter

## Physics

The Bragg beam splitter splits the wavefunction into two momentum states (but same internal state). 
Here, the main physics is sketched. 
The interested reader is referred to [the textbook by Grynberg, Aspect and Fabre](https://www.cambridge.org/core/books/introduction-to-quantum-optics/F45DCE785DC8226D4156EC15CAD5FA9A). 

The formalism describing the Bragg atom-light interaction is based on a semi-classical model. 
The atoms are described quantum mechanically via a wavefunction, while the light is described classically (its high intensity ensures a substantial photon presence at all times).
The Hamiltonian of the system is:
$$
    \hat H = \frac{\hat p^2}{2m} - \hat{\vec D} \vec E_L (\vec r, t)
$$
where $\hat p$ is the momentum operator, $\hat{\vec D}$ is the dipole operator and $\vec E_L$ is the electric field. 

The electric field can be described by two counterpropagating laser beams with frequencies $\omega_L+\omega_r$ (drives the transition $| g,0 \rangle \rightarrow | e, \hbar k_L \rangle$) and $\omega_r$ (drives the transition $| e, \hbar k_L \rangle \rightarrow | g, 2\hbar k_L\rangle$). 
Note that $\frac{(2\hbar k_L)^2}{2m} = \hbar \omega_r$.
It should be noted that both lasers are detuned by $\Delta$, such that the transition $|g,0 \rangle \rightarrow | e, \hbar k_L \rangle$ is unlikely to happen without the stimulated emission directly after it. 

Here, a one dimensional wavefunction $\Psi (x)$ is considered. 
Adiabatic elimination of the excited state leads to
$$
\hat H = -\frac{\hbar^2}{2m}\nabla^2 + 2 \hbar \Omega \cos ^2 \left( k_L x - \frac{\omega_r}{2} t \right)
$$
where $\Omega$ is the effective Rabi frequency. 
$\Omega$ is determined by the laser properties and has typically a Gaussian temporal profile to ensure good velocity selectivity. 
If the atoms are freely falling, an acceleration term $\frac{1}{2}a_\text{laser}t^2$ is added to the laser phase to ensure that the laser beams stay resonant to the falling atoms. 
Also common is an additional constant phase shift $\Phi_0$.

After the atom-light interaction, the atom is left in a superposition of states $|g,0\rangle$ and $|g,2\hbar k_L\rangle$, and typically higher orders like $|g,-2\hbar k_L\rangle$ and $|g,4\hbar k_L\rangle$ [[Siemß 2020](https://link.aps.org/doi/10.1103/PhysRevA.102.033709)]. 
Idealized, this sequence applies a momentum transfer of $2\hbar k_L$ to the atom with a $50\%$ chance.


## Implementation
This example illustrates the implementation of propagating a matterwave through a Bragg beam splitter.
We utilize the `matterwave.split_step` routine to propagate a matterwave (described by an `FFTArray`) under the influence of the time-dependent Bragg beam splitter potential.

### Initialize global constants

In [None]:
# plotting in jupyter notebook
import panel as pn
pn.extension()
from bokeh.io import output_notebook
output_notebook()

from scipy.constants import pi
from matterwave.rb87 import k_L, hbarkv

# angular frequency used to initialize the ground state (of quantum harmonic
# oscillator)
omega_x = 2*pi*10 # Hz
# laser pulse parameters
# Rabi frequency. This specific value was found as a binary search to
# optimize a 50/50 split of the two momentum classes for this specific beam
# splitter duration and pulse form.
rabi_frequency = 25144.285917282104 # Hz
phi_0 = 0. # phase jump
bragg_acc = 0. # bragg acceleration
sigma_bs = 25e-6 # temporal pulse width (s)
w_r = 2 * hbarkv * k_L # rb87 constant

#### Define the time grid
In order to simulate the time evolution of the matterwave, we need to discretize 
time into intervals of length $\Delta t$ in between the potential is
assumed constant. Then, for each time step, we can apply the time evolution 
operator as 
$|\Psi(t+\Delta t)\rangle = e^{-i \hat H \Delta t /\hbar}|\Psi(t)\rangle$ 
using the `split_step` function.

Here, we choose to sample $4\sigma_\mathrm{bs}$ of the Gaussian temporal profile
of the Bragg beam splitter potential with Gaussian width $\sigma_\mathrm{bs}$ 
with a step size of $\Delta t = 1$ µs. Additionally, we let the matterwave 
freely propagate for $25$ ms with $\Delta t = 50$ µs after applying the beam 
potential to illustrate the separation of both momentum states in position 
space.

```text
 laser start      intensity peak        laser end            simulation end
------|-----------------|-------------------|----------------------|-----> time
      |------------ 4 sigma_bs -------------|-- free propagation --|
```

In [None]:
import numpy as np
# define how many sigmas of the gauss to sample in each direction before
# reaching zero intensity:
sampling_range_mult = 4. # * sigma_bs
dt_bs = 1e-6 # time step size
# total number of pulse grid steps = gauss_width * scaling_factor / step_size
steps_limit = int(round(sigma_bs * sampling_range_mult / dt_bs))
t_offset = steps_limit*dt_bs
nt_bs = 2*steps_limit # number of time steps for beam splitter
dt_free = 5e-5 # defines time step size for free propagation
nt_free = 50 # number of time steps for free propagation
# time lists
t_list_bs = np.arange(nt_bs)*dt_bs
t_list_free = t_list_bs[-1]+np.arange(1,nt_free+1)*dt_free
t_list = np.concatenate((t_list_bs, t_list_free))

### Initialize the wavefunction
We initialize the wavefunction as the groundstate of a quantum harmonic oscillator with frequency $\omega_\mathrm{QHO} = 2\pi\times 10$ Hz.

In [None]:
from fftarray import FFTDimension, FFTArray
from fftarray.backends.jax_backend import JaxTensorLib
from fftarray.fft_constraint_solver import fft_dim_from_constraints
from matterwave.rb87 import m as mass_rb87
from matterwave import get_ground_state_ho
from matterwave.helpers import generate_panel_plot

# coordinate grid
x_dim: FFTDimension = fft_dim_from_constraints(
    name = "x",
    pos_min = -50e-6,
    pos_max = 50e-6,
    freq_middle = 0.,
    freq_extent = 32*k_L,
    loose_params = ["freq_extent"]
)

# initialize FFTArray as harmonic oscillator groundstate
wf_init: FFTArray = get_ground_state_ho(
    dim = x_dim,
    tlib =  JaxTensorLib(),
    omega = omega_x,
    mass = mass_rb87
)

generate_panel_plot(wf_init)

### Define the potential
Now, we implement the external potential
$$
    V = 2 \hbar \Omega(t) \cos ^2 \left( k_L x - \frac{\omega_r}{2} t \right)
$$
where the time dependent Rabi frequency is defined by
$$
    \Omega(t) = \Omega_0 \exp(-t^2/(2\sigma_\mathrm{bs}^2)).
$$

In [None]:
from scipy.constants import hbar

# position operator
x: FFTArray = x_dim.fft_array(JaxTensorLib(), space="pos")

def V(ramp: float, t: float) -> FFTArray:
    """Bragg pulse potential.

    Parameters
    ----------
    ramp : float
        The pulse ramp (scaling the rabi frequency).
    t : float
        The global lab time.

    Returns
    -------
    FFTArray
        The potential at time t.
    """
    return rabi_frequency * ramp * 2. * hbar * np.cos(
        k_L * (x - 0.5 * bragg_acc * t**2)
        - 0.5 * w_r * t
        + phi_0/2.
    )**2

Now, we plot the potential at peak intensity (`t=t_offset`).

In [None]:
from bokeh.plotting import figure, show
# plot the energy trend during the imaginary time evolution
plt = figure(
    width=800, height=400, min_border=50,
    title="The potential versus the wavefunction's position grid",
    x_axis_label="Iteration step",
    y_axis_label="Potential/hbar [Hz]"
)
plot = plt.line(
    x_dim.np_array(space="pos"), V(ramp=1, t=t_offset).values/hbar
)
show(plt)

### Propagate the wavefunction
Now, we implement the iterative procedure of evolving the wavefunction in time according to the previously defined potential and the subsequent free propagation. 
For this, we need to define the step function that we will iterate via the `jax.lax.scan` function.

#### Define the step functions

<div class="alert alert-block alert-info">
<b>Note:</b> We need to import jax.numpy to calculate the temporal ramp of the 
Bragg beam potential in order to use it in jax.lax.scan. The potential is 
defined using numpy as this is internally handled via jax.numpy when the 
JaxTensorLib was specified in the initialization of the FFTArray.
</div>

In [None]:
import jax.numpy as jnp
from matterwave import split_step, propagate

# helper function for the temporal ramp of the beam's intensity
def gauss(t: float, sigma: float):
    """Evaluate a Gaussian function.

	Parameters
	----------
	t : float
		Time value.
	sigma : float
		Width.

	Returns
	-------
	float
		The Gaussian function at t with width sigma.
    """
    return jnp.exp(-0.5 * (t / sigma)**2)

# compute the boundary value of the gauss to ensure that the potential is zero
# when we start and stop applying of the potential
gauss_offset = gauss(t = -t_offset, sigma = sigma_bs)

def step_bs(wf: FFTArray, t: float):
    """Step function for jax.lax.scan. Apply the Bragg beam splitter pulse.

	Parameters
	----------
	wf : FFTArray
		The wavefunction.
	t : float
		The time.

	Returns
	-------
	FFTArray
		The final wavefunction.
    """
    ramp = gauss(t = t-t_offset, sigma = sigma_bs) - gauss_offset
    wf = split_step(wf,
        dt = dt_bs,
        mass = mass_rb87,
        V = V(ramp, t)
    )
    return wf, {
        "abs_pos": np.abs(wf.into(space="pos")).values**2,
        "abs_freq": np.abs(wf.into(space="freq")).values**2
    }

def step_free(wf: FFTArray, t: float):
    """Step function for jax.lax.scan. Freely propagate the wavefunction.

	Parameters
	----------
	wf : FFTArray
		The wavefunction.
	wf : float
		The time.

	Returns
	-------
	FFTArray
		The final wavefunction.
    """
    wf = propagate(wf, dt = dt_free, mass = mass_rb87)
    return wf, {
        "abs_pos": np.abs(wf.into(space="pos")).values**2,
        "abs_freq": np.abs(wf.into(space="freq")).values**2
    }


#### Iterating the step functions

<div class="alert alert-block alert-warning">
<b>Important:</b> The input and output FFTArray of the step function for 
jax.lax.scan has to have the same attributes. This implies that the input and 
output need to be in the same space. Since split_step returns the FFTArray in 
frequency space, we also need to transform the input FFTArray into the 
frequency space before passing it to jax.lax.scan.
</div>

In [None]:
from jax.lax import scan

# bragg beam splitter pulse
wf_final_bs, wf_data_bs = scan(
    f = step_bs,
    init = wf_init.into(space="freq"), # initial wavefunction in frequency space
    xs = t_list_bs
)
# free propagation
wf_final_free, wf_data_free = scan(
    f = step_free,
    init = wf_final_bs, # already in frequency space
    xs = t_list_free
)

<div class="alert alert-block alert-info">
<b>Note:</b> If you are trying to run the Bokeh app within VSCode, you may 
encounter an error like
"ERROR:bokeh.server.views.ws:Refusing websocket connection from Origin ...". 
To bypass this, you need to follow the instructions given in the error message 
and set the flag BOKEH_ALLOW_WS_ORIGIN to the specific string given there.
</div>

In [8]:
from ipywidgets import interact
from bokeh.layouts import layout
from bokeh.io import push_notebook, curdoc
from bokeh.models import Slider, Select, Button

# animation is iterating through t_list with steps animation_time_step_mult for
# the beam splitter part, otherwise 1
anim_time_step_mult = 5

# x values
x_list = x_dim.np_array(space="pos")
# create time list for animation
# (only take every anim_time_step_mult'th element of t_list_bs)
plt_t_list = np.concatenate((
    t_list_bs[::anim_time_step_mult],
    t_list_free
))

ramp_list = [
    float(gauss(t = t-t_offset, sigma = sigma_bs) - gauss_offset)
    for t in plt_t_list]

# initial values
wf_init_abs_pos = np.abs(wf_init.into(space="pos"))**2
wf_init_abs_freq = np.abs(wf_init.into(space="freq"))**2
# list of all position space values
plt_wf_final_pos = np.concatenate((
    # initial value
    [wf_init_abs_pos],
    # beam splitter values: take only every anim_time_step_mult'th value
    wf_data_bs["abs_pos"][anim_time_step_mult::anim_time_step_mult],
    # free propagation values
    wf_data_free["abs_pos"]
))
# list of all momentum space values
plt_wf_final_freq = np.concatenate((
    # initial value
    [wf_init_abs_freq],
    # beam splitter values: take only every anim_time_step_mult'th value
    wf_data_bs["abs_freq"][anim_time_step_mult::anim_time_step_mult],
    # free propagation values
    wf_data_free["abs_freq"]
))

plt_V = figure(
    width=600, height=200,
    title="Pulse ramp",
    x_axis_label="t [s]",
    y_axis_label="Ramp"
)
plt_V_line = plt_V.line(plt_t_list, ramp_list)
plt_V_scatter = plt_V.scatter([plt_t_list[0]], [ramp_list[0]], color="red")

plt_pos = figure(
    width=500, height=400, min_border=50,
    title="Wavefunction in position space",
    x_axis_label="x pos coordinate [m]",
    y_axis_label="Propability"
)
plt_pos_line = plt_pos.line(x_dim.np_array(space="pos"), plt_wf_final_pos[0])
plt_freq = figure(
    width=500, height=400, min_border=50,
    title="Wavefunction in frequency space",
    x_axis_label="x freq coordinate [1/m]",
    y_axis_label="Propability"
)
plt_freq_line = plt_freq.line(x_dim.np_array(space="freq"), plt_wf_final_freq[0])

t_snapshots = {
    "animated": None,
    "before pulse": 0,
    "at pulse peak": t_offset,
    "after pulse": 2*t_offset,
    "after free propagation": plt_t_list[-1]
}

def bkapp(doc):

    t_select = Select(
        title="Show wavefunction ...",
        value="animated",
        options=list(t_snapshots.keys())
    )

    t_slider = Slider(
        title="Time in µs",
        value=0.,
        start=float(t_list[0])/1e-6,
        end=float(t_list[-1])/1e-6,
        step=dt_bs/1e-6

    )

    def time_changed(attr, old, new):
        """Called when the time slider or select was updated."""
        if t_select.value == "animated":
            t = t_slider.value * 1e-6
        else:
            t = t_snapshots[t_select.value]
            t_slider.value = t/1e-6
        tj = (np.abs(plt_t_list - t)).argmin()
        plt_V_scatter.data_source.data["x"] = [t]
        plt_V_scatter.data_source.data["y"] = [ramp_list[tj]]
        plt_pos_line.data_source.data["y"] = plt_wf_final_pos[tj]
        plt_freq_line.data_source.data["y"] = plt_wf_final_freq[tj]

    def update_time():
        """Iteratively called in every animation step."""
        time = t_slider.value*1e-6
        if time <= nt_bs*dt_bs:
            time += anim_time_step_mult*dt_bs
        else:
            time += dt_free
            if time > float(t_list[-1]):
                time = float(t_list[0])
        t_slider.value = time/1e-6 # calles time_changed

    t_anim_btn = Button(label='► Play', width=60, margin=(5, 10, 5, 5))
    callback_id = None
    def animate():
        """Animate the wavefunction in time. Play and pause the animation."""
        nonlocal callback_id
        if t_anim_btn.label == '► Play':
            t_anim_btn.label = '❚❚ Pause'
            t_select.value = "animated"
            callback_id = curdoc().add_periodic_callback(update_time, 100)
        else:
            t_anim_btn.label = '► Play'
            curdoc().remove_periodic_callback(callback_id)

    t_select.on_change("value", time_changed)
    t_slider.on_change("value", time_changed)
    t_anim_btn.on_click(animate)

    anim_layout = layout([
        [t_select],
        [t_anim_btn, t_slider],
        [plt_V],
        [plt_pos, plt_freq],
    ])
    doc.add_root(anim_layout)

show(bkapp)

### The final wavefunction

In [None]:
generate_panel_plot(wf_final_free)