In [None]:
import sys
import cmath
import math
import os
import h5py
import matplotlib.pyplot as plt   # plots
import matplotlib.animation as animation
import matplotlib.colors as clrs
from mpl_toolkits.mplot3d import Axes3D
from matplotlib import cm
from matplotlib.colors import Normalize, TwoSlopeNorm
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.ticker import FormatStrFormatter
from matplotlib.animation import FuncAnimation

import numpy as np
import time
import warnings

from liblibra_core import *
import util.libutil as comn
from libra_py import units
import libra_py.models.Holstein as Holstein
import libra_py.models.Phenol as Phenol

from libra_py import dynamics_plotting
import libra_py.dynamics.tsh.compute as tsh_dynamics
import libra_py.dynamics.tsh.plot as tsh_dynamics_plot
import libra_py.dynamics.exact.plot as exact_plot
import libra_py.data_savers as data_savers
import libra_py.data_read as data_read
import libra_py.data_visualize as data_visualize

import libra_py.dynamics.exact.compute as dvr
import libra_py.dynamics.exact.save as dvr_save

import re

%matplotlib inline
#from matplotlib.mlab import griddata
#%matplotlib inline 
#warnings.filterwarnings('ignore')

colors = {}
colors.update({"11": "#CB3310"}) #red
colors.update({"21": "#0077BB"}) #blue
colors.update({"31": "#009988"}) #teal
colors.update({"41": "#EE7733"}) #orange
colors.update({"51": "#00FFFF"}) #cyan
colors.update({"61": "#EE3377"}) #magenta
colors.update({"71": "#AA3377"}) #purple
colors.update({"81": "#BBBBBB" }) #grey

clrs_index = ["11", "21", "31", "41", "51", "61", "71", "81"]

In [None]:
plt.rc('axes', titlesize=30)      # fontsize of the axes title
plt.rc('axes', labelsize=30)      # fontsize of the x and y labels
plt.rc('legend', fontsize=30)     # legend fontsize
plt.rc('xtick', labelsize=30)     # fontsize of the tick labels
plt.rc('ytick', labelsize=30)     # fontsize of the tick labels

plt.rc('figure', facecolor='white')

#plt.rc('figure.subplot', left=0.2)
#plt.rc('figure.subplot', right=0.95)
#plt.rc('figure.subplot', bottom=0.2)
plt.rc('figure.subplot', top=0.88)


font = {'family': 'serif',
        'color':  'blue',
        'weight': 'bold',
        'size': 30,
        }


# PES

In [None]:
class tmp:
    pass

def superexchange(q, params, full_id):

    critical_params = []
    default_params = { "V_11":0.01, "V_22":0.005, "A_01":0.001, "A_12":0.01, "A_02":0.0  }
    comn.check_input(params, default_params, critical_params)

    V_11 = params["V_11"]
    V_22 = params["V_22"]
    A_01 = params["A_01"]
    A_12 = params["A_12"]
    A_02 = params["A_02"]

    n = 3

    Hdia = CMATRIX(n,n)
    Sdia = CMATRIX(n,n)
    d1ham_dia = CMATRIXList();  d1ham_dia.append( CMATRIX(n,n) )
    dc1_dia = CMATRIXList();  dc1_dia.append( CMATRIX(n,n) )

    Id = Cpp2Py(full_id)
    indx = Id[-1]

    x = q.col(indx).get(0)

    Sdia.identity()

    Hdia.set(0,0,  0 )
    Hdia.set(1,1,  V_11)
    Hdia.set(2,2,  V_22)

    exp_factor = math.exp(-0.5 *x * x)
    dexp_factor = -x * math.exp(-0.5 *x * x)

    Hdia.set(0,1,  A_01*exp_factor); Hdia.set(1,0,  A_01*exp_factor)
    Hdia.set(1,2,  A_12*exp_factor); Hdia.set(2,1,  A_12*exp_factor)

    d1ham_dia[0].set(0, 1, A_01*dexp_factor); d1ham_dia[0].set(1, 0, A_01*dexp_factor)
    d1ham_dia[0].set(1, 2, A_12*dexp_factor); d1ham_dia[0].set(2, 1, A_12*dexp_factor)

    obj = tmp()
    obj.ham_dia = Hdia
    obj.ovlp_dia = Sdia
    obj.d1ham_dia = d1ham_dia
    obj.dc1_dia = dc1_dia

    return obj

In [None]:
def compute_model(q, params, full_id):

    model = params["model"]
    res = None

    if model==1:
        res = Holstein.Holstein2(q, params, full_id)
    elif model==2:
        res = superexchange(q, params, full_id)
    elif model==3:
        res = Phenol.Pollien_Arribas_Agostini(q, params, full_id)
    else:
        pass

    return res

In [None]:
model_params1 = {"model":1, "model0":1, "nstates":2, "E_n":[0.0, -0.01], "x_n":[0.0,  0.5],"k_n":[0.002, 0.008],"V":0.001} # holstein
model_params2 = {"model":2, "model0":2, "nstates":3} # superexchange
model_params3 = {"model":3, "model0":3, "nstates":3} # phenol

all_model_params = [model_params1, model_params2, model_params3]

#################################
# Give the model used an index
model_indx = 0
################################

model_params = all_model_params[model_indx]

list_states = [x for x in range(model_params["nstates"])]
NSTATES = model_params["nstates"]

## 2D PES

In [None]:
def potential(q, params):
    full_id = Py2Cpp_int([0,0]) 
    
    return compute_model(q, params, full_id)

In [None]:
def get_E(nst, q, _model_params):
    res = potential(q, _model_params)
    Hdia = np.zeros((nst, nst), dtype=complex)
    for i in range(nst):
        for j in range(nst):
            Hdia[i,j] = res.ham_dia.get(i,j)
    H = np.diag(Hdia)
    E, U = np.linalg.eig(Hdia) 
    idx= np.argsort(E) 
    E = E[idx] 
    U = U[:,idx]

    return H, E

In [None]:
def get_PES(Rs, Ts, _model_params):
    nR = len(Rs)
    nT = len(Ts)

    dR = Rs[1] - Rs[0]
    dT = Ts[1] - Ts[0]

    PES_dia = np.zeros((nR, nT, 3)) # nR, nT, nst
    PES_adi = np.zeros((nR, nT, 3))

    for ix in range(nR):
        for iy in range(nT):
            q = MATRIX(2,1)
    
            q.set(0,0, Rs[0] + ix*dR)
            q.set(1,0, Ts[0] + iy*dT)
    
            H, E = get_E(3, q, _model_params)
            PES_dia[ix, iy] = H
            PES_adi[ix, iy] = E

    return PES_dia, PES_adi


In [None]:
Angst = 1.88973

nR = 100
nT = 100

dR = 3.0 / nR
dT = 2.0 / nT

Rs = [(0.5 + dR*ix)*Angst for ix in range(nR + 1)]
Ts = [0 + dT*iy for iy in range(nT + 1)]

PES_dia, PES_adi = get_PES(Rs, Ts, all_model_params[2])

