In [None]:
import panel as pn
# pn.extension()
from bokeh.io import output_notebook
output_notebook()

from fftarray.fft_constraint_solver import fft_dim_from_constraints
from fftarray import FFTArray, FFTDimension
import numpy as np
from fftarray.backends.jax_backend import JaxTensorLib
from jax import config
from scipy import constants
from matterwave.helpers import generate_panel_plot
from matterwave import get_ground_state_ho

config.update("jax_enable_x64", True)


In [None]:
hbar: float = constants.hbar
a_0: float = constants.physical_constants['Bohr radius'][0]

# coupling constant (used in GPE)
coupling_fun = lambda m, a: 4 * np.pi * hbar**2 * a / m

# Rubidium 87
m_rb87: float = 86.909 * constants.atomic_mass # The atom's mass in kg.
a_rb87: float = 98 * a_0 # s-wave scattering length
coupling_rb87: float = coupling_fun(m_rb87, a_rb87) # coupling constant (used in GPE)

# Potassium 41
m_k41: float = 40.962 * constants.atomic_mass # The atom's mass in kg.
a_k41: float = 60 * a_0 # s-wave scattering length
coupling_k41: float = coupling_fun(m_k41, a_k41) # coupling constant (used in GPE)

# Interspecies interaction
a_rb87_k41: float = 165.3 * a_0

In [None]:
# Define dimensions
x_dim: FFTDimension = fft_dim_from_constraints(
    name="x",
    pos_middle=0,
    freq_middle=0,
    pos_extent=100e-6,
    n=2**12
)
y_dim: FFTDimension = fft_dim_from_constraints(
    name="y",
    pos_middle=0,
    freq_middle=0,
    pos_extent=100e-6,
    n=2**12
)
xarr = x_dim.fft_array(tlib=JaxTensorLib(), space="pos")
generate_panel_plot(xarr)

In [None]:
trap_frequencies_rb = 2*np.pi*np.array([25, 400])
trap_frequencies_k = np.sqrt(87/41) * trap_frequencies_rb
trap_minimum = np.zeros(2)

tlib = JaxTensorLib()

init_state_rb87_x: FFTArray = get_ground_state_ho(
    x_dim,
    tlib=tlib,
    omega=trap_frequencies_rb[0],
    mass=m_rb87
)

init_state_rb87_y: FFTArray = get_ground_state_ho(
    y_dim,
    tlib=tlib,
    omega=trap_frequencies_rb[1],
    mass=m_rb87
)

init_state_k41_x: FFTArray = get_ground_state_ho(
    x_dim,
    tlib=tlib,
    omega=trap_frequencies_k[0],
    mass=m_k41
)

init_state_k41_y: FFTArray = get_ground_state_ho(
    y_dim,
    tlib=tlib,
    omega=trap_frequencies_k[1],
    mass=m_k41
)

# Combine dimensions into 2-dimensional FFTArrays to represent wavefunctions

init_state_rb87: FFTArray = init_state_rb87_x + init_state_rb87_y
init_state_k41: FFTArray = init_state_k41_x + init_state_k41_y

generate_panel_plot(init_state_rb87)

In [None]:
generate_panel_plot(init_state_k41)

In [None]:
# Define imaginary time evolution to find ground state of the system with GPE (first single species)

from functools import partial
from typing import Any, Callable
from matterwave import propagate
from matterwave.split_step import get_V_prop
from matterwave.wf_tools import normalize
import jax.numpy as jnp

x_fftarray = x_dim.fft_array(tlib=tlib, space="pos")
y_fftarray = y_dim.fft_array(tlib=tlib, space="pos")

trap_potential_rb87 = 0.5 * m_rb87 * (
    trap_frequencies_rb[0]**2 * (x_fftarray-trap_minimum[0])**2 +
    trap_frequencies_rb[1]**2 * (y_fftarray-trap_minimum[1])**2
)

trap_potential_k41 = 0.5 * m_k41 * (
    trap_frequencies_k[0]**2 * (x_fftarray-trap_minimum[0])**2 +
    trap_frequencies_k[1]**2 * (y_fftarray-trap_minimum[1])**2
)

def gpe_potential(
    pos_state: FFTArray,
    coupling_constant: float,
    trap_potential: FFTArray,
    num_atoms: int = 1e5,
) -> FFTArray:
    self_interaction = num_atoms * coupling_constant * np.abs(pos_state)**2
    return self_interaction + trap_potential

def imaginary_time_evolution_single_species(
    state: FFTArray,
    dt: float,
    mass: float,
    V: Callable[[FFTArray], Any],
) -> FFTArray:

    complex_factor = -1.j

    # Apply half kinetic propagator
    state = propagate(state, dt = -1.j * 0.5*dt, mass = mass)

    # Apply potential propagator
    state = state.into(space="pos")

    V_prop = get_V_prop(V = V(state), dt = complex_factor * dt)

    state = state * V_prop

    # Apply half kinetic propagator
    state = propagate(state, dt = complex_factor * 0.5*dt, mass = mass)

    state = normalize(state)
    return state

reduced_gpe_potential_rb87 = partial(
    gpe_potential,
    coupling_constant=coupling_rb87,
    trap_potential=trap_potential_rb87,
)

init_state_rb87 = normalize(init_state_rb87)

ground_state_rb87 = imaginary_time_evolution_single_species(
    init_state_rb87,
    dt=1e-5,
    mass=m_rb87,
    V=reduced_gpe_potential_rb87,
)

generate_panel_plot(ground_state_rb87)

In [None]:
from matterwave.wf_tools import expectation_value, get_e_kin
kb: float = constants.Boltzmann


def imaginary_time_step_single_species(
    state: FFTArray,
    *_
):
    E_kin = get_e_kin(state, m=m_rb87, return_microK=True)
    # calculate the potential energy (and convert to µK)
    state=state.into(space="pos")
    E_pot = expectation_value(state, reduced_gpe_potential_rb87(state)) / (kb * 1e-6)
    state=state.into(space="freq")

    # calculate the total energy
    E_tot = E_kin + E_pot
    # split-step application
    state = imaginary_time_evolution_single_species(
        state,
        dt=1e-5,
        mass=m_rb87,
        V=reduced_gpe_potential_rb87,
    )
    return state, {"E_kin": E_kin, "E_pot": E_pot, "E_tot": E_tot}

from jax.lax import scan

N_iter = 100

ground_state_rb87, energies = scan(
    f=imaginary_time_step_single_species,
    init=init_state_rb87.into(space="freq"),
    xs=None,
    length=N_iter
)


In [None]:
generate_panel_plot(ground_state_rb87)

In [None]:
import numpy as np
from bokeh.plotting import figure, show
# plot the energy trend during the imaginary time evolution
plt = figure(
    width=800, height=400, min_border=50,
    y_axis_type="log",
    title="Energy values during imaginary time evolution",
    x_axis_label="Iteration step",
    y_axis_label="Energy in µK"
)
x_num_iter = np.arange(N_iter)
# numerical solution
plt.line(
    x_num_iter, energies["E_kin"],
    line_width=1.5, color="red", legend_label="Kinetic Energy"
)
plt.line(
    x_num_iter, energies["E_pot"],
    line_width=1.5, color="green", legend_label="Potential Energy"
)
plt.line(
    x_num_iter, energies["E_tot"],
    line_width=1.5, color="blue", legend_label="Total Energy"
)

# show the plot
show(plt)
