# Contents
- [Imports](#Imports)
- [Parameters](#Parameters)
- [Scans](#Scans)
- [Plot_alpha](#Plot_alpha)
- [Plot_1Ds](#Plot_1Ds)
- [Debugging](#Debugging)
- [Convergence](#Convergence)

# Imports

In [None]:
%matplotlib notebook
import matplotlib.pyplot as plt
import numpy as np
from arc import *
from time import time
from scipy.optimize import curve_fit
from scipy.interpolate import interp1d
from typing import List, Dict
import gc

from basics import *
from utility import *
from floquet_hamiltonian import *


pi=np.pi
tau=2*pi

# Parameters
## Computational Parameters

In [None]:
# Initialize to None when computational paramters are changed. Lets downstream program know when to re-build basis
basis = None
levels = None
basis_n = None

energies = None
eigenstates = None

# Computational Parameters
scheme = ["red", "mid", "low"][2]

if scheme == "low":
    t_level = RydStateFS(52,1,1/2)
else:
    t_level = RydStateFS(52,1,3/2)
    
max_det = tau*40e9  # radians/s
dl = 2
n_max = 2
single_side = False  # True currently breaks code. Issue in utility.level_starts() function.

comp = {
    "max_det": max_det,
    "dl": dl,
    "n_max": n_max
}

## Physical Parameters

In [None]:
# Physical Parameters
Edc = 0 # V/m
if scheme == "red":
    zerox = {1: 78.6, 2: 82.6, 3: 85.6, 4: 89.6}
    Eac = 89.6 # 19.5  # V/m
    ellipticity = 0.0120
else:
    ellipticity = 0
    if scheme == "mid":
        Eac = 40.0  # zerox here
    elif scheme == "low":
        Eac = 30  # zerox here
energy_bands = ellipticity > 1e-3  # TODO, choose a nice threshold for this

if scheme == "red":
    field_omega = tau*4.780e9  #radians/s
elif scheme == "mid":
    field_omega = tau*5.095e9  #radians/s # Mid Detuning
elif scheme == "low":
    # resonance at 6720.5MHz
    field_omega = tau*6.62e9  #radians/s # 52P1/2

theta = 0.0*pi/2  # angle between quantization axis and electric field (radians)

Eac_list = None
Edc_list = None
ellip_list = None
theta_list = None

fields = {
    "Eac": Eac,
    "ellipticity": ellipticity,
    "field_omega": field_omega,
    "Edc": Edc,
    "theta": theta
}

#Edc_list = np.linspace(0,15,100)  # V/m
Eac_list = np.linspace(0,100,300)  # V/m
#ellip_list = np.linspace(0,np.sqrt(0.07),100)**2
#theta_list = np.linspace(0,pi,100)

if Edc_list is not None:
    varied = ("Edc", Edc_list)
elif Eac_list is not None:
    varied = ("Eac", Eac_list)
elif ellip_list is not None:
    varied = ("ellipticity", ellip_list)
elif theta_list is not None:
    varied = ("theta", theta_list)
    
dataset=60
if dataset == 60:
# measured resonance between nP3/2 -> nD3/2 levels
    p3o2_d3o2_res = 4839*1e6*tau
# for 685nm at +80MHz wrt 4-6' transition. Generator at 191.25MHz
elif dataset == 80:
    p3o2_d3o2_res = 4840*1e6*tau
# effective laser induced AC stark shift
laser_dAC = detuning(RydStateFS(51,2,3/2), t_level, p3o2_d3o2_res)

# Build Basis Lists (Define Hilbert Space)

In [None]:
try:
    energies = None
    eigenstates = None
except NameError:
    pass
if basis is None or levels is None:
    levels, basis = build_basis(t_level,max_det,single_side=single_side, dl=dl)
    t_inds = [ind for ind, state in enumerate(basis) if [state.n, state.l, state.j] == [t_level.n, t_level.l, t_level.j]]
if basis_n is None:
    basis_n = build_n_basis(basis, n_max)
print(len(basis),len(basis_n))
basis_print(levels)

# Scans
## Scan Parameter specified in physical parameters

In [None]:
try:
    energies = None
    eigenstates = None
except NameError:
    pass
if basis is None or levels is None:
    levels, basis = build_basis(t_level,max_det,single_side=False, dl=dl)
    t_inds = [ind for ind, state in enumerate(basis) if [state.n, state.l, state.j] == [t_level.n, t_level.l, t_level.j]]
if basis_n is None:
    basis_n = build_n_basis(basis, n_max)


fields_p = {key: value for key, value in fields.items()}
fields_p[varied[0]] = varied[1]

#old_scans = eigen_find(t_level,comp, fields_p)
#loading = False
# TODO : this is broken rn. Implement eigen_terpolate then fix this.
#if loading and len(old_scans) > 0: 
#    energies = np.load(old_scans[0][0].format("energies"))
#    eigenstates = np.load(old_scans[0][0].format("eigenstates"))
#else:
H0 = build_H0(basis, t_level, offset_zeeman=1e6)
#H0[t_inds,t_inds] += laser_dAC
f_inds = [(ind,state.j) for ind, state in enumerate(basis) if state.l == 2]
f_jsplit = 3e6

energies, eigenstates = floquet_loop(
    basis,
    H0,
    n_max=n_max,
    varied=varied,
    energy_bands=energy_bands,
    **fields
)

#f_name = f"convergenceTest-nmax={n_max}-dw={max_det}-dl={dl}-{{}}"
#filepath = "Shirley-Floquet_results\\convergence_tests"
#np.save(os.path.join(filepath,f_name.format("energies")), energies)
#np.save(os.path.join(filepath,f_name.format("eigenstates")), eigenstates)
#eigen_save(t_level, comp, fields_p, energies, eigenstates)

# Scan $E_{AC}$ or $\theta$ and check 2nd + 4th order DC polarizability 

In [None]:
def check_polarizability(
    basis: List[RydStateFS], 
    H0: np.ndarray,
    dc_end: float, 
    samples: int,
    comp: Dict[str, int],
    fields: Dict[str, float],
    energy_bands: bool=False
) -> np.ndarray:
    """
    Determines the polarizability of all states in basis for AC/DC field parameters
    provided in fields dict.
    
    Args:
        basis: list of atomic states considered in computation
        H0: Hamiltonian of unperturbed atom system (expected to be diagonal). Matrix
            elements should be reported in radial frequency (radians/s)
        dc_end: maximum electric field strength to be sampled. dc electric field values
            are sampled from 0 to dc_end
        samples: number of dc electric field strength values to take. System is prone to crashing when
            too few samples are provided. Stable for when 100 samples per 20 volt span
        fields: dict of AC and DC field values. Currently the following parameters are expected
            to be represented in fields:
                "Eac" : Electric field strength of the AC field (V/m)
                "ellipticity" : ellipticity of the AC field. Field polarization defined as
                    e_ac = sqrt(1-ellipticity)e_pi + sqrt(ellipticity)e_+
                "field_omega" : oscillation frequency of the AC. Reported in radial frequency (radians/s)
                "Edc" : Ignored if included in fields dict
                "theta" : angle between quantization axis and electric field direction (radians)
        comp: dict of computational parameters used for this computation. Keys:
            "max_det" : float, maximum energy difference between t_level and any level included
                in the basis used in the computation
            "dl" : int, maximum difference between t_level.l and the orbital angular momentum
                quantum number of other states in the basis, |l-lp| <= dl
            "n_max" : int, maximum number of fourier components of the AC field to include in the
                computation
    Returns:
        popts: list of fit parameters. popts.shape = (len(basis)*(2*n_max+1), len(fit_args))
            when fits == True:
                fit_args = v0, alpha, e0, beta. Parameters for a 4th order fit to the numeric data
                func = e0+alpha*(v-v0)**2+beta*(v-v0)**4
        
    """
    dcs = ("Edc", np.linspace(0,dc_end,samples))
    energies, eigenstates = floquet_loop(basis, H0, n_max=comp["n_max"], varied = dcs, energy_bands=energy_bands, **fields)
    
    quad = lambda v, v0, alpha, e0, beta: e0-1/2*alpha*(v-v0)**2+beta*(v-v0)**4
    guess = [0,0,0,0]
    
    fits = True
    
    dim = len(basis)*(2*n_max+1)
    popts = np.zeros((dim, len(guess)), dtype=float)
    perrs = np.zeros(popts.shape, dtype=float)

    for i in range(dim):
        band_energies = energies[i,:]
        guess = guess = [0.0, (band_energies[-1]-band_energies[0])/dc_end**2, band_energies[samples//2],0]
        if fits:
            try:
                popt, pcov = curve_fit(quad, dcs[1], band_energies, p0 = guess)
                perr = np.sqrt(np.diag(pcov))

                popts[i] = popt
                perrs[i] = perr
            except RuntimeError:
                print("WARNING: Error fitting polarizability data")
                popts[i] = np.NaN
                perrs[i] = np.NaN    
        else:
            dmax_abs = max(abs(band_energies-band_energies[0]))
            dmax_ind = np.argwhere(abs(band_energies-band_energies[0])==dmax_abs)[0,0]
            print(dmax_abs,dmax_ind)
            dmax = (band_energies-band_energies[0])[dmax_ind]
            e_max = dcs[1][dmax_ind]
            popts[i,2] = dmax/e_max**2
            perrs[i,2] = np.NaN
            
    fields["Edc"] = dcs[1]
    eigen_save(t_level, comp, fields, energies, eigenstates)
    es0 = eigenstates[...,0]
    del energies
    del eigenstates
    return popts, perrs, es0

try:
    del energies
    del eigenstates
except NameError:
    pass

# build basis if not already built
if basis is None or levels is None:
    levels, basis = build_basis(t_level,max_det,single_side=False, dl=dl)
    t_inds = [ind for ind, state in enumerate(basis) if [state.n, state.l, state.j] == [t_level.n, t_level.l, t_level.j]]
if basis_n is None:
    basis_n = build_n_basis(basis, n_max)

fields_p = {key: value for key, value in fields.items()}
fields_p[varied[0]] = varied[1]


H0 = build_H0(basis,t_level, offset_zeeman=1e6)
#H0[t_inds,t_inds] += laser_dAC
f_inds = [(ind,state.j) for ind, state in enumerate(basis) if state.l == 2]
f_jsplit = 3e6

samples = 100
Emax = 15
# set Eac to 2nd order zero-crossing when theta is scanned
if scheme == "red":
    zerox = {1: 78.6, 2: 82.6, 3: 85.6, 4: 89.6}
    Eac=zerox[4] 
elif scheme == "mid":
    zerox = [0,27.20,30.55,zerx[2:].mean()]            
    Eac = zerox[1]
elif scheme == "low":
    zerox = 46.81052748704767
    Eac=zerox

    
# Eac and theta values to sample
Eacs = np.linspace(0,100,41)  # V/m
thetas = np.linspace(-pi/2,pi/2,21)  # radians

sweep_ac = False  # scan Eac Values
sweep_theta = True  # scan theta values

if sweep_ac:
    alphas = np.zeros((len(Eacs),len(basis)*(2*n_max+1)),dtype=float)
    dalphas = np.zeros(alphas.shape,dtype=float)
    popts = np.zeros((len(Eacs),len(basis)*(2*n_max+1),4),dtype=float)
    perrs = np.zeros(popts.shape,dtype=float)
    es0s = np.zeros((len(Eacs),len(basis_n),len(basis_n)),dtype=float)
    # failed = False

    for i, eac in enumerate(Eacs):
        fields["theta"] = theta
        fields["Eac"] = eac
        try:
            popts[i], perrs[i], es0s[i] = check_polarizability(basis, H0, Emax, samples, comp, fields, energy_bands)
        except ValueError:
            print(f"WARNING: failed to check polarizability for field values :\n{fields}")
        # If one point fails try again on the next point. If two in a row fail, raise the error
            popts[i] = np.NaN
            perrs[i] = np.NaN
            es0s[i] = np.NaN

    alphas = popts[...,1]
    dalphas = perrs[...,1]
    popts_eac = popts
    perrs_eac = perrs
    es0s_eac = es0s

if sweep_theta:
    alphas = np.zeros((len(thetas),len(basis)*(2*n_max+1)),dtype=float)
    dalphas = np.zeros(alphas.shape,dtype=float)
    popts = np.zeros((len(thetas),len(basis)*(2*n_max+1),4),dtype=float)
    perrs = np.zeros(popts.shape,dtype=float)
    es0s = np.zeros((len(Eacs),len(basis_n),len(basis_n)),dtype=float)
    # failed = False

    for i, th in enumerate(thetas):
        fields["Eac"] = Eac
        fields["theta"] = th
        try:
            popts[i], perrs[i], es0s[i] = check_polarizability(basis, H0, Emax, samples, comp, fields, energy_bands)
        except ValueError:
            print(f"WARNING: failed to check polarizability for field values :\n{fields}")
        # If one point fails try again on the next point. If two in a row fail, raise the error
            popts[i] = np.NaN
            perrs[i] = np.NaN
            es0s[i] = np.NaN

    popts_theta2 = popts
    perrs_theta2 = perrs
    es0s_theta2 = es0s

In [None]:
raise RuntimeError  # we dont want these to run unless the data being saved is known to be good.
#fname = "mid_det"
fname = "low"
results_dir = r"Shirley-Floquet_results\PolarizabilityPlots"
np.save(os.path.join(results_dir,fname+"popts_eac.npy"),popts_eac)
np.save(os.path.join(results_dir,fname+"perrs_eac.npy"),perrs_eac)
np.save(os.path.join(results_dir,fname+"eac.npy"),Eacs)
np.save(os.path.join(results_dir,fname+"e0s.npy"),es0s_eac)

np.save(os.path.join(results_dir,fname+"thetas.npy"),thetas)
np.save(os.path.join(results_dir,fname+"popts_theta.npy"),popts_theta)
np.save(os.path.join(results_dir,fname+"perrs_theta.npy"),perrs_theta)
np.save(os.path.join(results_dir,fname+"e0s_theta.npy"),es0s_theta)

#np.save(os.path.join(results_dir,fname+"popts_theta2.npy"),popts_theta2)
#np.save(os.path.join(results_dir,fname+"perrs_theta2.npy"),perrs_theta2)
#np.save(os.path.join(results_dir,fname+"e0s_theta2.npy"),es0s_theta2)

## Plot polarizability scan results
### Alpha

In [None]:
fig, ax = plt.subplots(1,1, figsize=(6,4))
band_c = 0
scale=167
x_ac = 20*np.log10(Eacs/scale)
x_ac = Eacs
for i, state in zip(p_inds,t_states):
    #if state == t_level:  # check n,l,j quantum numbers
    print(i,state.ket())
    k = n_max*len(basis)+i
    band_c +=1
    lab = f"band {band_c}" if energy_bands else basis_n[k].ket()
    if sweep_ac:
        ax.errorbar(x_ac, popts_eac[:,k,1]*1e-3/tau, label=lab,yerr=perrs_eac[:,k,1]*1e-3/tau,fmt=".")
    elif sweep_theta:
        ax.errorbar(thetas/pi, popts_theta[:,k,1]*1e-3/tau, label=lab,yerr=perrs_theta[:,k,1]*1e-3/tau,fmt=".")
ax.legend(loc=1,prop={"size":9})
ax.axhline(0,ls=":")
#ax.set_ylim(0,max(np.abs(alphas[0,:]))*1e-3/tau)
#ax.set_yscale("log")
if sweep_ac:
    ax.set_xlabel("AC field strength (V/m)")
    ax.set_title(f"Polarizabilities with field params:\n$\epsilon$ = {ellipticity}, $\omega_d$ = $2\pi \cdot${field_omega*1e-6/tau:.1f}MHz, $\\theta$ = {theta/pi}$\cdot\pi$")
elif sweep_theta:
    ax.set_xlabel("$\\theta/\pi$")
    ax.set_title(f"Polarizabilities with field params:\n$E_{{ac}}$ = {Eac}$V/m$, $\epsilon$ = {ellipticity}, $\omega_d$ = $2\pi \cdot${field_omega*1e-6/tau:.1f}MHz")
    
ax.set_ylabel("$\\alpha$'/$2\pi$ $KHz/(V/m)^2$")
#ax.set_ylabel("$\\alpha'/\\alpha$ (%)")
for item in [ax.xaxis.label, ax.yaxis.label, ax.title]:
    item.set_fontsize(14)
    
fig.tight_layout()
fig.show()
omega_for_humans = int(field_omega*1e-6/tau)
fig.savefig(f"polarizabilities_omegad-{omega_for_humans}MHz_eps_{ellipticity:.3f}-around{Eacs.mean():.0f}Eac.png")
fig, ax = plt.subplots(1,1, figsize=(6,4))
band_c = 0
scale=167
x_ac = 20*np.log10(Eacs/scale)
x_ac = Eacs
for i, state in zip(p_inds,t_states):
    #if state == t_level:  # check n,l,j quantum numbers
    print(i,state.ket())
    k = n_max*len(basis)+i
    band_c +=1
    lab = f"band {band_c}" if energy_bands else basis_n[k].ket()
    if sweep_ac:
        ax.errorbar(x_ac, popts_eac[:,k,3]/tau, label=lab,yerr=perrs_eac[:,k,3]/tau,fmt=".")
    elif sweep_theta:
        ax.errorbar(thetas/pi, popts_theta[:,k,3]/tau, label=lab,yerr=perrs_theta[:,k,3]/tau,fmt=".")
ax.legend(loc=1,prop={"size":9})
ax.axhline(0,ls=":")
#ax.set_ylim(0,max(np.abs(alphas[0,:]))*1e-3/tau)
#ax.set_yscale("log")
if sweep_ac:
    ax.set_xlabel("AC field strength (V/m)")
    ax.set_title(f"Polarizabilities with field params:\n$\epsilon$ = {ellipticity}, $\omega_d$ = $2\pi \cdot${field_omega*1e-6/tau:.1f}MHz, $\\theta$ = {theta/pi}$\cdot\pi$")
elif sweep_theta:
    ax.set_xlabel("$\\theta/\pi$")
    ax.set_title(f"Polarizabilities with field params:\n$E_{{ac}}$ = {Eac}$V/m$, $\epsilon$ = {ellipticity}, $\omega_d$ = $2\pi \cdot${field_omega*1e-6/tau:.1f}MHz")
    
ax.set_ylabel("$\\alpha$'/$2\pi$ $KHz/(V/m)^2$")
ax.set_ylabel("$\\beta$'/$2\pi$ $Hz/(V/m)^2$")
#ax.set_ylabel("$\\alpha'/\\alpha$ (%)")
for item in [ax.xaxis.label, ax.yaxis.label, ax.title]:
    item.set_fontsize(14)
    
fig.tight_layout()
fig.show()
omega_for_humans = int(field_omega*1e-6/tau)
#fig.savefig(f"polarizabilities_omegad-{omega_for_humans}MHz_eps_{ellipticity:.3f}-around{Eacs.mean():.0f}Eac.png")

In [None]:
fig, ax = plt.subplots(1,1, figsize=(6,4))
band_c = 0
scale=167
x_ac = 20*np.log10(Eacs/scale)
x_ac = Eacs
for i, state in enumerate(basis):
    if state == t_level:  # check n,l,j quantum numbers
        band_c +=1
        lab = f"band {band_c}" if energy_bands else basis_n[n_max*len(basis)+i].ket()
        if sweep_ac:
            ax.errorbar(x_ac, popts_eac[:,n_max*len(basis)+i,3]/tau, label=lab,yerr=perrs_eac[:,n_max*len(basis)+i,3]/tau,fmt=".")
        elif sweep_theta:
            ax.errorbar(thetas/pi, popts_theta[:,n_max*len(basis)+i,3]/tau, label=lab,yerr=perrs_theta[:,n_max*len(basis)+i,3]/tau,fmt=".")
ax.legend(loc=4,prop={"size":9})
ax.axhline(0,ls=":")
#ax.set_ylim(0,max(np.abs(alphas[0,:]))*1e-3/tau)
#ax.set_yscale("log")
if sweep_ac:
    ax.set_xlabel("AC field strength (V/m)")
    ax.set_title(f"Polarizabilities with field params:\n$\epsilon$ = {ellipticity}, $\omega_d$ = $2\pi \cdot${field_omega*1e-6/tau:.1f}MHz, $\\theta$ = {theta/pi}$\cdot\pi$")
elif sweep_theta:
    ax.set_xlabel("$\\theta/\pi$")
    ax.set_title(f"Polarizabilities with field params:\n$E_{{ac}}$ = {Eac}$V/m$, $\epsilon$ = {ellipticity}, $\omega_d$ = $2\pi \cdot${field_omega*1e-6/tau:.1f}MHz")
    
ax.set_ylabel("$\\beta$'/$2\pi$ $Hz/(V/m)^4$")
#ax.set_ylabel("$\\alpha'/\\alpha$ (%)")
for item in [ax.xaxis.label, ax.yaxis.label, ax.title]:
    item.set_fontsize(14)
    
fig.tight_layout()
fig.show()
omega_for_humans = int(field_omega*1e-6/tau)
fig.savefig(f"polarizabilities_omegad-{omega_for_humans}MHz_eps_{ellipticity:.3f}-around{Eacs.mean():.0f}Eac.png")

### Plot polarizabilities in mid case

In [None]:
thrs=2e-2
fig,axar = plt.subplots(2,len(p_inds),figsize=(4.5*len(p_inds),3.5*2))
count = 0
zero_x = [27.20,30.55,zerx[2:].mean()]
cv = ["black","tab:blue","tab:orange"]
alpha0 = popts_eac[0,np.array(p_inds),1].mean()*1e-3/tau

for i,t in enumerate(p_inds):
    for k in range(len(basis_n))[::-1]:
        if any(np.abs(es0s_eac[:,t,k])**2>thrs):
            color = f"C{count}"
            #color = f"C{i}"
            alphs = np.abs(es0s_eac[:,t,k])**2
            ax=axar[0,0]
            ydat = popts_eac[:,k,1]*1e-3/tau
            yerr = perrs_eac[:,k,1]*1e-3/tau
            ax.set_ylabel("$\\alpha$'/$2\pi$ $kHz/(V/m)^2$")
            for it in range(len(Eacs))[:]:
                kt = basis_n[k]
                ls = {0:"S",1:"P",2:"D",3:"F"}
                lab = f"|{kt.n},{ls[kt.l]},{int(2*kt.j)}/2,{int(2*kt['mj'])}/2,m={kt['nphot']}>'"
                lab = "" if (it > 0) else lab
                alph = alphs[it] if (it>0) else 1
                if abs(alphs[it] - alph) > 0.5:
                    xdat = [-5]
                else:
                    xdat = Eacs[it:it+1]
                ax.errorbar(xdat,ydat[it:it+1],yerr=yerr[it:it+1],fmt="os"[i], ms=3,c=color,alpha=alph,label=lab)
            #ax.axhline(0,ls=":",c="black")
            #ax.vlines(zero_x,*ax.get_ylim(),linewidth=1,ls=":")
            
            ax=axar[1,0]
            ydat = popts_eac[:,k,3]/tau
            yerr = perrs_eac[:,k,3]/tau
            ax.set_ylabel("$\\beta$'/$2\pi$ $Hz/(V/m)^4$")
            for it in range(len(Eacs))[:]:
                kt = basis_n[k]
                ls = {0:"S",1:"P",2:"D",3:"F"}
                lab = f"|{kt.n},{ls[kt.l]},{int(2*kt.j)}/2,{int(2*kt['mj'])}/2,m={kt['nphot']}>'"
                lab = "" if (it > 0) else lab
                alph = alphs[it] if (it>0) else 1
                if abs(alphs[it] - alph) > 0.5:
                    xdat = [-5]
                else:
                    xdat = Eacs[it:it+1]
                ax.errorbar(xdat,ydat[it:it+1],yerr=yerr[it:it+1],fmt="os"[i], ms=3,c=color,alpha=alph,label=lab)
            lm = ax.get_ylim()
            #ax.vlines(zero_x,*ax.get_ylim(),linewidth=1,ls=":")
            #ax.set_ylim(*lm)
            #ax.axhline(0,ls=":",c="black")            
            count+=1
            
t = p_inds[0]
x_data = thetas/pi*180

y_data = popts_theta[:,t,1]*1e-3/tau
y_err = perrs_theta[:,t,1]*1e-3/tau
axar[0,1].errorbar(x_data,y_data,yerr=y_err,ms=3, fmt="os"[i],label=f"$E_{{AC}}={zerox[2]:.1f}V/m$",c=cv[1])
y_data = popts_theta[:,t,3]/tau
y_err = perrs_theta[:,t,3]/tau
axar[1,1].errorbar(x_data,y_data,yerr=y_err,ms=3, fmt="os"[i],label=f"$E_{{AC}}={zerox[2]:.1f}V/m$",c=cv[1])

y_data = popts_theta2[:,t,1]*1e-3/tau
y_err = perrs_theta2[:,t,1]*1e-3/tau
axar[0,1].errorbar(x_data,y_data,yerr=y_err,ms=3, fmt="os"[i],label=f"$E_{{AC}}={zerox[3]:.1f}V/m$",c=cv[2])
y_data = popts_theta2[:,t,3]/tau
y_err = perrs_theta2[:,t,3]/tau
axar[1,1].errorbar(x_data,y_data,yerr=y_err,ms=3, fmt="os"[i],label=f"$E_{{AC}}={zerox[3]:.1f}V/m$",c=cv[2])
axar[1,1].set_ylabel("$\\beta$'/$2\pi$ $Hz/(V/m)^4$")

for ax in axar[:,1]:
    ax.axhline(0,c="black",ls=":")
    ax.legend(prop={"size":8})

for ax in axar[:,0]:
    cv = ["black","tab:blue","tab:orange"]
    lm = ax.get_ylim()
    ax.vlines(zero_x,*ax.get_ylim(),linewidth=1,ls=":",colors=cv)
    ax.set_ylim(*lm)
    ax.axhline(0,ls=":",c="black")            
    
axar[0,0].set_xlim(0,axar[0,0].get_xlim()[1])
axar[0,0].get_shared_x_axes().join(axar[0,0], axar[1,0])
axar[0,0].set_xticklabels([])
axar[1,0].set_xlim(axar[0,0].get_xlim())
axar[1,0].set_xlabel("$E_{AC}$ $(V/m)$")

#axar[0,1].set_xlim(0,axar[0,0].get_xlim()[1])
axar[0,1].get_shared_x_axes().join(axar[0,1], axar[1,1])
axar[0,1].set_xticklabels([])
axar[1,1].set_xlim(axar[0,1].get_xlim())
axar[1,1].set_xlabel("$\\theta$ (degrees)")

#axar[0,0].axhline(alpha0,ls=":")
#axar[0,1].axhline(alpha0*0.2,ls=":")

for ax in axar[0,:]:
    ax.set_ylabel("$\\alpha'$ ($MHz/(V/cm)^2$)")
    #ax.set_ylabel("$-\\alpha'$ $(KHz/(V/m)^2)$")
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')

    #ax.axhline(-alpha0*1e-3/tau)
    ax2 = ax.twinx()
    #for k in basis_t:
    #    ax2.plot(x_ac,100*popts_eac[:,n_max*len(basis)+k,1]/alpha0)
    low,high = ax.get_ylim()
    #ax2.set_ylim(-100*low/(alpha0*1e-3/tau),-100*high/(alpha0*1e-3/tau))
    if ax == axar[0,1]:
        #ax.axhline(alpha0*1e-3/tau)
        ax2.set_yticks(np.arange(-50,60,10))
        ax2.set_ylim(100*low/(alpha0),100*high/(alpha0))
        ax2.set_ylabel("$\\alpha'/\\alpha_0$ (%)")
    else:
        ax2.set_yticks(np.arange(-200,150,25))
        ax2.set_ylim(100*low/(alpha0)+48,100*high/(alpha0))


    for item in [ax.xaxis.label, ax.yaxis.label, ax.title]:
        item.set_fontsize(12)    
    for item in [ax2.xaxis.label, ax2.yaxis.label, ax2.title]:
        item.set_fontsize(12)    


axar[0,0].set_ylim(-100,200)
axar[0,0].legend(prop={"size":8})
#axar[1,0].legend(prop={"size":8})
fig.tight_layout()
fig.show()
fig.savefig("Plots\MidDetuning.png")

### find and plot zero-crossings
#### mid case

In [None]:
fun = interp1d(Eacs,popts[:,p_inds[0],1]*1e-3/tau,kind="quadratic")
xlin=np.linspace(min(Eacs),max(Eacs),50000)
fig,ax=plt.subplots(1,1)
ax.plot(xlin,fun(xlin))
ax.vlines(xlin[np.argwhere(np.abs(fun(xlin))<1e-2)],*ax.get_ylim())
fig.show()

zerx = xlin[np.argwhere(np.abs(fun(xlin))<1e-2)][:,0]
zerx[2:].mean()
zero_x = [27.20,30.55,zerx[2:].mean()]

#### simple cases

In [None]:
fig, ax = plt.subplots(1,1, figsize=(6,4))
band_c = 0
scale=167
x_ac = 20*np.log10(Eacs/scale)
x_ac = Eacs

basis_t = np.array([i+len(basis)*n_max for i, state in enumerate(basis) if [state.n, state.l, state.j] == [t_level.n, t_level.l, t_level.j]])
colors = ["blue","orange","green","red"]
for i,k in enumerate(basis_t):
    x_dat = Eacs[2:]
    y_dat = 1e-3*popts_eac[2:,k,1]/tau
    y_err = 1e-3*perrs_eac[2:,k,1]/tau
    good_inds = np.argwhere(np.invert(np.isnan(y_dat)))[:,0]
    y_dat = y_dat[good_inds]
    y_err = y_err[good_inds]
    x_dat = x_dat[good_inds]
    
#    y_dat[1]=np.NaN
#    y_err[1]=np.NaN
    
    ax.errorbar(x_dat,y_dat,yerr=y_err,fmt=".",c=colors[i],label=f"band {i}")
    ax.axhline(0,ls=":")
    cal_fun = interp1d(y_dat,x_dat,kind="quadratic")
    xlin=np.linspace(min(y_dat),max(y_dat),1000)
    ax.plot(cal_fun(xlin),xlin,colors[i])
    #zerox = np.interp(0,y_dat,x_dat)
    print(f"band {i} zero-crossing at Eac = {cal_fun(0)} V/m")
    ax.axvline(cal_fun(0),ls=":",c=colors[i])
ax.legend()
fig.show()

In [None]:
fig, axar = plt.subplots(2,1, figsize=(4.5,6.5))
band_c = 0
scale=167
x_ac = 20*np.log10(Eacs/scale)
x_ac = Eacs
basis_t = np.array([i+len(basis)*n_max for i, state in enumerate(basis) if [state.n, state.l, state.j] == [t_level.n, t_level.l, t_level.j]])
alpha0 = popts_eac[0,n_max*len(basis)+basis_t,1].mean()
for i, state in enumerate(basis):
    if state["mj"] < 0:
        continue
    if state == t_level:  # check n,l,j quantum numbers
        band_c +=1
        lab = f"band {band_c}" if energy_bands else basis_n[n_max*len(basis)+i].ket()
        lab=""
        axar[0].errorbar(x_ac, popts_eac[:,n_max*len(basis)+i,1]*1e-3/tau, label=lab,yerr=perrs_eac[:,n_max*len(basis)+i,1]*1e-3/tau,fmt=".")
        axar[1].errorbar(thetas/pi*180, popts_theta[:,n_max*len(basis)+i,1]*1e-3/tau, label=lab,yerr=perrs_theta[:,n_max*len(basis)+i,1]*1e-3/tau,fmt=".")
axar[0].legend(loc=1,prop={"size":8})
axar[1].legend(loc="upper center",prop={"size":8})
axar[0].axhline(0,ls=":")
axar[1].axhline(0,ls=":")
#ax.set_ylim(0,max(np.abs(alphas[0,:]))*1e-3/tau)
#ax.set_yscale("log")
axar[0].set_xlabel("$E_{AC}$ (V/cm)")
axar[1].set_xlabel("$\\theta$ (Degrees)")
axar[1].set_xticks(np.arange(-90,114,45))

y_vals = popts_eac[:,n_max*len(basis)+basis_t,1]*1e-3
axar[0].set_yticks(np.arange(-50,150,50))
y_vals = popts_theta[:,n_max*len(basis)+basis_t,1]*1e-3
axar[1].set_yticks(np.arange(-0,80,20))
axar[0].axvline(zerox,ls=":",c="tab:blue")

axar[0].set_ylim(-50,150)
axar[1].set_ylim(-5,50)

for ax in axar:
    # ax.set_ylabel("$\\alpha$'/$2\pi$ $KHz/(V/m)^2$")
    ax.set_ylabel("$\\alpha'$ ($MHz/(V/cm)^2$)")
    #ax.set_ylabel("$-\\alpha'$ $(KHz/(V/m)^2)$")
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    
    #ax.axhline(-alpha0*1e-3/tau)
    ax2 = ax.twinx()
    #for k in basis_t:
    #    ax2.plot(x_ac,100*popts_eac[:,n_max*len(basis)+k,1]/alpha0)
    low,high = ax.get_ylim()
    #ax2.set_ylim(-100*low/(alpha0*1e-3/tau),-100*high/(alpha0*1e-3/tau))
    if ax == axar[0]:
        #ax.axhline(alpha0*1e-3/tau)
        ax2.set_yticks(np.arange(-50,150,50))
        ax2.set_ylim(100*low/(alpha0*1e-3/tau),100*high/(alpha0*1e-3/tau))
    elif ax == axar[1]:
        #ax.axhline(0.3*alpha0*1e-3/tau)
        ax2.set_yticks(np.arange(0,50,10))
        ax2.set_ylim(100*low/(alpha0*1e-3/tau),100*high/(alpha0*1e-3/tau))
    ax2.set_ylabel("$\\alpha'/\\alpha_0$ (%)")
    
    for item in [ax.xaxis.label, ax.yaxis.label, ax.title]:
        item.set_fontsize(12)    
    for item in [ax2.xaxis.label, ax2.yaxis.label, ax2.title]:
        item.set_fontsize(12)    
    
#axar[1].set_yticks(np.arange(0,75,25))
    
fig.tight_layout()
fig.show()
omega_for_humans = int(field_omega*1e-6/tau)
fig.savefig(f"Low-PolarizabilityAnisotropy_TwoAxes.png")
fig.savefig(f"polarizabilities_omegad-{omega_for_humans}MHz_eps_{ellipticity:.3f}-around{Eacs.mean():.0f}Eac.png")

In [None]:
fig, axar = plt.subplots(2,1, figsize=(4.5,6.5))
band_c = 0
scale=167
x_ac = 20*np.log10(Eacs/scale)
x_ac = Eacs
basis_t = np.array([i+len(basis)*n_max for i, state in enumerate(basis) if [state.n, state.l, state.j] == [t_level.n, t_level.l, t_level.j]])
alpha0 = popts_eac[0,n_max*len(basis)+basis_t,1].mean()
popts_eac[1,:] = np.NaN
for i, state in enumerate(basis):
    if state["mj"] < 0:
        continue
    if state == t_level:  # check n,l,j quantum numbers
        band_c +=1
        lab = f"band {band_c}" if energy_bands else basis_n[n_max*len(basis)+i].ket()
        lab = ""
        axar[0].errorbar(x_ac, popts_eac[:,n_max*len(basis)+i,3]/tau, label=lab,yerr=perrs_eac[:,n_max*len(basis)+i,3]/tau,fmt=".")
        axar[1].errorbar(thetas/pi*180, popts_theta[:,n_max*len(basis)+i,3]/tau, label=lab,yerr=perrs_theta[:,n_max*len(basis)+i,3]/tau,fmt=".")
axar[0].legend(loc=4,prop={"size":8})
axar[1].legend(loc="upper center",prop={"size":8})
axar[0].axhline(0,ls=":")
axar[1].axhline(0,ls=":")
#ax.set_ylim(0,max(np.abs(alphas[0,:]))*1e-3/tau)
#ax.set_yscale("log")
axar[0].set_xlabel("$E_{AC}$ (V/m)")
axar[1].set_xlabel("$\\theta$ (Degrees)")
axar[1].set_xticks(np.arange(-90,114,45))
#axar[0].set_ylim(-60,210)
#axar[1].set_ylim(-30,100)
axar[0].axvline(zerox,ls=":",c="tab:blue")


for ax in axar:
    # ax.set_ylabel("$\\alpha$'/$2\pi$ $KHz/(V/m)^2$")
    ax.set_ylabel("$\\beta'$ ($Hz/(V/m)^4$)")
    #ax.set_ylabel("$-\\alpha'$ $(KHz/(V/m)^2)$")
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')
    
    #ax.axhline(-alpha0*1e-3/tau)
    #ax2 = ax.twinx()
    #for k in basis_t:
    #    ax2.plot(x_ac,100*popts_eac[:,n_max*len(basis)+k,1]/alpha0)
    low,high = ax.get_ylim()
    #ax2.set_ylim(-100*low/(alpha0*1e-3/tau),-100*high/(alpha0*1e-3/tau))
    if ax == axar[0]:
        pass
        #ax.axhline(alpha0*1e-3/tau)
        #ax2.set_yticks(np.arange(-50,150,50))
        #ax2.set_ylim(100*low/(alpha0*1e-3/tau),100*high/(alpha0*1e-3/tau))
    elif ax == axar[1]:
        pass
        #ax.axhline(alpha0*1e-3/tau/2)
        #ax2.set_yticks(np.arange(-25,75,25))
        #ax2.set_ylim(100*low/(alpha0*1e-3/tau),100*high/(alpha0*1e-3/tau))
    #ax2.set_ylabel("$\\beta'/\\beta_0$ (%)")
    
    for item in [ax.xaxis.label, ax.yaxis.label, ax.title]:
        item.set_fontsize(12)    
    for item in [ax2.xaxis.label, ax2.yaxis.label, ax2.title]:
        item.set_fontsize(12)    
    
    
y_vals = popts_eac[:,n_max*len(basis)+basis_t,1]*1e-3
#axar[0].set_yticks(np.linspace(-50,200,6))
y_vals = popts_theta[:,n_max*len(basis)+basis_t,1]*1e-3
#axar[1].set_yticks(np.linspace(-25,100,6))
#axar[1].set_yticks(np.arange(0,75,25))
    
fig.tight_layout()
fig.show()
omega_for_humans = int(field_omega*1e-6/tau)
fig.savefig(f"Plots\\Betas_{omega_for_humans}MHz.png")
#fig.savefig(f"AvoidedX-PolarizabilityAnisotropy_TwoAxes.png")
fig.savefig(f"polarizabilities_omegad-{omega_for_humans}MHz_eps_{ellipticity:.3f}-around{Eacs.mean():.0f}Eac.png")

In [None]:
results_dir = r"Shirley-Floquet_results\PolarizabilityPlots"
#np.save(os.path.join(results_dir,"popts_eac.npy"),popts_eac)
#np.save(os.path.join(results_dir,"perrs_eac.npy"),perrs_eac)
#np.save(os.path.join(results_dir,"popts_theta.npy"),popts_theta)
#np.save(os.path.join(results_dir,"perrs_theta.npy"),perrs_theta)

In [None]:
Eacs = np.linspace(0,100,21)
thetas = np.linspace(-pi/2,pi/2,21)
popts_eac = np.load(os.path.join(results_dir,"popts_eac.npy"))
perrs_eac = np.load(os.path.join(results_dir,"perrs_eac.npy"))
popts_theta = np.load(os.path.join(results_dir,"popts_theta.npy"))
perrs_theta = np.load(os.path.join(results_dir,"perrs_theta.npy"))

In [None]:
fig, ax = plt.subplots(1,1, figsize=(4.5,4))
band_c = 0
for i, state in enumerate(basis):
    if state == t_level:  # check n,l,j quantum numbers
        band_c +=1
        lab = f"band {band_c}" if energy_bands else basis_n[n_max*len(basis)+i].ket()
        if sweep_ac:
            ax.errorbar(Eacs, popts_eac[:,n_max*len(basis)+i,3]/tau, label=lab,yerr=perrs_eac[:,n_max*len(basis)+i,3]/tau,fmt=".")
        elif sweep_theta:
            ax.errorbar(180*thetas/pi, popts_theta[:,n_max*len(basis)+i,3]/tau, label=lab,yerr=perrs[:,n_max*len(basis)+i,3]/tau,fmt=".")
ax.legend(loc=3,prop={"size":9})
ax.axhline(0,ls=":")
#ax.set_ylim(0,max(np.abs(alphas[0,:]))*1e-3/tau)
#ax.set_yscale("log")
if sweep_ac:
    ax.set_xlabel("AC field strength (V/m)")
   # ax.set_title(f"Polarizabilities with field params:\n$\epsilon$ = {ellipticity}, $\omega_d$ = $2\pi \cdot${field_omega*1e-6/tau:.1f}MHz, $\\theta$ = {theta/pi}$\cdot\pi$")
elif sweep_theta:
    ax.set_xlabel("$\\theta$ (Degrees)")
   # ax.set_title(f"Polarizabilities with field params:\n$E_{{ac}}$ = {Eac}$V/m$, $\epsilon$ = {ellipticity}, $\omega_d$ = $2\pi \cdot${field_omega*1e-6/tau:.1f}MHz")
    
ax.yaxis.set_ticks_position('left')
ax.xaxis.set_ticks_position('bottom')
ax.set_ylabel("$\\beta$ $(Hz/(V/m)^4)$")
for item in [ax.xaxis.label, ax.yaxis.label, ax.title]:
    item.set_fontsize(12)
    
fig.tight_layout()
fig.show()
omega_for_humans = int(field_omega*1e-6/tau)
fig.savefig(r"Plots/beta_theta")

In [None]:
fig, ax = plt.subplots(1,1, figsize=(6,4))
band_c = 0
for i, state in enumerate(basis):
    if state == t_level:  # check n,l,j quantum numbers
        band_c +=1
        lab = t_level.ket()+f"band {band_c}" if energy_bands else basis_n[n_max*len(basis)+i].ket()
        if sweep_ac:
            ax.plot(Eacs, perrs[:,n_max*len(basis)+i,3], label=lab)
        elif sweep_theta:
            ax.errorbar(thetas/pi, 100*popts[:,n_max*len(basis)+i,3]*1e-3/tau, label=lab,yerr=perrs[:,n_max*len(basis)+i,3]*1e-3/tau,fmt=".")
ax.legend(loc=1,prop={"size":9})
ax.axhline(0,ls=":")
#ax.set_ylim(0,max(np.abs(alphas[0,:]))*1e-3/tau)
#ax.set_yscale("log")
if sweep_ac:
    ax.set_xlabel("AC field strength (V/m)")
    ax.set_title(f"Polarizabilities with field params:\n$\epsilon$ = {ellipticity}, $\omega_d$ = $2\pi \cdot${field_omega*1e-6/tau:.1f}MHz, $\\theta$ = {theta/pi}$\cdot\pi$")
elif sweep_theta:
    ax.set_xlabel("$\\theta/\pi$")
    ax.set_title(f"Polarizabilities with field params:\n$E_{{ac}}$ = {Eac}$V/m$, $\epsilon$ = {ellipticity}, $\omega_d$ = $2\pi \cdot${field_omega*1e-6/tau:.1f}MHz")
    
ax.set_ylabel("$\\alpha$'/$2\pi$ $KHz/(V/m)^2$")
ax.set_ylabel("$\\beta (KHz/V^4)$")
ax.set_yscale("log")
for item in [ax.xaxis.label, ax.yaxis.label, ax.title]:
    item.set_fontsize(14)
    
fig.tight_layout()
fig.show()
omega_for_humans = int(field_omega*1e-6/tau)
fig.savefig(f"polarizabilities_omegad-{omega_for_humans}MHz_eps_{ellipticity:.3f}-around{Eacs.mean():.0f}Eac.png")

# Plot_1Ds
## plot all relevant Energies and one overlap

In [None]:
fig,axar = plt.subplots(2,1)
for t in t_inds:
    k = t + n_max*len(basis)
    if basis_n[k]["mj"]<0:
        continue
    label = basis_n[t].ket()

    for j in range(len(basis_n)):
        label = basis_n[j].ket()
        if any(np.abs(eigenstates[k,j,:])**2 > 1e-2):
            ax=axar[0]
            ax.plot(varied[1],1e-6*energies[j,:]/tau,label=label)
            ax=axar[1]        
            ax.plot(varied[1],np.abs(eigenstates[k,j,:])**2,label=label)

#    ax.set_yscale("log")
for ax in axar:
    ax.legend()
fig.show()

In [None]:
thrs = 1e-2# threshold for caring about population
thrs2 = 1e-2
fig,axar = plt.subplots(3,1,figsize=(4.5,7),gridspec_kw={'height_ratios': [2.5, 1.6, 1.5]})
iis = []
interesting_inds = []
e_inds = []
for i, state in enumerate(basis):
#    if [state.n,state.l,state.j,state["mj"]] == [t_level.n, t_level.l, t_level.j,3/2]:
    if [state.n,state.l,state.j] == [t_level.n, t_level.l, t_level.j]:
        iis.append(i+len(basis)*n_max)
for ii in iis:
    for k, p in enumerate(eigenstates[:,ii,:]):
        if max(np.absolute(p)**2)>thrs:
            interesting_inds.append(k)
        if max(np.absolute(p)**2)>thrs2:
            e_inds.append(k)
interesting_inds = list(set(interesting_inds))
for k in interesting_inds:
    kt = basis_n[k]
    ls = {0:"S",1:"P",2:"D",3:"F"}
    lab = f"|{kt.n},{ls[kt.l]},{int(2*kt.j)}/2,{int(2*kt['mj'])}/2,m={kt['nphot']}>'"
    if basis_n[k]["mj"] < 0:
        continue
    if theta_list is None:
        if basis_n[k]["mj"] is not None:
            if k in e_inds:
                axar[0].plot(varied[1], 1e-6*np.real(energies[k,:]-0*energies[k,0])/tau, label=lab)
            y_st = np.absolute(eigenstates[k,ii-1,:])**2
            if any(y_st > thrs):
                axar[1].plot(varied[1], y_st, label=lab)
            y_st = np.absolute(eigenstates[k,ii,:])**2
            if any(y_st > thrs):
                axar[2].plot(varied[1], y_st, label=lab)
        else:
            axar[0].plot(varied[1], 1e-6*np.real(energies[k,:]-0*energies[k,0])/tau) #, label=basis_n[k].ket())
            axar[1].plot(varied[1], np.absolute(eigenstates[k,ii-1,:])**2) # , label=basis_n[k].ket())

    else:
        if basis_n[k]["mj"] is not None:
            axar[0].plot(varied[1]/pi, 1e-6*np.real(energies[k,:]-0*energies[k,0])/tau, label=basis_n[k].ket())
            axar[1].plot(varied[1]/pi, np.absolute(eigenstates[k,ii,:])**2, label=basis_n[k].ket())
        else:
            axar[0].plot(varied[1]/pi, 1e-6*np.real(energies[k,:]-0*energies[k,0])/tau) #, label=basis_n[k].ket())
            axar[1].plot(varied[1]/pi, np.absolute(eigenstates[k,ii,:])**2)# , label=basis_n[k].ket())            
    #print(f"state = {basis[k%len(basis)].ket()}, dE = 2pix{1e-6*np.real(Energies[k,-1]-Energies[k,0])/tau}MHz")
#axar[0].set_title("Change in Stark Shifts\n$E_{ac}$ = "+f"{Eac}V/m\n" + "$E_{dc}$ = " + f"{Edc}V/m")
#axar[0].set_ylim(-1.05e3,300)
#axar[1].set_title(f"Overlap with {basis_n[ii-1].ket()}")
#axar[1].set_yscale("log")
axar[1].set_ylabel("$P_{tl;1/2}$")
axar[2].set_ylabel("$P_{tl;3/2}$")
#axar[1].set_ylim(thrs*0.8,1)
if Edc_list is not None:
    axar[0].set_ylabel("DC Stark Shift/$2\pi$ (MHz)")
    axar[0].set_xlabel("DC Electric Field strength (V/m)")
    axar[1].set_xlabel("DC Electric Field strength (V/m)")
elif Eac_list is not None:
    #axar[0].set_title(f"Energy Levels\n$\omega_D = 2\pi\cdot{1e-6*field_omega/tau:.0f}$; $\epsilon$ = {ellipticity}")#"\n $E_{{dc}}$ = {Edc}V/m")
    axar[0].set_ylabel("$\\nu-\\nu_t$ (MHz)")
    #axar[0].set_xlabel("$E_{AC}$ (V/m)")
    axar[2].set_xlabel("$E_{AC}$ (V/m)")
elif theta_list is not None:
    #axar[0].set_ylabel("Change in DC Stark Shift/$2\pi$ (MHz)")
    axar[0].set_xlabel("$\\theta/\pi$ (radians)")
    axar[1].set_xlabel("Angle Between AC and DC fields/$\pi$/ (radians)")
axar[0].legend(loc=3,prop={"size":8})
axar[1].legend(loc=1,prop={"size":8})
axar[2].legend(loc=1,prop={"size":8})

axar[0].get_shared_x_axes().join(axar[0], axar[1])
axar[0].get_shared_x_axes().join(axar[1], axar[2])
axar[0].set_xticklabels([])
axar[1].set_xticklabels([])
axar[2].set_xlim(axar[0].get_xlim())
for ax in axar:
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')

    for item in [ax.xaxis.label, ax.yaxis.label, ax.title]:
        item.set_fontsize(14)
#axar[1].legend()
fig.tight_layout()
fig.savefig("Plots/AvoidedXEac.png")
fig.show()

In [None]:
thrs = 0.2# threshold for caring about population
thrs2 = 0.5
fig,axar = plt.subplots(3,1,figsize=(4.5,7),gridspec_kw={'height_ratios': [2.5, 1.6, 1.5]})
iis = []
interesting_inds = []
e_inds = []
for i, state in enumerate(basis):
#    if [state.n,state.l,state.j,state["mj"]] == [t_level.n, t_level.l, t_level.j,3/2]:
    if [state.n,state.l,state.j] == [t_level.n, t_level.l, t_level.j]:
        iis.append(i+len(basis)*n_max)
for ii in iis:
    for k, p in enumerate(eigenstates[:,ii,:]):
        if max(np.absolute(p)**2)>thrs:
            interesting_inds.append(k)
        if max(np.absolute(p)**2)>thrs2:
            e_inds.append(k)
interesting_inds = list(set(interesting_inds))
count=0

for k in interesting_inds:
    kt = basis_n[k]
    ls = {0:"S",1:"P",2:"D",3:"F"}
    lab = f"|{kt.n},{ls[kt.l]},{int(2*kt.j)}/2,{int(2*kt['mj'])}/2,m={kt['nphot']}>'"
    if basis_n[k]["mj"] < 0:
        continue
    if theta_list is None:
        if basis_n[k]["mj"] is not None:
            if k in e_inds:
                color = f"C{count}"
                count+=1
                y_st = np.zeros(energies.shape[1],dtype=float)
                for ii in iis:
                    y_st += np.abs(eigenstates[k,ii,:])**2
                for i in range(len(energies[k,:]))[::2]:
                    if i == 0:
                        axar[0].plot(varied[1][i:i+2], 1e-6*np.real(energies[k,i:i+2]-0*energies[k,0])/tau, label=lab, alpha=y_st[i]*0+1,c=color)
                    else:
                        axar[0].plot(varied[1][i:i+2], 1e-6*np.real(energies[k,i:i+2]-0*energies[k,0])/tau, alpha=y_st[i],c=color)
            y_st = np.absolute(eigenstates[k,ii-1,:])**2
            if any(y_st > thrs):
                axar[1].plot(varied[1], y_st, label=lab)
            y_st = np.absolute(eigenstates[k,ii,:])**2
            if any(y_st > thrs):
                axar[2].plot(varied[1], y_st, label=lab)
        else:
            axar[0].plot(varied[1], 1e-6*np.real(energies[k,:]-0*energies[k,0])/tau) #, label=basis_n[k].ket())
            axar[1].plot(varied[1], np.absolute(eigenstates[k,ii-1,:])**2) # , label=basis_n[k].ket())

    else:
        if basis_n[k]["mj"] is not None:
            axar[0].plot(varied[1]/pi, 1e-6*np.real(energies[k,:]-0*energies[k,0])/tau, label=basis_n[k].ket())
            axar[1].plot(varied[1]/pi, np.absolute(eigenstates[k,ii,:])**2, label=basis_n[k].ket())
        else:
            axar[0].plot(varied[1]/pi, 1e-6*np.real(energies[k,:]-0*energies[k,0])/tau) #, label=basis_n[k].ket())
            axar[1].plot(varied[1]/pi, np.absolute(eigenstates[k,ii,:])**2)# , label=basis_n[k].ket())            
    #print(f"state = {basis[k%len(basis)].ket()}, dE = 2pix{1e-6*np.real(Energies[k,-1]-Energies[k,0])/tau}MHz")
#axar[0].set_title("Change in Stark Shifts\n$E_{ac}$ = "+f"{Eac}V/m\n" + "$E_{dc}$ = " + f"{Edc}V/m")
#axar[0].set_ylim(-1.05e3,300)
#axar[1].set_title(f"Overlap with {basis_n[ii-1].ket()}")
#axar[1].set_yscale("log")
axar[1].set_ylabel("$P_{tl;1/2}$")
axar[2].set_ylabel("$P_{tl;3/2}$")
#axar[1].set_ylim(thrs*0.8,1)
if Edc_list is not None:
    axar[0].set_ylabel("DC Stark Shift/$2\pi$ (MHz)")
    axar[0].set_xlabel("DC Electric Field strength (V/m)")
    axar[1].set_xlabel("DC Electric Field strength (V/m)")
elif Eac_list is not None:
    #axar[0].set_title(f"Energy Levels\n$\omega_D = 2\pi\cdot{1e-6*field_omega/tau:.0f}$; $\epsilon$ = {ellipticity}")#"\n $E_{{dc}}$ = {Edc}V/m")
    axar[0].set_ylabel("$\\nu-\\nu_t$ (MHz)")
    #axar[0].set_xlabel("$E_{AC}$ (V/m)")
    axar[2].set_xlabel("$E_{AC}$ (V/m)")
elif theta_list is not None:
    #axar[0].set_ylabel("Change in DC Stark Shift/$2\pi$ (MHz)")
    axar[0].set_xlabel("$\\theta/\pi$ (radians)")
    axar[1].set_xlabel("Angle Between AC and DC fields/$\pi$/ (radians)")
axar[0].legend(loc=3,prop={"size":8})
axar[1].legend(loc=1,prop={"size":8})
axar[2].legend(loc=1,prop={"size":8})

axar[0].get_shared_x_axes().join(axar[0], axar[1])
axar[0].get_shared_x_axes().join(axar[1], axar[2])
axar[0].set_xticklabels([])
axar[1].set_xticklabels([])
axar[2].set_xlim(axar[0].get_xlim())
for ax in axar:
    ax.yaxis.set_ticks_position('left')
    ax.xaxis.set_ticks_position('bottom')

    for item in [ax.xaxis.label, ax.yaxis.label, ax.title]:
        item.set_fontsize(14)
#axar[1].legend()
fig.tight_layout()
fig.savefig("Plots/AvoidedXEac.png")
fig.show()

In [None]:
thrs = 0.11 # threshold for caring about population
fig,axar = plt.subplots(1,2,figsize=(10,5))
iis = []
interesting_inds = []
for i, state in enumerate(basis):
#    if [state.n,state.l,state.j,state["mj"]] == [t_level.n, t_level.l, t_level.j,3/2]:
    if [state.n,state.l,state.j] == [t_level.n, t_level.l, t_level.j]:
        iis.append(i+len(basis)*n_max)
for ii in iis:
    for k, p in enumerate(eigenstates[:,ii,:]):
        if max(np.absolute(p)**2)>thrs:
            interesting_inds.append(k)
interesting_inds = list(set(interesting_inds))
band_c = 0
for k in interesting_inds:
    band_c +=1
#    lab = t_level.ket()+f"band {band_c}" if energy_bands else basis_n[k].ket()
    lab = basis_n[k].ket()
    if theta_list is None:
        axar[0].plot(varied[1], 1e-6*np.real(energies[k,:]-0*energies[k,0])/tau, label=lab)
        axar[1].plot(varied[1], np.absolute(eigenstates[k,ii-1,:])**2, label=lab)
    else:
        axar[0].plot(varied[1]/pi, 1e-6*np.real(energies[k,:]-0*energies[k,0])/tau, label=basis_n[k].ket())
        axar[1].plot(varied[1]/pi, np.absolute(eigenstates[k,ii,:])**2, label=basis_n[k].ket())
    #print(f"state = {basis[k%len(basis)].ket()}, dE = 2pix{1e-6*np.real(Energies[k,-1]-Energies[k,0])/tau}MHz")
axar[0].set_title("Change in Stark Shifts\n$E_{ac}$ = "+f"{Eac}V/m\n" + "$E_{dc}$ = " + f"{Edc}V/m")
axar[0].set_title(f"DC field response\n$E_{{ac}}$ = {Eac}V/m, $\epsilon$ = {ellipticity:.3f}, $\\theta$ = {theta/pi}$\pi$")
#axar[0].set_ylim(-40,2)
axar[1].set_title(f"Overlap with {basis_n[ii-1].ket()}")
axar[1].set_yscale("log")
axar[1].set_ylim(thrs*0.8,1)
axar[0].legend(loc=3,prop={"size":14})
if Edc_list is not None:
    axar[0].set_ylabel("DC Stark Shift/$2\pi$ (MHz)")
    axar[0].set_xlabel("DC Electric Field strength (V/m)")
    axar[1].set_xlabel("DC Electric Field strength (V/m)")
elif Eac_list is not None:
    axar[0].set_title("Change in Stark Shifts\n$\epsilon$ = "+f"{ellipticity}\n" + "$E_{dc}$ = " + f"{Edc}V/m")
    axar[0].set_ylabel("AC Stark Shift/$2\pi$ (MHz)")
    axar[0].set_xlabel("AC Electric Field strength (V/m)")
    axar[1].set_xlabel("AC Electric Field strength (V/m)")
elif theta_list is not None:
    axar[0].set_ylabel("Change in DC Stark Shift/$2\pi$ (MHz)")
    axar[0].set_xlabel("$\\theta/\pi$ (radians)")
    axar[1].set_xlabel("Angle Between AC and DC fields/$\pi$/ (radians)")
#axar[1].legend()
for txt in [axar[0].xaxis.label,axar[0].yaxis.label,axar[0].title]:
    txt.set_fontsize(14)
fig.tight_layout()
fig.show()

## Fit to quadratic shifts, compare to alphas

In [None]:
quad = lambda v, alpha, e0: -1/2*alpha*(v-0)**2+e0
four_pol = lambda v, alpha, beta, e0: -1/2*alpha*(v-0)**2+e0+beta*(v-0)**4
interesting_inds = list(set(interesting_inds))
 
func = four_pol
dim = len(basis_n)
#alphas = np.zeros(dim, dtype=float)
#dalphas = np.zeros(dim, dtype=float)
print("\n".join(basis_n[k].ket() for k in interesting_inds))
for k in iis:
    band_energies = energies[k,:]
#    guess = guess = [(band_energies[-1]-band_energies[0])/varied[1][-1]**2, band_energies[0]]
    guess = guess = [-2*(band_energies[-1]-band_energies[0])/varied[1][-1]**2, 0, band_energies[0]]
    try:
        popt, pcov = curve_fit(func, varied[1], band_energies, p0 = guess)
        perr = np.sqrt(np.diag(pcov))
    except RuntimeError:
        popt = [np.NaN]*len(guess)
        perr = popt
    print(f"state {basis_n[k].ket()} has polarizability fits:\n\t {popt}\n\t {perr}")
    
    fig,axar = plt.subplots(1,2,figsize=(10,4))
    ax=axar[0]
    ax.plot(varied[1],(band_energies - band_energies[0])*1e-6/tau, label = "Band energy")
    ax.plot(varied[1], (func(varied[1],*popt)-band_energies[0])*1e-6/tau, ls="-.", label = "Quadratic fit to band energy")
    ax.plot(varied[1], func(varied[1],guess[0],0,0*band_energies[0])*1e-6/tau, ls=":", label="Manual Guess")
    #ax.plot(varied[1], (quad(varied[1],popt[0], alphas[k], popt[2])-band_energies[0])*1e-6/tau, ls = ":", label = "fit from alphas")
    ax.legend()
    ax.set_xlabel(varied[0])
    ax.set_ylabel("DC Stark Shift/$2\pi$ (MHz)")
    ax.set_title(
        f"Quadratic fit on state {basis_n[k].ket()}\n$\\alpha$' = 2$\pi \cdot$ {popt[0]*1e-3/tau:.2f} $\pm$ {perr[0]*1e-3/tau:.2f} $KHz/(V/m)^2$\n$\\beta$' = 2$\pi \cdot$ {popt[1]/tau:.2e} $\pm$ {perr[1]/tau:.3e} $Hz/(V/m)^4$")# \n$\\alpha$'2 = 2$\pi \cdot$ {alphas[k]*1e-3/tau:.2f} $\pm$ {dalphas[k]*1e-3/tau:.2f} $KHz/(V/m)^2$")
    ax=axar[1]
    ax.plot(varied[1], (band_energies-func(varied[1],*popt))/tau, label = "Fit Residuals")
    ax.plot(varied[1], (band_energies-quad(varied[1],guess[0],guess[2]))/tau, ls=":", label="Guess residuals")
    ax.legend()
    ax.set_xlabel(varied[0])
    ax.set_ylabel("Fit Residuals (Hz)")
    fig.tight_layout()
    fig.show()

## Plot Energies for each relevant m_j

In [None]:
thrs = 2e-1  # threshold for caring about population
look_at = t_level
#look_at = RydStateFS(51,2,5/2)
for i, state in enumerate(basis):
    if state == look_at:
        interesting_inds = []
        #print(state.ket())
        ii = i+len(basis)*n_max
        #print(ii)
        for k, p in enumerate(eigenstates[:,ii,:]):
            #check 
            #print(k)
            if max(np.absolute(p)**2)>thrs:
                interesting_inds.append(k)
                #print(f"index {k} is interesting")
                #print(f"corresponds to {basis_n[k].ket()}")
                #print(f"{p}")
                #print(max(p))
            #print(k,p.shape,p[1:].max())
        fig,axar = plt.subplots(1,2,figsize=(8,6))
        for k in interesting_inds:
            axar[0].plot(varied[1], 1e-6*np.real(energies[k,:])/tau, label=basis_n[k].ket())
            axar[1].plot(varied[1], np.absolute(eigenstates[k,ii,:])**2, label=basis_n[k].ket())
        axar[0].set_title("Energies")
        #axar[0].set_ylim(-1e3,6e3)
        axar[1].axhline(0.5,ls=":")
        axar[1].set_title("Probability")
        axar[1].set_yscale("log")
        axar[1].set_ylim(thrs*0.8,1)
        axar[0].legend()
        axar[0].set_ylabel("Rydberg State Energy/$2\pi$ (MHz)")
        if Edc_list is not None:
            axar[0].set_xlabel("DC Electric Field strength (V/m)")
            axar[1].set_xlabel("DC Electric Field strength (V/m)")
        elif Eac_list is not None:
            axar[0].set_xlabel("AC Electric Field strength (V/m)")
            axar[1].set_xlabel("AC Electric Field strength (V/m)")
        #axar[1].legend()
        fig.suptitle(f"State = {state.ket()}")
        fig.tight_layout()
        fig.show()

## Plot Relative Shifts for each relevant m_j

In [None]:
thrs = 5e-2  # threshold for caring about population
look_at = t_level
#look_at = RydStateFS(51,2,5/2)
for i, state in enumerate(basis):
    if state == look_at:
        interesting_inds = []
        #print(state.ket())
        ii = i+len(basis)*n_max
        #print(ii)
        for k, p in enumerate(eigenstates[:,ii,:]):
            #check 
            #print(k)
            if max(np.absolute(p)**2)>thrs:
                interesting_inds.append(k)
                #print(f"index {k} is interesting")
                #print(f"corresponds to {basis_n[k].ket()}")
                #print(f"{p}")
                #print(max(p))
            #print(k,p.shape,p[1:].max())
        fig,axar = plt.subplots(1,2,figsize=(8,6))
        for k in interesting_inds:
            axar[0].plot(varied[1], 1e-6*np.real(energies[k,:]-energies[k,0])/tau, label=basis_n[k].ket())
            axar[1].plot(varied[1], np.absolute(eigenstates[k,ii,:])**2, label=basis_n[k].ket())
        axar[0].set_title("Energies")
        #axar[0].set_ylim(-1e3,6e3)
        axar[1].set_title("Probability")
        axar[1].set_yscale("log")
        axar[1].set_ylim(thrs*0.8,1)
        axar[0].legend()
        axar[0].set_ylabel("State DC Stark Shift/$2\pi$ (MHz)")
        if Edc_list is not None:
            axar[0].set_xlabel("DC Electric Field strength (V/m)")
            axar[1].set_xlabel("DC Electric Field strength (V/m)")
        elif Eac_list is not None:
            axar[0].set_xlabel("AC Electric Field strength (V/m)")
            axar[1].set_xlabel("AC Electric Field strength (V/m)")
        #axar[1].legend()
        fig.suptitle(f"State = {state.ket()}")
        fig.tight_layout()
        fig.show()

In [None]:
thrs = 5e-2  # threshold for caring about population
look_at = t_level
#look_at = RydStateFS(51,2,5/2)
for i, state in enumerate(basis):
    if state == look_at:
        interesting_inds = []
        #print(state.ket())
        ii = i+len(basis)*n_max
        #print(ii)
        for k, p in enumerate(eigenstates[:,ii,:]):
            #check 
            #print(k)
            if max(np.absolute(p)**2)>thrs:
                interesting_inds.append(k)
                #print(f"index {k} is interesting")
                #print(f"corresponds to {basis_n[k].ket()}")
                #print(f"{p}")
                #print(max(p))
            #print(k,p.shape,p[1:].max())
        fig,axar = plt.subplots(1,2,figsize=(8,6))
        for k in interesting_inds:
            axar[0].plot(varied[1], 1e-6*np.real(energies[k,:]-energies[k,0])/tau, label=basis_n[k].ket())
            axar[1].plot(varied[1], np.absolute(eigenstates[k,ii,:])**2, label=basis_n[k].ket())
        axar[0].set_title("Energies")
        #axar[0].set_ylim(-1e3,6e3)
        axar[1].set_title("Probability")
        axar[1].set_yscale("log")
        axar[1].set_ylim(thrs*0.8,1)
        axar[0].legend()
        axar[0].set_ylabel("State DC Stark Shift/$2\pi$ (MHz)")
        if Edc_list is not None:
            axar[0].set_xlabel("DC Electric Field strength (V/m)")
            axar[1].set_xlabel("DC Electric Field strength (V/m)")
        elif Eac_list is not None:
            axar[0].set_xlabel("AC Electric Field strength (V/m)")
            axar[1].set_xlabel("AC Electric Field strength (V/m)")
        #axar[1].legend()
        fig.suptitle(f"State = {state.ket()}")
        fig.tight_layout()
        fig.show()

# Analytical comparison (NOT COMPLETE)

In [None]:
# Mixing Parameters : TODO
pt = 0.7
pm = 0.3
th = np.arccos(np.sqrt(pt))
s2th = np.sin(2*th)**2
# Effective tensor polarizabilities : TODO
alpha2_t = 9
alpha2_m = -7
# Quantum numbers
j = t_level.j
mjs = np.arange(-j,j+1,1)


In [None]:
def anisotropy_coefficient(mj):
    t1 = -3/4*(3*mj**2-j*(j+1))/j/(2*j-1)
    t2 = -3/4*(3*mj**2-(j+2)*(j+1))/(j+1)/(2*j+1)
    t3 = ((1+j)**2-0.5*(3+2*j)*mj-mj**2)/((j+1)*(2*j+1)*(2*j+3))
    
    return pt*alpha2_t*t1+pm*alpha2_m*t2+0.5*s2th*d10*t3/(omegam-omegap)

# Debugging

In [None]:
Ht = np.diag([1,1,1,1,2,2,2,2,2])
print(Ht)
e, ev = np.linalg.eig(H0)
print((e-laser_dAC)*1e-6/tau)
print(np.diag(ev)-1)

In [None]:
Eac = Eac_list[1]
Edc = 0
ellipticity = 0.012
theta = 0
HF = build_floquet(
    basis,
    H0,
    Eac,
    ellipticity,
    field_omega,
    Edc,
    theta,
    n_max
)
(HF-np.diag(np.diag(HF))).sum()

In [None]:
e,ev =np.linalg.eig(HF)
print((ev-np.diag(np.diag(ev))).sum())

In [None]:
np.dot(Ht,[1,1,0,0])

# Debugging

In [None]:
fields
H_floquet = build_floquet(basis,H0,Eac,0.02,field_omega,Edc_list[-1],pi/2,n_max)
H_floquet = H_floquet
fig,ax = plt.subplots(1,1)
ax.imshow(np.real(H_floquet[(n_max-1)*len(basis):(n_max+1)*len(basis),(n_max-1)*len(basis):(n_max+1)*len(basis)]))
fig.show()

In [None]:
fields
eigenvalues, eigenvectors = floquet_diag(basis, H0, n_max = n_max, **fields)

In [None]:
energies = np.zeros((len(basis)*(2*n_max+1),2), dtype=complex)
eigenstates = np.zeros((len(basis)*(2*n_max+1), len(basis)*(2*n_max+1), 2), dtype=complex)
starts = level_starts(levels)
ips_summer = level_projector(basis, levels, n_max)

# compute overlaps wrt unperturbed eigenstates
ips = np.abs(eigenvectors)**2
# sum over all zeeman states in each level
ips_levels = np.dot(ips, ips_summer)
# sum over all level in each fourier sub basis
used_inds = []
troublesome_level = RydStateFS(51,2,5/2)
for j, level in enumerate(levels):
    for k, n in enumerate(range(-n_max, n_max+1)):
        print(f"finding good eigenvectors for |level, n> = |{level.ket()},{n}>")
        # eigenvectors that have >50% population in this level
        thrsh = 0.5
        inds_l = np.argwhere(ips_levels[:, j+k*len(levels)] > thrsh)
        # if the 50% threshold is too high to accommodate all m levels, lower the
        # threshold incrementally
        while len(inds_l) < 2*level.j+1:
            thrsh *= 0.95
            print(f"expansion required, threshold reduced to {thrsh}")
            inds_l = np.argwhere(ips_levels[:, j + k * len(levels)] > thrsh)

        # find the band that has the greatest overlap with each mj level
        print(f"inds_l = {inds_l[:,0]}")
        strt = starts[j]+k*len(basis)
        for a in range(int(2*level.j + 1)):
            # print(f"m = {-level.j + a}")
            # print(f"sub_ips = {ips[inds_l, strt + a]}")
            ev_ind = np.argmax(ips[inds_l, strt + a])
            if inds_l[ev_ind, 0] in used_inds:
                print(f"WARNING: index {inds_l[ev_ind, 0]} has been used")
            print(f"inds used = {inds_l[ev_ind,0]}")
            used_inds.append(inds_l[ev_ind, 0])
            # print(f"new inds = {inds_l[ev_ind,0]}")
            # print(f"New energy = {eigenvalues[inds_l[ev_ind, 0]]}")
            eigenstates[strt + a, :, 0] = eigenvectors[inds_l[ev_ind, 0], :]
            energies[strt + a, 0] = np.real(eigenvalues[inds_l[ev_ind, 0]])


In [None]:
n = -2
k = n+n_max
j = 2
strt = starts[j]+k*len(basis)
end = starts[j+1]+k*len(basis)
ind_l = [32, 33, 34, 35]
for ind in ind_l:
    print(f"overlaps in range {basis_n[strt].ket()}-{basis_n[end-1].ket()} for ind = {ind}")
    print(f"\t{ips[ind,strt:end]}")

In [None]:
basis[j].ket()

In [None]:
strt

In [None]:
ips_levels[32,j+k*len(levels)]

In [None]:
thrs = 1e-2
print([f"{ev*(ev > thrs):.2e}" for ev in ips[32]])

In [None]:
basis_print(levels)

# Convergence

In [None]:
class ConvTest:
    def __init__(self, n_max, dw, dl, energies=None, eigenstates=None):
        self.n_max=n_max
        self.max_det = dw
        self.dl = dl
        self.energies = energies
        self.eigenstates = eigenstates
    def __repr__(self):
        return f"ConvTest({self.n_max},{self.max_det},{self.dl})"
    def __str__(self):
        return self.__repr__()

In [None]:
filepath = "Shirley-Floquet_results\\convergence_tests"

ConvTests = [None]*(len(os.listdir(filepath))//2)
for i, file in enumerate(os.listdir(filepath)):
    test_str = "eigenstates.npy"
    nrg = "energies.npy"
    print(file)
    if file[-len(test_str):] == test_str:
        print(file)
        settings = file.split("-")[1:4]
        for setting in settings:
            exec(setting)
        ConvTests[i//2] = ConvTest(nmax,dw,dl)
        ConvTests[i//2].energies = np.load(os.path.join(filepath,file[:-len(test_str)]+nrg))
        # ConvTests[i//2].eigenstates = np.load(os.path.join(filepath,file))

In [None]:
ConvTests

In [None]:
for test in ConvTests:
    levels, basis = build_basis(t_level,test.max_det,single_side=False, dl=test.dl)
    target_inds = [i for i, state in enumerate(basis) if [state.n,state.l,state.j] == [t_level.n,t_level.l,t_level.j]]
    dim = len(basis)
    t_level_nrgs = {basis[k].ket(): test.energies[test.n_max*dim + k] for k in target_inds}
    #t_level_estates = {basis[k].ket(): test.eigenstates[test.n_max*dim + k] for k in target_inds}
    test.sub_nrgs = t_level_nrgs
    #test.sub_estates = t_level_estates

In [None]:
ConvTests[0].desc = "$\Delta l = 2$, $m_{max} = 2$"#", $\omega_{max} = 2\pi\\times30GHz$"
ConvTests[1].desc = "$\Delta l = 1$, $m_{max} = 3$, $\omega_{max} = 2\pi\\times30GHz$"
ConvTests[2].desc = "$\Delta l = 2$, $m_{max} = 3$"#", $\omega_{max} = 2\pi\\times30GHz$"
ConvTests[3].desc = "$\Delta l = 1$, $m_{max} = 3$, $\omega_{max} = 2\pi\\times60GHz$"
ConvTests[4].desc = "$\Delta l = 1$, $m_{max} = 4$, $\omega_{max} = 2\pi\\times30GHz$"

In [None]:
ConvTests[0].desc = "$\Delta l = 2$, $m_{max} = 2$"#", $\omega_{max} = 2\pi\\times30GHz$"
ConvTests[1].desc = "Baseline"
ConvTests[2].desc = "$\Delta l = 2$, $m_{max} = 3$"#", $\omega_{max} = 2\pi\\times30GHz$"
ConvTests[3].desc = "$\omega_{max} = 2\pi\\times60GHz$"
ConvTests[4].desc = "$m_{max} = 4$"

In [None]:
ms = int(2*t_level.j+1)
#fig,axar = plt.subplots(ms,2,figsize=(8,ms*3))
#names = ["$\Delta l = 2, m_{max}-1$", "Baseline", "$\Delta l = 2$", "$2\\times \omega_{max}$", "$m_{max} + 1$"]
sub_list = ConvTests[1:]+[ConvTests[0]]
for m in range(ms):
    fig,axar = plt.subplots(2,1,figsize=(4,4.5),gridspec_kw={'height_ratios': [3, 2]})#,sharex=True)
    ax = axar[1]
    for t,test in enumerate(sub_list):
        m_key = list(test.sub_nrgs.keys())[m]
        ax.plot(varied[1],(test.sub_nrgs[m_key]-ConvTests[1].sub_nrgs[m_key])*1e-6/tau,label=test.desc)
    #ax.set_title(m_key)
    ax.set_ylabel("Correction (MHz)")
    ax.set_xlabel("$E_{DC}$ (V/m)")
    #ax.legend(loc="center right",prop={"size":9})
#ax.set_ylim(-1,1)
    
    ax = axar[0]
    for t,test in enumerate(sub_list):
        m_key = list(test.sub_nrgs.keys())[m]
        ax.plot(varied[1],(test.sub_nrgs[m_key]-test.sub_nrgs[m_key][0])*1e-6/tau,label=test.desc)
    #ax.set_title(m_key)
    ax.set_ylabel("$\Delta_{DC}$ (MHz)")
    #ax.set_xlabel("$E_{DC}$ (V/m)")
    ax.legend(loc=3,prop={"size":9})

    axar[0].get_shared_x_axes().join(axar[0], axar[1])
    axar[0].set_xticklabels([])
    axar[1].set_xlim(axar[0].get_xlim())

    for ax in axar:
        ax.yaxis.set_ticks_position('left')
        ax.xaxis.set_ticks_position('bottom')
        for txt in [ax.xaxis.label, ax.yaxis.label, ax.title]:
            txt.set_fontsize("10")
    
    fig.tight_layout()
    fig.show()

In [None]:
ms = int(2*t_level.j+1)
#fig,axar = plt.subplots(ms,2,figsize=(8,ms*3))
#names = ["$\Delta l = 2, m_{max}-1$", "Baseline", "$\Delta l = 2$", "$2\\times \omega_{max}$", "$m_{max} + 1$"]
sub_list = ConvTests[1:]+[ConvTests[0]]

sub_lists = [ConvTests[1:]+[ConvTests[0]],[ConvTests[0],ConvTests[2]]]
labs="abcd"
for m in range(ms):
    fig,axar = plt.subplots(2,2,figsize=(7,4.5),gridspec_kw={'height_ratios': [3, 2]})#,sharex=True)
    for i,sub_list in enumerate(sub_lists):
        ax = axar[1,i]
        for t,test in enumerate(sub_list):
            m_key = list(test.sub_nrgs.keys())[m]
            ax.plot(varied[1],(test.sub_nrgs[m_key]-sub_list[0].sub_nrgs[m_key])*1e-6/tau,label=test.desc)
        #ax.set_title(m_key)
        ax.set_ylabel("Correction (MHz)")
        ax.set_xlabel("$E_{DC}$ (V/m)")
        #ax.legend(loc="center right",prop={"size":9})
    #ax.set_ylim(-1,1)

        ax = axar[0,i]
        for t,test in enumerate(sub_list):
            m_key = list(test.sub_nrgs.keys())[m]
            ax.plot(varied[1],(test.sub_nrgs[m_key]-test.sub_nrgs[m_key][0])*1e-6/tau,label=test.desc)
        #ax.set_title(m_key)
        ax.set_ylabel("$\Delta_{DC}$ (MHz)")
        #ax.set_xlabel("$E_{DC}$ (V/m)")
        ax.legend(loc=3,prop={"size":9})

        axar[0,i].get_shared_x_axes().join(axar[0,i], axar[1,i])
        axar[0,i].set_xticklabels([])
        axar[1,i].set_xlim(axar[0,i].get_xlim())
        #axar[0,i].text(0,3, labs[2*i])
        #axar[1,i].text(-2,0, labs[2*i+1])
        
        for axs in axar:
            for ax in axs:
                ax.yaxis.set_ticks_position('left')
                ax.xaxis.set_ticks_position('bottom')
                for txt in [ax.xaxis.label, ax.yaxis.label, ax.title]:
                    txt.set_fontsize("11")

    fig.tight_layout()
    fig.savefig(f"plots/Convergence_DCStark{m}n.png")
    fig.show()

In [None]:
axar[1,1].text(5,5,"Test")
fig.show()

In [None]:
func = lambda x, alpha, beta: -1/2*alpha*x**2+beta*x**4
popts = np.zeros((len(ConvTests),len(t_inds),2),dtype=float)
perrs = np.zeros(popts.shape,dtype=float)

ConvTests=ConvTests[1:]+[ConvTests[0]]
for j,m in enumerate(range(ms)):
    for i,test in enumerate(ConvTests):
        m_key = list(test.sub_nrgs.keys())[m]
        x_dat = varied[1]
        y_dat = test.sub_nrgs[m_key]-test.sub_nrgs[m_key][0]
        guess = [-2*y_dat[-1]/x_dat[-1]**2,0]
        popts[i,j], pcov = curve_fit(func,x_dat,y_dat,p0=guess)
        perrs[i,j] = np.sqrt(np.diag(pcov))

In [None]:
fig,axar=plt.subplots(2,1,figsize=(6,4))
w=1/(ms+1)
for m in range(ms):
    ax=axar[0]
    locs = np.arange(len(ConvTests)-1)+(m-ms/2+1/2)*w
    ax.bar(locs,(popts[1:,m,0]-popts[0,m,0])*1e-3/tau,width=w,label=f"band {m}")
    ax.set_ylabel("$\Delta\\alpha'$ ($kHz/(V/m)^2$)")
    ax.legend()
    
    ax=axar[1]
    ax.bar(locs,(popts[1:,m,1]-popts[0,m,1])/tau,width=w)
    ax.set_ylabel("$\Delta\\beta'$ ($Hz/(V/m)^4$)")
    #ax.tick_params(axis='x', labelrotation = 90)
    axar[0].get_shared_x_axes().join(axar[0], axar[1])
    axar[0].set_xticklabels([])
    axar[1].set_xlim(axar[0].get_xlim())
    axar[1].set_xticks(np.arange(len(ConvTests)-1),minor=False)
    axar[1].set_xticklabels([test.desc for test in ConvTests[1:]])
    axar[1].set_xlim(-w*3,max(locs)+w)
    for ax in axar:
        ax.yaxis.set_ticks_position('left')
        ax.xaxis.set_ticks_position('bottom')
        ax.minorticks_off()
        
        for txt in [ax.xaxis.label, ax.yaxis.label, ax.title]:
            txt.set_fontsize("10")
    [test.desc for test in ConvTests]
fig.tight_layout()
fig.savefig("Plots/ConvergencePolarizabilitiesn.png")
fig.show()