In [None]:
def plot_PES_3D(PES, Rs, Ts, title="Diabatic PES"):
    # Unit convert
    Rs_t = list(Rs)
    Rs_t = [x/1.88973 for x in Rs_t]
    PES_t = np.copy(PES)
    
    PES_t *= 27.21

    z_lim = 10
    PES_t = np.where(PES_t > z_lim, np.nan, PES_t)
    
    # Generate a meshgrid for plotting
    R_mesh, T_mesh = np.meshgrid(Rs_t, Ts)
    num_states = PES.shape[2]

    #vmin = np.min(PES)
    #vmax = np.max(PES)
    
    # Set up a single 3D plot for all states
    fig = plt.figure(figsize=(15, 13))
    ax = fig.add_subplot(111, projection='3d')

    # Plot each state in the same 3D space
    for state in range(num_states):
        #surf = ax.plot_surface(
        #    R_mesh, T_mesh, PES[:, :, state].T,
        #    cmap='viridis', edgecolor='none', alpha=0.6, # Set alpha for transparency
        #    vmin=vmin, vmax=vmax
        #)
        surf = ax.plot_surface(
            R_mesh, T_mesh, PES_t[:, :, state].T,
            color=colors[clrs_index[state]], edgecolor='none', shade=False
        )
    
    # Color bar for energy and axis labels
    #fig.colorbar(surf, ax=ax, shrink=0.5, aspect=5)

    ax.margins(x=0, y=0, z=0)

    ax.set_xticks([1.0, 2.0, 3.0])
    ax.set_yticks([0, math.pi/3])
    ax.set_yticklabels(["0", r"    $\pi/4$"])
    ax.set_zticks([2.0, 4.0, 6.0, 8.0, 10.0])
    ax.set_zlim([0, 10])

    #ax.xaxis.set_major_formatter(FormatStrFormatter('%.1f'))
    #ax.zaxis.set_major_formatter(FormatStrFormatter('%.1f'))
    
    ax.set_xlabel(r"OH length $r$ (Å)")
    ax.set_ylabel(r"Dih. $\theta$ (rad)")
    ax.set_zlabel("energy (eV)")
    if title != "":
        ax.set_title(title)

    ax.xaxis.labelpad = 30
    ax.yaxis.labelpad = 40
    ax.zaxis.labelpad = 30

    ax.view_init(elev=5, azim=275)

    ax.xaxis.pane.fill = False
    ax.yaxis.pane.fill = False
    ax.zaxis.pane.fill = False

    fig.tight_layout()
    #plt.savefig("fig1_pes_phenol.png", dpi=300)

In [None]:
plot_PES_3D(PES_adi, Rs, Ts, title="")

## PES

In [None]:
def plot_pec(_compute_model, model_params, xmin, xmax, dx, states_of_interest=[0,1], _ndof=1, _active_dof=0, _all_coordinates=[0.0], 
             show_nac_abs=0, frac_nac=1, _legend_ncol=1, _legend_font=30,  _loc="best", _bbox_to_anchor=""):
    X = []
    nsteps = int((xmax - xmin) / dx) + 1

    for i in range(nsteps):
        X.append(xmin + i * dx)
    comn.check_input(model_params, {}, ["nstates"])
    n = model_params["nstates"]
    nstates = n
    
    ham = nHamiltonian(nstates, nstates, _ndof) # ndia, nadi, nnucl
    ham.init_all(2)
    
    hdia, hadi, nac, nac_abs  = [], [], [], []
    uij = []              # projecitions of the MOs onto elementary basis
    
    for k1 in range(nstates):
        hadi.append([])
        hdia.append([])
        nac_k1 = []
        nac_abs_k1 = []
        uij_k1 = []
        for k2 in range(nstates):
            uij_k1.append([])
            nac_k1.append([])
            nac_abs_k1.append([])
        uij.append(uij_k1)
        nac.append(nac_k1)
        nac_abs.append(nac_abs_k1)
    
    for i in range(nsteps):

        scan_coord = MATRIX(_ndof, 1);
        for j in range(_ndof):
            scan_coord.set(j, 0, _all_coordinates[j])
        scan_coord.set(_active_dof, 0, X[i])

        # Diabatic properties
        ham.compute_diabatic(_compute_model, scan_coord, model_params)

        # Adiabatic properties
        ham.compute_adiabatic(1);

        U = ham.get_basis_transform()
        #P = U * U.H()  # population matrix

        for k1 in range(nstates):
            hadi[k1].append(ham.get_ham_adi().get(k1, k1).real)
            hdia[k1].append(ham.get_ham_dia().get(k1, k1).real)

            for k2 in range(nstates):
                uij[k1][k2].append(U.get(k1,k2).real**2 + U.get(k1,k2).imag**2)
                nac_k1_k2 = ham.get_dc1_adi(0).get(k1, k2).real
                nac[k1][k2].append(nac_k1_k2)
                nac_abs[k1][k2].append( abs(nac_k1_k2) )
    
    fig, ax1 = plt.subplots(figsize=(10,6), dpi=300)
    ax1.tick_params(axis='both', labelsize=30)
    
    ax1.set_xlabel('$R$ (Bohr)', fontsize=30)
    #ax1.set_ylabel('energy (Ha)', labelpad=15, fontsize=30)
    ax1.set_ylabel('energy (Ha)', fontsize=30)
    ax1.margins(x=0)

    for k1 in states_of_interest:
        ax1.plot(X, hadi[k1], label='$E_{%i}$' % (k1), linewidth=5, color = colors[clrs_index[k1]], zorder=5)
            
    for k1 in states_of_interest:
        if k1 == 0:
            plt.plot(X, hdia[k1], label='$H_{%i%i}$' % (k1,k1), ls="dotted", lw=4, color = colors[clrs_index[k1]], zorder=10)
        else:
            plt.plot(X, hdia[k1], label='$H_{%i%i}$' % (k1,k1), ls="dotted", lw=4, color = colors[clrs_index[k1]])
    
    ax1.legend(frameon=False, ncol=2)

    fig.tight_layout()
    #plt.savefig("fig1_pes_holstein.png", dpi=300)

In [None]:
plot_pec(compute_model, model_params1, -4.0, 5.0, 0.05, states_of_interest=[0,1])

In [None]:
plot_pec(compute_model, model_params2, -10.0, 10.0, 0.05, states_of_interest=[0,1,2])

## Holstein

In [None]:
def holstein_all(dirs, labels, title="Populations", onlegend=True, name_out="res.png"):

    with h5py.File("holstein/DVR/exact-model0-icond0-p0_0/data.hdf", 'r') as f:
        time_DVR = f["time/data"][:]/41.0
        pop_adi_DVR = np.array(f["pop_adi/data"])
        coh_adi_DVR = np.array(f["coherence_adi/data"])
    
    MQC_se_pop = []
    MQC_sh_pop = []
    MQC_coh = []
    MQC_Etot = []
    
    for idir in dirs:
        with h5py.File(idir + "/mem_data.hdf") as f:
            time = np.array(f["time/data"])/41.0
            se_pop = np.array(f["se_pop_adi/data"])
            sh_pop = np.array(f["sh_pop_adi/data"])
            coh_adi = np.array(f["coherence_adi/data"])
            #C_adi = np.array(f["Cadi/data"][:,:,:])
            #hvib_adi = np.array(f["hvib_adi/data"][:,:,:,:])
            Etot = np.array(f["Etot_ave/data"])
    
        Etot0 = Etot[0]
        Etot -= Etot0
        Etot = np.abs(Etot)/Etot0
    
        MQC_se_pop.append(se_pop)
        MQC_sh_pop.append(sh_pop)
        MQC_coh.append(coh_adi)
        MQC_Etot.append(Etot)
    
    nsteps = se_pop.shape[0]
    nst = se_pop.shape[1]
    
    fig, axs = plt.subplots(nrows=3, figsize=(10,16) )
    
    ## Energy conservation
    axs[0].margins(x=0)
    axs[0].set_ylabel(r"$|\Delta E_{tot}|/E_{tot}(0)$")
    axs[0].yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    
    for i, idir in enumerate(dirs):  
        axs[0].plot(time, MQC_Etot[i], linewidth=5, color=colors[clrs_index[i]])
    
    # For label
    for i, idir in enumerate(dirs):
        axs[0].plot([],[], linewidth=5, color=colors[clrs_index[i]], label=labels[i])
    axs[0].plot([],[], linewidth=5, color='k', label="DVR")
    
    axs[0].legend(frameon=False, loc="upper left", bbox_to_anchor=(0, 1.03))
    
    ## POP
    axs[1].margins(x=0)
    axs[1].yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    # DVR
    axs[1].plot(time_DVR, pop_adi_DVR[:,0], linewidth=7, color="k")
    
    # MQC
    for i, idir in enumerate(dirs):
        axs[1].plot(time, MQC_sh_pop[i][:,0], linewidth=5, color=colors[clrs_index[i]])
    
    for i, idir in enumerate(dirs):
        axs[1].plot(time, MQC_se_pop[i][:,0], linewidth=5, 
                           linestyle="dotted", color=colors[clrs_index[i]], zorder = -10)
                
    axs[1].set_ylabel(rF"population")
        
    axs[-1].set_xlabel("time (fs)")

    ## COH
    axs[2].yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    axs[2].set_ylabel(rF"coherence")
    axs[2].margins(x=0)
    axs[2].plot(time_DVR, coh_adi_DVR[:,0, 1], linewidth=7, color="k")
    for i, idir in enumerate(dirs):
        axs[2].plot(time, MQC_coh[i][:, 0, 1], linewidth=5, color=colors[clrs_index[i]])
    
    axs[-1].set_xlabel("time (fs)")
        
    #fig.legend(loc='upper center', ncol=3, fontsize=30, frameon=False, bbox_to_anchor=(0.565,1.01))
    fig.tight_layout()
    fig.subplots_adjust(top=0.9)  # Move subplots down to make space for the legend
    
    #plt.savefig(name_out, dpi=300)

In [None]:
dirs = ["model0-method0-icond0-p0_0/", "model0-method1-icond0-p0_0/", 
        "model0-method2-icond0-p0_0-eps0_0/", "model0-method3-icond0-p0_0-w0_1"]

dirs = ["holstein/TSH/" + idir for idir in dirs]

labels=["FSSH", "QTSH", "QTSH-SDM", "QTSH-XF"]

holstein_all(dirs, labels, title="", onlegend=True, name_out="fig2_holstein.png")

## Superexchange: scattering

In [None]:
def get_RT_DVR3(_model_indx, _icond_indx, _ks):
    R0_QD =np.zeros(len(_ks))
    R1_QD =np.zeros(len(_ks))
    R2_QD =np.zeros(len(_ks))
    T0_QD =np.zeros(len(_ks))
    T1_QD =np.zeros(len(_ks))
    T2_QD =np.zeros(len(_ks))


    for i, k in enumerate(_ks):
        if k < 11:
            fn = F"superexchange/DVR/wfc{_model_indx}-icond0-p{k}/wfcr_snap_40_dens_rep_1".replace(".", "_")
        else:
            fn = F"superexchange/DVR/wfc{_model_indx}-icond0-p{k}/wfcr_snap_20_dens_rep_1".replace(".", "_")
        which_cols = [0]
        nstates = 3
    
        for state in [0, 1, 2]:
            which_cols.append(1+state)
        data_exact = data_read.get_data_from_file2(fn, which_cols)
    
        ngrids = len(data_exact[0])
    
        #find necessary region
        grids = np.array(data_exact[0])
    
        # Get wfc array
        wfcs = []
        for ist in range(nstates):
            wfcs.append(np.array(data_exact[1+ist]))
    
        wfcs = np.array(wfcs)
        
        # find the zero
        neg = list(grids).index(max(grids[grids<0]))
        
        dr = grids[1] - grids[0]
        
        P0 = wfcs[0]*dr
        P1 = wfcs[1]*dr
        P2 = wfcs[2]*dr
        
        R0_QD[i]=np.sum(P0[:neg+1])
        T0_QD[i]=np.sum(P0[neg+1:])
        R1_QD[i]=np.sum(P1[:neg+1])
        T1_QD[i]=np.sum(P1[neg+1:])
        R2_QD[i]=np.sum(P2[:neg+1])
        T2_QD[i]=np.sum(P2[neg+1:])
        
    return R0_QD, R1_QD, R2_QD, T0_QD, T1_QD, T2_QD

In [None]:
def get_RT_MQC3(_model_indx, _method_indx, _icond_indx, _ks):
    R0 = np.zeros(len(_ks))
    R1 = np.zeros(len(_ks))
    R2 = np.zeros(len(_ks))
    T0 = np.zeros(len(_ks))
    T1 = np.zeros(len(_ks))
    T2 = np.zeros(len(_ks))
    
    for i, k in enumerate(_ks):
        if _method_indx == 5:
            dir_dyn = "superexchange/TSH/" + F"model{_model_indx}-method{_method_indx}-icond{_icond_indx}-p{k:.1f}-eps0_1".replace(".", "_")
        elif _method_indx == 6:
            dir_dyn = "superexchange/TSH/" + F"model{_model_indx}-method{_method_indx}-icond{_icond_indx}-p{k:.1f}-w0_1".replace(".", "_")
        else:
            dir_dyn = "superexchange/TSH/" + F"model{_model_indx}-method{_method_indx}-icond{_icond_indx}-p{k:.1f}".replace(".", "_")
        
        R0_tmp, R1_tmp, R2_tmp, T0_tmp, T1_tmp, T2_tmp = 0, 0, 0, 0, 0, 0
        with h5py.File(dir_dyn + "/mem_data.hdf", 'r') as f:
            qs = np.array(f["q/data"][:,:,0])
            ntraj = qs.shape[1]
            
            for itraj in range(ntraj):
                act_st = f["states/data"][-1, itraj]
                q_end = f["q/data"][-1,itraj,0]
                p_end = f["p/data"][-1,itraj,0]
                #print(act_st, q_end, p_end)
                if act_st==0: # at 0
                    if q_end > 0.0:
                        T0_tmp += 1
                    elif q_end < 0.0:
                        R0_tmp += 1
                elif act_st==1: # at 1
                    if q_end > 0.0:
                        T1_tmp += 1
                    elif q_end < 0.0:
                        R1_tmp += 1
                else: # at 2
                    if q_end > 0.0:
                        T2_tmp += 1
                    elif q_end < 0.0:
                        R2_tmp += 1
            R0[i] = R0_tmp/ntraj
            T0[i] = T0_tmp/ntraj
            R1[i] = R1_tmp/ntraj
            T1[i] = T1_tmp/ntraj
            R2[i] = R2_tmp/ntraj
            T2[i] = T2_tmp/ntraj
    
    return R0, R1, R2, T0, T1, T2

In [None]:
incr_k = 1.0
ks = [5.0 + incr_k*ik for ik in range(16)]

R0_FSSH, R1_FSSH, R2_FSSH, T0_FSSH, T1_FSSH, T2_FSSH = get_RT_MQC3(1, 0, 0, ks)
R0_QTSH, R1_QTSH, R2_QTSH, T0_QTSH, T1_QTSH, T2_QTSH = get_RT_MQC3(1, 1, 0, ks)
R0_QTSH_SDM, R1_QTSH_SDM, R2_QTSH_SDM, T0_QTSH_SDM, T1_QTSH_SDM, T2_QTSH_SDM = get_RT_MQC3(1, 2, 0, ks)
R0_QTSH_XF, R1_QTSH_XF, R2_QTSH_XF, T0_QTSH_XF, T1_QTSH_XF, T2_QTSH_XF = get_RT_MQC3(1, 3, 0, ks)
R0_QD, R1_QD, R2_QD, T0_QD, T1_QD, T2_QD = get_RT_DVR3(1, 0, ks_DVR)

In [None]:
incr_k = 0.2
ks_low = [3.0 + incr_k*ik for ik in range(11)]

R0_FSSH_low, R1_FSSH_low, R2_FSSH_low, T0_FSSH_low, T1_FSSH_low, T2_FSSH_low = get_RT_MQC3_low(0, 0, 0, ks_low)
R0_QTSH_low, R1_QTSH_low, R2_QTSH_low, T0_QTSH_low, T1_QTSH_low, T2_QTSH_low = get_RT_MQC3_low(0, 1, 0, ks_low)
R0_QD_low, R1_QD_low, R2_QD_low, T0_QD_low, T1_QD_low, T2_QD_low = get_RT_DVR3_low(0, 0, ks_DVR_low)

In [None]:
fig, ax = plt.subplots(ncols=3, figsize=(27,9))

## For high k, transmission
ax[0].margins(x=0)
ax[0].xaxis.set_tick_params(labelsize=50)
ax[0].yaxis.set_tick_params(labelsize=50)
ax[0].yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
ax[0].set_xticks([5, 10, 20])
ax[0].set_yticks([1.0, 0.8, 0.2, 0.0])
ax[0].set_ylabel(r"transmission $T_{n}$", fontsize=50)
ax[0].set_xlabel(r"$\hbar k_{0}$ (a.u.)", fontsize=50)

# T0
ax[0].plot(ks_DVR, T0_QD, linewidth=10, color="k", label="DVR", zorder=-10)
ax[0].scatter(ks, T0_FSSH, marker="v", color=colors[clrs_index[0]], s=150, clip_on=False, label="FSSH")
ax[0].scatter(ks, T0_QTSH, marker="v", color=colors[clrs_index[1]], s=150, clip_on=False, label="QTSH")
ax[0].scatter(ks, T0_QTSH_SDM, marker="o", color=colors[clrs_index[2]], s=150, clip_on=False, label="QTSH-SDM")
ax[0].scatter(ks, T0_QTSH_XF, marker="o", color=colors[clrs_index[3]], s=150, clip_on=False, label="QTSH-XF")

ax[0].legend(handlelength=0.7, frameon=False, fontsize=50, labelspacing=0.3, loc="center")

ax[0].annotate(r"$n=0$", (15, 0.85), fontsize=50 )

# T1
ax[0].plot(ks_DVR, T1_QD, linewidth=10, color="k", label="DVR", zorder=-10)
ax[0].scatter(ks, T1_FSSH, marker="v", color=colors[clrs_index[0]], s=150, clip_on=False, label="FSSH")
ax[0].scatter(ks, T1_QTSH, marker="v", color=colors[clrs_index[1]], s=150, clip_on=False, label="QTSH")
ax[0].scatter(ks, T1_QTSH_SDM, marker="o", color=colors[clrs_index[2]], s=150, clip_on=False, label="QTSH-SDM")
ax[0].scatter(ks, T1_QTSH_XF, marker="o", color=colors[clrs_index[3]], s=150, clip_on=False, label="QTSH-XF")

ax[0].annotate(r"$n=1$", (15, 0.08), fontsize=50 )

## For lower k, transmission
ax[1].margins(x=0)
ax[1].xaxis.set_tick_params(labelsize=50)
ax[1].yaxis.set_tick_params(labelsize=50)
ax[1].yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
ax[1].set_yticks([0.8, 0.2, 0.0])
ax[1].set_ylabel(r"transmission $T_{n}$", fontsize=50)
ax[1].set_xlabel(r"$\hbar k_{0}$ (a.u.)", fontsize=50)

# T0
ax[1].plot(ks_DVR_low, T0_QD_low, linewidth=10, color="k", label="DVR", zorder=-10)
ax[1].scatter(ks_low, T0_FSSH_low, marker="v", color=colors[clrs_index[0]], s=150, clip_on=False, label="FSSH")
ax[1].scatter(ks_low, T0_QTSH_low, marker="v", color=colors[clrs_index[1]], s=150, clip_on=False, label="QTSH")

ax[1].legend(handlelength=0.7, frameon=False, fontsize=50, labelspacing=0.3, loc="center")

ax[1].annotate(r"$n=0$", (4.35, 0.67), fontsize=50 )

# T1
ax[1].plot(ks_DVR_low, T1_QD_low, linewidth=10, color="k", label="DVR", zorder=-10)
ax[1].scatter(ks_low, T1_FSSH_low, marker="v", color=colors[clrs_index[0]], s=150, clip_on=False, label="FSSH")
ax[1].scatter(ks_low, T1_QTSH_low, marker="v", color=colors[clrs_index[1]], s=150, clip_on=False, label="QTSH")

ax[1].annotate(r"$n=1$", (4.35, 0.03), fontsize=50 )

## For lower k, reflection
ax[2].margins(x=0)
ax[2].xaxis.set_tick_params(labelsize=50)
ax[2].yaxis.set_tick_params(labelsize=50)
ax[2].yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
ax[2].set_yticks([0.2, 0.0])
ax[2].set_ylabel(r"reflection $R_{n}$", fontsize=50)
ax[2].set_xlabel(r"$\hbar k_{0}$ (a.u.)", fontsize=50)

# R0
ax[2].plot(ks_DVR_low, R0_QD_low, linewidth=10, color="k", label="DVR", zorder=-10)
ax[2].scatter(ks_low, R0_FSSH_low, marker="v", color=colors[clrs_index[0]], s=150, clip_on=False, label="FSSH")
ax[2].scatter(ks_low, R0_QTSH_low, marker="v", color=colors[clrs_index[1]], s=150, clip_on=False, label="QTSH")

ax[2].annotate(r"$n=0$", (3.6, 0.1), fontsize=50 )

fig.tight_layout()

#plt.savefig("fig2_scattering.png", dpi=300)

## Superexchange: dynamics evolution

In [None]:
# Compute the canonical momentum
def get_p_cano(file):

    with h5py.File(file, 'r') as f:
        p = np.array(f["p/data"][:,:,0])
        C_adi = np.array(f["Cadi/data"][:,:,:])
        dc1_adi = np.array(f["dc1_adi/data"][:,:,:,:,:])
    
    p_cano = np.zeros(p.shape)
    p_cano = np.copy(p)

    ntraj = p.shape[1]
    nst = C_adi.shape[2]

    for itraj in range(ntraj):
        for ist in range(nst):
            for jst in range(nst):
                if(ist>=jst): continue
                p_cano[:,itraj] += 2* (C_adi[:, itraj, ist] * C_adi[:, itraj, jst].conj()).imag * dc1_adi.real[:,itraj,0, ist, jst]
    
    return p_cano

In [None]:
# Compute the diagonal and coherence energy
def get_E_cano(file):

    with h5py.File(file, 'r') as f:
        time = np.array(f["time/data"])
        Epot = np.array(f["Epot_ave/data"])
        Ekin = np.array(f["Ekin_ave/data"])
        Etot = np.array(f["Etot_ave/data"])
        states = np.array(f["states/data"])
        C_adi = np.array(f["Cadi/data"][:,:,:])
        hvib_adi = np.array(f["hvib_adi/data"][:,:,:,:])
        dc1_adi = np.array(f["dc1_adi/data"][:,:,:,:,:])
    
    p_cano = get_p_cano(file)

    Etot_cano = np.zeros(Etot.shape)
    Etot_cano = np.copy(Etot_cano)

    Ediag = np.zeros(Etot_cano.shape)
    Ecoh = np.zeros(Etot_cano.shape)

    ntraj = C_adi.shape[1]
    nst = C_adi.shape[2]

    M = 2000.

    for itraj in range(ntraj):
        # Diagonal part
        Ediag[:] += 0.5 * p_cano[:,itraj] * p_cano[:,itraj] / M 
        for ist in range(nst):
            delta = (ist == states[:,itraj]).astype(float)
            Ediag[:] += delta* hvib_adi.real[:, itraj, ist, ist]
        
        # Coherence part
        for ist in range(nst):
            for jst in range(nst):
                if(ist>=jst): continue
                Ecoh[:] -= 2* (C_adi[:, itraj, ist] * C_adi[:, itraj, jst].conj()).imag * (-hvib_adi.imag[:, itraj, ist, jst])

    Ediag /= ntraj
    Ecoh /= ntraj

    Etot_cano = Ediag + Ecoh
    
    return Ediag, Ecoh, Etot_cano

In [None]:
def comp_epc(file, file_DVR, pos_label=0, broken = (-0.1, 0.3, 3.8, 4.2), title="QTSH", onlegend=True, name_out="res.png"):

    with h5py.File(file_DVR, 'r') as f:
        pop_adi_DVR = np.array(f["pop_adi/data"])
        coh_adi_DVR = np.array(f["coherence_adi/data"])
    
    with h5py.File(file) as f:
        time = np.array(f["time/data"])/41.0
        se_pop = np.array(f["se_pop_adi/data"])
        sh_pop = np.array(f["sh_pop_adi/data"])
        coh_adi = np.array(f["coherence_adi/data"])
        Etot = np.array(f["Etot_ave/data"])
    
    nst = coh_adi.shape[1]
    
    Ediag = np.loadtxt(name_out+"_edig.txt")
    Ecoh = np.loadtxt(name_out+"_ecoh.txt")
    Etot_cano = np.loadtxt(name_out+"_etot_cano.txt")
    
    fig, axs = plt.subplots(nrows=4, figsize=(10,16), gridspec_kw={'height_ratios': [0.8, 0.8, 2, 2]}) # axs[0], axs[1] is for broken axes
    
    plt.suptitle(title, fontsize=30)
    axs[0].margins(x=0)
    axs[0].set_ylabel(r"energy (mHa)")
    axs[0].yaxis.set_label_coords(-0.1, pos_label)
    axs[0].yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
    
    axs[0].plot(time, Ediag*1000, color=colors[clrs_index[0]], linewidth=5, label=r"$E_{diag}$")
    axs[0].plot(time, Ecoh*1000, color=colors[clrs_index[1]], linewidth=5, label=r"$E_{coh}$")
    axs[0].plot(time, Etot_cano*1000, color=colors[clrs_index[2]], linewidth=5, label=r"$E_{tot}$")
    
    if(onlegend):
        axs[0].legend(frameon=False, ncol=3, fontsize=30, loc='upper right', bbox_to_anchor=(1.0, 0.80), 
                      handlelength=1.5, columnspacing=1.0)

    axs[1].margins(x=0)
    axs[1].yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
    
    axs[1].plot(time, Ediag*1000, color=colors[clrs_index[0]], linewidth=5, label=r"$E_{diag}$")
    axs[1].plot(time, Ecoh*1000, color=colors[clrs_index[1]], linewidth=5, label=r"$E_{coh}$")
    axs[1].plot(time, Etot_cano*1000, color=colors[clrs_index[2]], linewidth=5, label=r"$E_{tot}$")
        
    # breaking
    diff = (broken[3]- broken[2])*0.1
    axs[0].set_ylim(broken[2]-diff,broken[3]+diff)
    
    diff = (broken[1]- broken[0])*0.1
    axs[1].set_ylim(broken[0]-diff, broken[1]+diff)
    
    # hide the spines between axs[0] and axs[1]
    axs[0].spines.bottom.set_visible(False)
    axs[1].spines.top.set_visible(False)
    axs[0].xaxis.tick_top()
    axs[0].tick_params(labeltop=False)  # don't put tick labels at the top
    axs[1].xaxis.tick_bottom()
    
    d = .5  # proportion of vertical to horizontal extent of the slanted line
    kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12,
                  linestyle="none", color='k', mec='k', mew=1, clip_on=False)
    axs[0].plot([0, 1], [0, 0], transform=axs[0].transAxes, **kwargs)
    axs[1].plot([0, 1], [1, 1], transform=axs[1].transAxes, **kwargs)
    
    # POP
    axs[2].margins(x=0)
    axs[2].set_ylabel("population")
    axs[2].yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
    axs[2].set_yticks([0.0, 0.5, 1.0])
    axs[2].plot([],[], color="k", linewidth=5, label="DVR")
    
    for ist in range(nst):
        axs[2].plot(time, pop_adi_DVR[:,ist], linewidth=7, color="k")
        axs[2].plot(time, sh_pop[:,ist], linewidth=5, color=colors[clrs_index[ist]], label=rF"$P^{{SH}}_{{{ist}}}$")
    
    for ist in range(nst):
        axs[2].plot(time, se_pop[:,ist], linestyle="dotted", linewidth=5, color=colors[clrs_index[ist]], 
                    label=rF"$P^{{SE}}_{{{ist}}}$")
    
    if(onlegend):
        axs[2].legend(ncol=2, frameon=False, fontsize=30, loc='upper right', labelspacing=0.6 , bbox_to_anchor=(1,1.01))
    
    # COH
    axs[3].plot([],[], color="k", label="DVR", linewidth=5)
    
    cnt = -1
    for ist in range(nst):
        for jst in range(nst):
            if ist < jst:
                cnt += 1
                axs[3].plot(time, coh_adi_DVR[:,ist,jst], color="k", linewidth=7)
                axs[3].plot(time, coh_adi[:,ist,jst], color=colors[clrs_index[cnt]], linewidth=5, 
                            label=rF"$\langle |\rho_{{{ist}{jst}}}|^2 \rangle$")
    
    axs[3].margins(x=0)
    axs[3].yaxis.set_major_formatter(FormatStrFormatter('%.1f'))
    axs[3].set_yticks([0.0, 0.1, 0.2])
    axs[3].set_ylabel("coherence")
    axs[3].set_xlabel("time (fs)")
    
    if(onlegend):
        axs[3].legend(ncol=2, frameon=False, loc="center right", fontsize=30)

    fig.tight_layout()
    #plt.savefig(name_out + ".png", dpi=300)

In [None]:
_model_indx, _method_indx, _icond_indx, _k = 0, 3, 0, 4.0

if _method_indx == 2:
    dir_name_suff = F"-eps0_0"
    dir_name_suff = dir_name_suff.replace(".", "_")
elif _method_indx == 3:
    dir_name_suff = F"-w0_1"
    dir_name_suff = dir_name_suff.replace(".", "_")
else:
    dir_name_suff = ""
    
file = F"superexchange/TSH/model{_model_indx}-method{_method_indx}-icond{_icond_indx}-p{_k:.1f}" + dir_name_suff
file = file.replace(".", "_")
file += "/mem_data.hdf"

file_DVR = F"superexchange/DVR/exact-model{_model_indx}-icond{_icond_indx}-p{_k:.1f}"
file_DVR = file_DVR.replace(".", "_")
file_DVR += "/data.hdf"

comp_epc(file, file_DVR, pos_label=-0.2, title=rF"$\hbar k_{0}=4.0$ a.u.", onlegend=False, name_out="fig4_qtsh_4")

In [None]:
_model_indx, _method_indx, _icond_indx, _k = 0, 3, 0, 5.0

if _method_indx == 2:
    dir_name_suff = F"-eps0_1"
    dir_name_suff = dir_name_suff.replace(".", "_")
elif _method_indx == 3:
    dir_name_suff = F"-w0_1"
    dir_name_suff = dir_name_suff.replace(".", "_")
else:
    dir_name_suff = ""
    
file = F"superexchange/TSH/model{_model_indx}-method{_method_indx}-icond{_icond_indx}-p{_k:.1f}" + dir_name_suff
file = file.replace(".", "_")
file += "/mem_data.hdf"

file_DVR = F"superexchange/DVR/exact-model{_model_indx}-icond{_icond_indx}-p{_k:.1f}"
file_DVR = file_DVR.replace(".", "_")
file_DVR += "/data.hdf"

comp_epc(file, file_DVR, pos_label=-0.25, broken = (-0.1, 0.5, 5.8, 6.5), title=rF"$\hbar k_{0}=5.0$ a.u.", onlegend=True, name_out="fig4_qtsh_5")

## Superexchange: linear correlation

In [None]:
_model_indx = 0
_method_indx = 1
_icond_indx = 0

ks = [3.0 + 0.2*ik for ik in range(11)]

tavg_de = [] # RMSE
tavg_dP = []

for i, k in enumerate(ks):
    dir_dyn = F"superexchange/TSH/model{_model_indx}-method{_method_indx}-icond{_icond_indx}-p{k:.1f}"
    dir_dyn = dir_dyn.replace(".", "_")
        
    with h5py.File(dir_dyn + "/mem_data.hdf", "r") as f:
        etot = np.array(f["Etot_ave/data"])
        se_pop = np.array(f["se_pop_adi/data"])
        sh_pop = np.array(f["sh_pop_adi/data"])

    etot0 = etot[0]
    etot -= etot0 # E - E(0)
    etot /= etot0
    etot = etot * etot
    tavg_de.append( np.sqrt(np.average(etot)) )

    dP = np.zeros(se_pop.shape[0])
    for ist in range(3):
        dP += (se_pop[:,ist] - sh_pop[:,ist]) **2
    tavg_dP.append( np.sqrt(np.average(dP)/3) )

coefficients = np.polyfit(tavg_de, tavg_dP, 1)  # Linear fit (degree=1)
slope, intercept = coefficients

fit_line = slope * np.array(tavg_de) + intercept

# Calculate R^2
y_mean = np.mean(tavg_dP)
ss_total = np.sum((tavg_dP - y_mean) ** 2)
ss_residual = np.sum((tavg_dP - fit_line) ** 2)
r2 = 1 - (ss_residual / ss_total)

x = np.linspace(0.0, 0.2, 100)
fit_line_for_plot = slope * np.array(x) + intercept

In [None]:
fig, ax = plt.subplots(figsize = (10, 9))

ax.xaxis.set_major_formatter(FormatStrFormatter('%.2f'))
ax.yaxis.set_major_formatter(FormatStrFormatter('%.2f'))

ax.set_xlabel(r"RMSE of $[E_{tot}-E_{tot}(0)]/E_{tot}(0)$")
ax.set_ylabel(r"RMSE of $P^{SE}-P^{SH}$")

ax.set_xlim([0.0, tavg_de[0]+0.01])
ax.set_ylim([tavg_dP[-2]-0.015, tavg_dP[0]+0.015])

#ax.margins(x=0, y=0)
ax.plot(x, fit_line_for_plot, color='k', lw=5, linestyle="dashed", label=r'Linear Fit, $R^{2}=$' + f'{r2:.2f}')
ax.scatter(tavg_de, tavg_dP, color=colors[clrs_index[0]], label="QTSH")

ax.legend(frameon=False, handlelength=1.5, fontsize=30)

for i, k in enumerate(ks):
    if k == 4.8:
        ax.annotate(F"{k:.1f}", (tavg_de[i]-0.002, tavg_dP[i]-0.012), fontsize=30, color=colors[clrs_index[0]])
    elif k == 4.6:
        ax.annotate(F"{k:.1f}", (tavg_de[i]-0.01, tavg_dP[i]+0.003), fontsize=30, color=colors[clrs_index[0]])
    elif k == 5.0:
        ax.annotate(F"{k:.1f}", (tavg_de[i]+0.002, tavg_dP[i]-0.004), fontsize=30, color=colors[clrs_index[0]])
    elif k == 3.0:
        ax.annotate(F"{k:.1f}", (tavg_de[i]-0.002, tavg_dP[i]+0.002), fontsize=30, color=colors[clrs_index[0]])
    elif k in [4.2, 4.0]:
        ax.annotate(F"{k:.1f}", (tavg_de[i]-0.005, tavg_dP[i]+0.004), fontsize=30, color=colors[clrs_index[0]])
    elif k in [3.6, 3.4]:
        ax.annotate(F"{k:.1f}", (tavg_de[i]-0.0025, tavg_dP[i]+0.004), fontsize=30, color=colors[clrs_index[0]])
    elif k == 3.2:
        ax.annotate(F"{k:.1f}", (tavg_de[i]+0.001, tavg_dP[i]-0.006), fontsize=30, color=colors[clrs_index[0]])
    else:
        ax.annotate(F"{k:.1f}", (tavg_de[i]+0.001, tavg_dP[i]+0.001), fontsize=30, color=colors[clrs_index[0]])

fig.tight_layout()
        
#plt.savefig("fig5_pE.png", dpi=300)

## Superexchange: QTSH-SDM

In [None]:
_model_indx, _method_indx, _icond_indx, _k = 0, 2, 0, 4.0

if _method_indx == 2:
    dir_name_suff = F"-eps0_0"
    dir_name_suff = dir_name_suff.replace(".", "_")
elif _method_indx == 3:
    dir_name_suff = F"-w0_1"
    dir_name_suff = dir_name_suff.replace(".", "_")
else:
    dir_name_suff = ""
    
file = F"superexchange/TSH/model{_model_indx}-method{_method_indx}-icond{_icond_indx}-p{_k:.1f}" + dir_name_suff
file = file.replace(".", "_")
file += "/mem_data.hdf"

file_DVR = F"superexchange/DVR/exact-model{_model_indx}-icond{_icond_indx}-p{_k:.1f}"
file_DVR = file_DVR.replace(".", "_")
file_DVR += "/data.hdf"

comp_epc(file, file_DVR, broken = (-0.1, 0.3, 3.8, 4.1), title=r"$C=0.01$ Ha", onlegend=False, name_out="fig6_qtsh_sdm_0_01_4")

In [None]:
_model_indx, _method_indx, _icond_indx, _k = 0, 2, 0, 4.0

if _method_indx == 2:
    dir_name_suff = F"-eps0_1"
    dir_name_suff = dir_name_suff.replace(".", "_")
elif _method_indx == 3:
    dir_name_suff = F"-w0_1"
    dir_name_suff = dir_name_suff.replace(".", "_")
else:
    dir_name_suff = ""
    
file = F"superexchange/TSH/model{_model_indx}-method{_method_indx}-icond{_icond_indx}-p{_k:.1f}" + dir_name_suff
file = file.replace(".", "_")
file += "/mem_data.hdf"

file_DVR = F"superexchange/DVR/exact-model{_model_indx}-icond{_icond_indx}-p{_k:.1f}"
file_DVR = file_DVR.replace(".", "_")
file_DVR += "/data.hdf"

comp_epc(file, file_DVR, broken = (-0.1, 0.3, 3.8, 4.1), title=r"$C=0.1$ Ha", onlegend=False, name_out="fig6_qtsh_sdm_0_1_4")

In [None]:
_model_indx, _method_indx, _icond_indx, _k = 0, 2, 0, 4.0

if _method_indx == 2:
    dir_name_suff = F"-eps2_0"
    dir_name_suff = dir_name_suff.replace(".", "_")
elif _method_indx == 3:
    dir_name_suff = F"-w0_1"
    dir_name_suff = dir_name_suff.replace(".", "_")
else:
    dir_name_suff = ""
    
file = F"superexchange/TSH/model{_model_indx}-method{_method_indx}-icond{_icond_indx}-p{_k:.1f}" + dir_name_suff
file = file.replace(".", "_")
file += "/mem_data.hdf"

file_DVR = F"superexchange/DVR/exact-model{_model_indx}-icond{_icond_indx}-p{_k:.1f}"
file_DVR = file_DVR.replace(".", "_")
file_DVR += "/data.hdf"

comp_epc(file, file_DVR, broken = (-0.1, 0.3, 3.8, 4.1), title=r"$C=2.0$ Ha", onlegend=False, name_out="fig6_qtsh_sdm_2_0_4")

## Superexchange: QTSH-IDA

In [None]:
_model_indx, _method_indx, _icond_indx, _k = 0, 5, 0, 4.0

if _method_indx == 2:
    dir_name_suff = F"-eps0_0"
    dir_name_suff = dir_name_suff.replace(".", "_")
elif _method_indx == 3:
    dir_name_suff = F"-w0_1"
    dir_name_suff = dir_name_suff.replace(".", "_")
else:
    dir_name_suff = ""
    
file = F"superexchange/TSH/model{_model_indx}-method{_method_indx}-icond{_icond_indx}-p{_k:.1f}" + dir_name_suff
file = file.replace(".", "_")
file += "/mem_data.hdf"

file_DVR = F"superexchange/DVR/exact-model{_model_indx}-icond{_icond_indx}-p{_k:.1f}"
file_DVR = file_DVR.replace(".", "_")
file_DVR += "/data.hdf"

comp_epc(file, file_DVR, broken = (-0.2, 0.1, 4.0, 4.6), pos_label=-0.25, title=rF"$\hbar k_{0}=4.0$ a.u.", onlegend=True, name_out="s_qtsh_ida_4")

In [None]:
_model_indx, _method_indx, _icond_indx, _k = 0, 5, 0, 5.0

if _method_indx == 2:
    dir_name_suff = F"-eps0_0"
    dir_name_suff = dir_name_suff.replace(".", "_")
elif _method_indx == 3:
    dir_name_suff = F"-w0_1"
    dir_name_suff = dir_name_suff.replace(".", "_")
else:
    dir_name_suff = ""
    
file = F"superexchange/TSH/model{_model_indx}-method{_method_indx}-icond{_icond_indx}-p{_k:.1f}" + dir_name_suff
file = file.replace(".", "_")
file += "/mem_data.hdf"

file_DVR = F"superexchange/DVR/exact-model{_model_indx}-icond{_icond_indx}-p{_k:.1f}"
file_DVR = file_DVR.replace(".", "_")
file_DVR += "/data.hdf"

comp_epc(file, file_DVR, broken = (-0.2, 0.1, 4.0, 4.6), pos_label=-0.25, title=rF"$\hbar k_{0}=5.0$ a.u.", onlegend=True, name_out="s_qtsh_ida_5")

## Superexchange: QTSH-XF

In [None]:
_model_indx, _method_indx, _icond_indx, _k = 0, 3, 0, 5.0

if _method_indx == 5:
    dir_name_suff = F"-eps0_0"
    dir_name_suff = dir_name_suff.replace(".", "_")
elif _method_indx == 6:
    dir_name_suff = F"-w0_1"
    dir_name_suff = dir_name_suff.replace(".", "_")
else:
    dir_name_suff = ""
    
file = F"superexchange/TSH/model{_model_indx}-method{_method_indx}-icond{_icond_indx}-p{_k:.1f}" + dir_name_suff
file = file.replace(".", "_")
file += "/mem_data.hdf"

file_DVR = F"superexchange/DVR/exact-model{_model_indx}-icond{_icond_indx}-p{_k:.1f}"
file_DVR = file_DVR.replace(".", "_")
file_DVR += "/data.hdf"

comp_epc(file, file_DVR, pos_label=-0.3, broken = (-0.1, 0.5, 6.0, 6.5), title=r"$\sigma=0.1$ Bohr", onlegend=True, name_out="fig7_qtsh_xf_0_1_5")

In [None]:
_model_indx, _method_indx, _icond_indx, _k = 0, 3, 0, 5.0

if _method_indx == 5:
    dir_name_suff = F"-eps0_0"
    dir_name_suff = dir_name_suff.replace(".", "_")
elif _method_indx == 6:
    dir_name_suff = F"-w1_0"
    dir_name_suff = dir_name_suff.replace(".", "_")
else:
    dir_name_suff = ""
    
file = F"superexchange/TSH/model{_model_indx}-method{_method_indx}-icond{_icond_indx}-p{_k:.1f}" + dir_name_suff
file = file.replace(".", "_")
file += "/mem_data.hdf"

file_DVR = F"superexchange/DVR/exact-model{_model_indx}-icond{_icond_indx}-p{_k:.1f}"
file_DVR = file_DVR.replace(".", "_")
file_DVR += "/data.hdf"

comp_epc(file, file_DVR, pos_label=-0.3, broken = (-0.1, 0.5, 6.0, 6.5), title=r"$\sigma=1.0$ Bohr", onlegend=True, name_out="fig7_qtsh_xf_1_0_5")

In [None]:
_model_indx, _method_indx, _icond_indx, _k = 0, 3, 0, 5.0

if _method_indx == 5:
    dir_name_suff = F"-eps0_0"
    dir_name_suff = dir_name_suff.replace(".", "_")
elif _method_indx == 6:
    dir_name_suff = F"-w5_0"
    dir_name_suff = dir_name_suff.replace(".", "_")
else:
    dir_name_suff = ""
    
file = F"superexchange/TSH/model{_model_indx}-method{_method_indx}-icond{_icond_indx}-p{_k:.1f}" + dir_name_suff
file = file.replace(".", "_")
file += "/mem_data.hdf"

file_DVR = F"superexchange/DVR/exact-model{_model_indx}-icond{_icond_indx}-p{_k:.1f}"
file_DVR = file_DVR.replace(".", "_")
file_DVR += "/data.hdf"

comp_epc(file, file_DVR, pos_label=-0.3, broken = (-0.1, 0.5, 6.0, 6.5), title=r"$\sigma=5.0$ Bohr", onlegend=True, name_out="fig7_qtsh_xf_5_0_5")

## Phenol

In [None]:
# Compute the dissociation probability from DVR
def plot_disso_DVR(pref):
    nstep = 4001
    freq = 100
    
    nst = 3
    r_cut = 2.6*units.Angst
    nx, ny = 4096, 4096
    dx, dy = 0.02, 0.03

    n_snapshots = nstep // freq + 1
    P_disso = np.zeros(n_snapshots)

    for istep in range(0, nstep, freq):
        wfc_dens = None

        for ist in range(nst):
            filename = f"{pref}/wfcr_snap_{istep // freq}_state_{ist}_dens_rep_1"
            data_exact = np.loadtxt(filename)

            if ist == 0:
                wfc_dens = data_exact[:, 2]
            else:
                wfc_dens += data_exact[:, 2]

        # Reshape data
        x = data_exact[:, 0].reshape(nx, ny)
        wfc_dens = wfc_dens.reshape(nx, ny)

        # Calculate dissociation probability for the snapshot
        mask = x > r_cut
        P_disso[istep // freq] = np.sum(wfc_dens[mask]) * dx * dy

    # Save and plot results
    time = np.arange(0, nstep, freq)
    np.savetxt("disso.txt", P_disso)


In [None]:
plot_disso_DVR("phenol/DVR/wfc2-icond0-p15_0")

In [None]:
def phenol_all(dirs, labels, title="Populations", onlegend=True, name_out="res.png"):

    with h5py.File("phenol/DVR/exact-model2-icond0-p15_0/data.hdf", 'r') as f:
        time_DVR = f["time/data"][:]/41.0
        pop_adi_DVR = np.array(f["pop_adi/data"])
        coh_adi_DVR = np.array(f["coherence_adi/data"])
    
    MQC_se_pop = []
    MQC_sh_pop = []
    MQC_coh = []
    MQC_Etot = []
    
    for idir in dirs:
        with h5py.File(idir + "/mem_data.hdf") as f:
            time = np.array(f["time/data"])/41.0
            se_pop = np.array(f["se_pop_adi/data"])
            sh_pop = np.array(f["sh_pop_adi/data"])
            coh_adi = np.array(f["coherence_adi/data"])
            Etot = np.array(f["Etot_ave/data"])
    
        Etot0 = Etot[0]
        Etot -= Etot0
        Etot = np.abs(Etot)/Etot0
    
        MQC_se_pop.append(se_pop)
        MQC_sh_pop.append(sh_pop)
        MQC_coh.append(coh_adi)
        MQC_Etot.append(Etot)
    
    nsteps = se_pop.shape[0]
    nst = se_pop.shape[1]
    
    fig, axs = plt.subplots(nrows=4, ncols=2, figsize=(12,15) )
    
    ## Energy conservation
    axs[0][0].margins(x=0)
    axs[0][0].set_ylabel(r"$|\Delta E_{tot}|/E_{tot}(0)$")
    axs[0][0].yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    
    for i, idir in enumerate(dirs):  
        axs[0][0].plot(time, MQC_Etot[i], linewidth=5, color=colors[clrs_index[i]])
    
    # For label
    for i, idir in enumerate(dirs):
        axs[0][0].plot([],[], linewidth=5, color=colors[clrs_index[i]], label=labels[i])
    axs[0][0].plot([],[], linewidth=5, color='k', label="DVR")
    
    ## POP
    for ist in range(nst):
        axs[ist+1][0].margins(x=0)
        axs[ist+1][0].yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
        # DVR
        axs[ist+1][0].plot(time_DVR, pop_adi_DVR[:,ist], linewidth=7, color="k")
        
        # MQC
        for i, idir in enumerate(dirs):
            axs[ist+1][0].plot(time, MQC_sh_pop[i][:,ist], linewidth=5, color=colors[clrs_index[i]])
        
        for i, idir in enumerate(dirs):
            axs[ist+1][0].plot(time, MQC_se_pop[i][:,ist], linewidth=5, 
                               linestyle='dotted', color=colors[clrs_index[i]], zorder=-10)
                    
        axs[ist+1][0].set_ylabel(rF"$\langle  \rho_{{{ist}{ist}}} \rangle $")
        
    axs[-1][0].set_xlabel("time (fs)")

    ## Dissociation Prob.
    r_cut = 2.6*units.Angst
        
    P_disso_DVR = np.loadtxt("disso.txt")
    time_DVR2 = np.array([100*x for x in range(41)])/41
    
    axs[0][1].plot(time_DVR2, P_disso_DVR, lw=7, color='k')
    
    axs[0][1].margins(x=0)
    for i, idir in enumerate(dirs):
        
        P_disso = np.zeros(nsteps)    
        with h5py.File(idir + "/mem_data.hdf") as f:
            time = np.array(f["time/data"])/41
            states = np.array(f["states/data"])
            q = np.array(f["q/data"][:,:,:])
 
        ntraj = q.shape[1]
        
        # Calculate dissociation probabilities
        P_disso = (q[:nsteps, :, 0] > r_cut).sum(axis=1) / ntraj

        axs[0][1].plot(time, P_disso, lw=5, color=colors[clrs_index[i]])
    
    axs[0][1].set_ylabel(r"$P^{disso}$")
    axs[0][1].yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
    
    ## COH
    count = -1
    for ist in [1,2,0]:
        for jst in [2,1,0]:
            if ist>=jst: continue
            count+= 1
            
            axs[count+1][1].yaxis.set_major_formatter(FormatStrFormatter('%.2f'))
            
            axs[count+1][1].set_ylabel(rF"$\langle | \rho_{{{ist}{jst}}}|^{2} \rangle $")
            
            axs[count+1][1].margins(x=0)
            axs[count+1][1].plot(time_DVR, coh_adi_DVR[:,ist, jst], linewidth=7, color="k")
            
            for i, idir in enumerate(dirs):
                axs[count+1][1].plot(time, MQC_coh[i][:,ist, jst], linewidth=5, color=colors[clrs_index[i]])
    
    #if(onlegend):
    #    axs[1][1].legend(frameon=False, fontsize=10)
    axs[-1][1].set_xlabel("time (fs)")
    
    for i in range(4):
        for j in range(2):
            axs[i][j].set_xticks([0, 20, 40, 60, 80])
    
    fig.legend(loc='upper center', ncol=3, fontsize=30, frameon=False, bbox_to_anchor=(0.565,1.01))
    fig.tight_layout()
    fig.subplots_adjust(top=0.9)  # Move subplots down to make space for the legend
    
    plt.savefig(name_out, dpi=300)

In [None]:
dirs = ["model2-method0-icond0-p15_0", "model2-method1-icond0-p15_0", 
        "model2-method2-icond0-p15_0-eps1_0", "model2-method4-icond0-p15_0"]

dirs = ["phenol/TSH/" + idir for idir in dirs]

labels=["FSSH", "QTSH", "QTSH-SDM", "QTSH-XF"]

phenol_all(dirs, labels, title="Populations", onlegend=True, name_out="fig8_phenol.png")

## DVR movie

In [None]:
def read_DVR_wp(_model_indx, _icond_indx, _k, _snap):
    dir_name = F"superexchange/DVR/wfc{_model_indx}-icond{_icond_indx}-p{_k:.1f}".replace(".","_")
    fn = dir_name + F"/wfcr_snap_{_snap}_dens_rep_1"
    which_cols = [0]
    nstates = 3
    for state in [x for x in range(nstates)]:
        which_cols.append(1+state)
    data_exact = data_read.get_data_from_file2(fn, which_cols)
    ngrids = len(data_exact[0])
    
    #find necessary region
    grids = np.array(data_exact[0])
    # Get wfc array
    wfcs = []
    for ist in range(nstates):
        wfcs.append(np.array(data_exact[1+ist]))
    wfcs = np.array(wfcs)
    
    return grids, wfcs

In [None]:
%matplotlib notebook

# Create the figure and axes
fig, ax = plt.subplots(figsize=(10,6))
plt.subplots_adjust(left=0.15, bottom=0.25)  # Adjust layout to make space for slider

# Initial value of k
k = 5.0
snap_init = 0

grids, wfcs = read_DVR_wp(1, 0, k, snap_init)
lines = [ax.plot(grids, wfcs[ist], color=colors[clrs_index[ist]], label=f"WP {ist}")[0] for ist in range(3)]

# Add a legend
ax.legend(frameon=False)
ax.set_xlim([-50,50])   

title = ax.text(0.02,0.90, "", fontsize=20, transform=ax.transAxes)

# Update function for animation
def update(frame):
    _snap = frame
    grids, wfcs = read_DVR_wp(0, 0, k, _snap)
    
    # Update each line in the plot with new data
    for ist, line in enumerate(lines):
        line.set_ydata(wfcs[ist])
    
    ax.set_xlabel(r"$R$ (Bohr)")
    ax.set_ylabel("nuclear density", x=3.1)
    ax.set_title(r"Superexchange, $\hbar k_{0}=$ 5.0 a.u.")
    
    title.set_text(rF"$t=$ {_snap*100*10/41:.1f} fs")
    
    ax.margins(x=0)
    return lines, title

# Create the animation
ani = FuncAnimation(fig, update, frames=range(41), interval=200, blit=True)

movie_name = F'DVR_superexchange_{k:.0f}'
ani.save(movie_name.replace(".","_") + ".gif", fps=80, dpi=300)