In [None]:
import sys
import cmath
import math
import os
import h5py
import matplotlib.pyplot as plt   # plots
from matplotlib.ticker import MaxNLocator
import numpy as np
import time
import warnings

from liblibra_core import *
import util.libutil as comn
from libra_py import units
import models
import libra_py.dynamics.tsh.compute as tsh_dynamics
import libra_py.dynamics.tsh.plot as tsh_dynamics_plot
import libra_py.data_savers as data_savers

from recipes import fssh, fssh2, fssh3, gfsh

import libra_py.models.GLVC as GLVC


warnings.filterwarnings('ignore')

colors = {}
colors.update({"11": "#8b1a0e"})  # red       
colors.update({"12": "#FF4500"})  # orangered 
colors.update({"13": "#B22222"})  # firebrick 
colors.update({"14": "#DC143C"})  # crimson   
colors.update({"21": "#5e9c36"})  # green
colors.update({"22": "#006400"})  # darkgreen  
colors.update({"23": "#228B22"})  # forestgreen
colors.update({"24": "#808000"})  # olive      
colors.update({"31": "#8A2BE2"})  # blueviolet
colors.update({"32": "#00008B"})  # darkblue  
colors.update({"41": "#2F4F4F"})  # darkslategray

clrs_index = ["11", "21", "31", "41", "12", "22", "32", "13","23", "14", "24"]


# Check the "decoherence time" for MFSD, SDM, and DISH. 

## Schwartz 1

In [None]:
%matplotlib inline
colors = plt.cm.tab10.colors
colors = list(colors)
colors.append('darkblue')
A_vals = [0.0001, 0.001, 0.01, 0.1, 1.0, 10.0]
# 'MFSD_SCHW2_ntraj_100_iter_47_A_0.01'
labels = [f'$A_k$={A_val}' for A_val in A_vals]
plt.rcParams.update({'font.size': 35, 'axes.linewidth': 3, 
                     'xtick.major.width': 3, 'ytick.major.width': 3, 'lines.linewidth': 6.0})

plt.figure(figsize=(3.21*3*3, 2.41*3*3))
c = 1
for dec_method in ['MFSD','SDM','DISH_REV23']:
    for method in ['FSSH','FSSH2','GFSH']:
        plt.subplot(3,3,c)
        #plt.figure(figsize=(3.21*3, 2.41*3))
        print(dec_method, method)
        for k in range(len(A_vals)):
            #print(k)
            #file = f'main/MFSD_ntraj_100_iter_{i}_A_{A_vals[k]}/mem_data.hdf'
            #if dec_method=='MFSD':
            #    file = f'all_methods/{method.lower()}_{dec_method}_SCHW1_ntraj_2000_iter_0_dt_20.0_A_{A_vals[k]}/mem_data.hdf'
            #else:
            file = f'all_methods/{method.lower()}_{dec_method}_SCHW1_ntraj_2000_iter_0_dt_20.0_A_{A_vals[k]}/mem_data.hdf'
            F = h5py.File(file)
            sh_pop = np.array(F['sh_pop_adi/data'])
            time_vec = F['time/data'][:]*units.au2fs/1000
            F.close()
            #if k==2:
            #    plt.plot(time_vec, np.average(sh_pops, axis=0)[:,1], label=labels[k], color=colors[k], ls='dashed')
            #else:
            plt.plot(time_vec, sh_pop[:,1], label=labels[k], color=colors[k])

        ml_mctdh = np.loadtxt('reference/ML-MCTDH_n32.dat')
        plt.plot(ml_mctdh[:,0]/1000, ml_mctdh[:,1], label='Ref:ML-MCTDH', ls='dashed', color='black')
        plt.xlim(0,50)
        if c==1:
            plt.legend(fontsize=25, ncol=1, loc='upper right')
        plt.ylabel('S$_1$ Population')
        plt.xlabel('Time, ps')
        name = dec_method.replace('_REV23','')
        if method=='FSSH2':
            plt.title(f'FSSH-2 with {name}', fontsize=35)
        elif dec_method=='MFSD':
            plt.title(f'MFSD', fontsize=35)
        else:
            plt.title(f'{method} with {name}', fontsize=35)
        c += 1
plt.suptitle('Schwartz 1 - No SSY', fontsize=45)
plt.tight_layout()
#plt.savefig(f'{name}_{method}_Schw1.jpg', dpi=600)
plt.savefig(f'Schw1_no_SSY.jpg', dpi=600)

# Schwartz 2

In [None]:
%matplotlib inline
colors = plt.cm.tab10.colors
colors = list(colors)
colors.append('darkblue')
A_vals = [0.0001, 0.001, 0.01, 0.1, 1.0, 10.0]
# 'MFSD_SCHW2_ntraj_100_iter_47_A_0.01'
labels = [f'$A_k$={A_val}' for A_val in A_vals]
plt.rcParams.update({'font.size': 35, 'axes.linewidth': 3, 
                     'xtick.major.width': 3, 'ytick.major.width': 3, 'lines.linewidth': 6.0})

plt.figure(figsize=(3.21*3*3, 2.41*3*3))
c = 1
for dec_method in ['MFSD','SDM','DISH_REV23']:
    for method in ['FSSH','FSSH2','GFSH']:
        plt.subplot(3,3,c)
        #plt.figure(figsize=(3.21*3, 2.41*3))
        print(dec_method, method)
        for k in range(len(A_vals)):
            try:
                #print(k)
                #file = f'main/MFSD_ntraj_100_iter_{i}_A_{A_vals[k]}/mem_data.hdf'
                #if dec_method=='MFSD':
                #    file = f'all_methods/{method.lower()}_{dec_method}_SCHW1_ntraj_2000_iter_0_dt_20.0_A_{A_vals[k]}/mem_data.hdf'
                #else:
                file = f'all_methods/{method.lower()}_{dec_method}_SCHW2_ntraj_2000_iter_0_dt_20.0_A_{A_vals[k]}/mem_data.hdf'
                F = h5py.File(file)
                sh_pop = np.array(F['sh_pop_adi/data'])
                time_vec = F['time/data'][:]*units.au2fs/1000
                F.close()
                #if k==2:
                #    plt.plot(time_vec, np.average(sh_pops, axis=0)[:,1], label=labels[k], color=colors[k], ls='dashed')
                #else:
                plt.plot(time_vec, sh_pop[:,1], label=labels[k], color=colors[k])
            except:
                print(file)

        ml_mctdh = np.loadtxt('reference/ML-MCTDH_n32.dat')
        plt.plot(ml_mctdh[:,0]/1000, ml_mctdh[:,1], label='Ref:ML-MCTDH', ls='dashed', color='black')
        plt.xlim(0,50)
        if c==1:
            plt.legend(fontsize=25, ncol=1, loc='upper right')
        plt.ylabel('S$_1$ Population')
        plt.xlabel('Time, ps')
        name = dec_method.replace('_REV23','')
        if method=='FSSH2':
            plt.title(f'FSSH-2 with {name}', fontsize=35)
        elif dec_method=='MFSD':
            plt.title(f'MFSD', fontsize=35)
        else:
            plt.title(f'{method} with {name}', fontsize=35)
        c += 1
plt.suptitle('Schwartz 2 - No SSY', fontsize=45)
plt.tight_layout()
#plt.savefig(f'{name}_{method}_Schw1.jpg', dpi=600)
plt.savefig(f'Schw2_no_SSY.jpg', dpi=600)

## EDC

In [None]:
%matplotlib inline
colors = plt.cm.tab10.colors
colors = list(colors)
colors.append('darkblue')
eps_vals = [0.01, 0.05, 0.1, 0.2, 0.4, 1.0, 5.0, 10.0, 20.0, 40.0, 80.0]
# 'MFSD_SCHW2_ntraj_100_iter_47_A_0.01'

labels = [f'$\\epsilon$={eps_val}' for eps_val in eps_vals]
plt.rcParams.update({'font.size': 35, 'axes.linewidth': 3, 
                     'xtick.major.width': 3, 'ytick.major.width': 3, 'lines.linewidth': 6.0})
plt.figure(figsize=(3.21*3*3, 2.41*3*3))
c = 1
for dec_method in ['MFSD','SDM','DISH_REV23']:
    for method in ['FSSH','FSSH2','GFSH']:
        print(dec_method, method)
        plt.subplot(3,3,c)
        #plt.figure(figsize=(3.21*3, 2.41*3))
        for k in range(len(eps_vals)):    
            file = f'all_methods/{method.lower()}_{dec_method}_EDC_ntraj_2000_iter_0_dt_20.0_eps_param_{eps_vals[k]}/mem_data.hdf'
            F = h5py.File(file)
            sh_pop = np.array(F['sh_pop_adi/data'])
            time_vec = F['time/data'][:]*units.au2fs/1000
            F.close()
            plt.plot(time_vec, sh_pop[:,1], label=labels[k], color=colors[k])
        ml_mctdh = np.loadtxt('reference/ML-MCTDH_n32.dat')
        plt.plot(ml_mctdh[:,0]/1000, ml_mctdh[:,1], label='Ref:ML-MCTDH', ls='dashed', color='black') 
        plt.xlim(0,50) 
        if c==1:
            plt.legend(fontsize=25, ncol=2, loc='upper right')
        plt.ylabel('Population')
        plt.xlabel('Time, ps')
        name = dec_method.replace('_REV23','')
        if method=='FSSH2':
            plt.title(f'FSSH-2 with {name}', fontsize=35)
        elif dec_method=='MFSD':
            plt.title(f'MFSD', fontsize=35)
        else:
            plt.title(f'{method} with {name}', fontsize=35)
        c += 1
plt.suptitle('EDC - No SSY', fontsize=45)
plt.tight_layout()
# plt.savefig(f'{name}_{method}_EDC.jpg', dpi=600)
plt.savefig(f'EDC_no_SSY.jpg', dpi=600)

# Gu-Franco decoherence time

In [None]:
%matplotlib inline
colors = plt.cm.tab10.colors
colors = list(colors)
colors.append('darkblue')
plt.rcParams.update({'font.size': 35, 'axes.linewidth': 3, 
                     'xtick.major.width': 3, 'ytick.major.width': 3, 'lines.linewidth': 6.0})
plt.figure(figsize=(3.21*3*3, 2.41*3*1))
c = 1
for dec_method in ['MFSD','SDM','DISH_REV23']:
    plt.subplot(1,3,c)
    if dec_method=='MFSD':
        print(dec_method)
        file = f'all_methods/fssh_MFSD_Gu_Franco_ntraj_2000_iter_0_dt_20.0_temperature_300.0/mem_data.hdf'
        F = h5py.File(file)
        sh_pop = np.array(F['sh_pop_adi/data'])
        print(sh_pop[:,1])
        time_vec = F['time/data'][:]*units.au2fs/1000
        F.close()
        plt.plot(time_vec, sh_pop[:,1], label='MFSD')
        plt.xlim(0,50) 
        plt.ylabel('Population')
        plt.xlabel('Time, ps')
        plt.title(f'MFSD', fontsize=35)
    else:
        for method in ['FSSH','FSSH2','GFSH']:
            print(dec_method, method)
            file = f'all_methods/{method.lower()}_{dec_method}_Gu_Franco_ntraj_2000_iter_0_dt_20.0_temperature_300.0/mem_data.hdf'
            F = h5py.File(file)
            sh_pop = np.array(F['sh_pop_adi/data'])
            time_vec = F['time/data'][:]*units.au2fs/1000
            F.close()
            if method=='FSSH2':
                label = 'FSSH-2'
            else:
                label = method
            plt.plot(time_vec, sh_pop[:,1], label=label)
            plt.xlim(0,50) 
            plt.ylabel('Population')
            plt.xlabel('Time, ps')
            name = dec_method.replace('_REV23','')
            plt.title(f'{name}', fontsize=35)
    c += 1
    ml_mctdh = np.loadtxt('reference/ML-MCTDH_n32.dat')
    plt.plot(ml_mctdh[:,0]/1000, ml_mctdh[:,1], label='Ref:ML-MCTDH', ls='dashed', color='black') 
    plt.legend(fontsize=25, ncol=1, loc='upper right')
plt.suptitle('Gu-Franco - No SSY', fontsize=45)
plt.tight_layout()
# plt.savefig(f'{name}_{method}_EDC.jpg', dpi=600)
plt.savefig(f'Gu_Franco_no_SSY.jpg', dpi=600)

# SHXF results with different wavepacket width

In [None]:
%matplotlib inline
plt.rcParams.update({'font.size': 40, 'axes.linewidth': 3, 'lines.linewidth': 6.0})
plt.figure(figsize=(3.21*9,2.41*4))
#colors1 = []
colors = plt.cm.tab20.colors
F = h5py.File('FSSH/time.hdf')
time_vec = F['time/data'][:]*units.au2fs/1000
F.close()
# for i in [2,5,6,0,4,3,1,7,8]:
#     print(i)
#     colors1.append(colors[i])
main_fig_labels = []
main_fig_errors = []
folders = ['FSSH','FSSH2','GFSH']
for c, method in enumerate(['FSSH','FSSH2','GFSH']):
    print(method)
    plt.subplot(1,3,c+1) 
    #for wpwidth_scale in [0.05, 0.1, 0.2, 0.3, 0.4, 0.5, 1.0, 2.0, 4.0]:
    for wpwidth_scale in [0.0001, 0.001, 0.01, 0.02]:#, 0.03, 0.04, 0.05, 0.1]:
        file = f'all_methods/{method.lower()}_SHXF_ntraj_2000_iter_0_dt_20.0_wpwidth_scale_{wpwidth_scale}/mem_data.hdf'
        F = h5py.File(file)
        sh_pop = np.array(F['sh_pop_adi/data'])
        time_vec = np.array(F['time/data'])*units.au2fs/1000
        F.close()
        plt.plot(time_vec, sh_pop[:,1], label=f'{wpwidth_scale}$\\sigma_q$')
        
    ml_mctdh = np.loadtxt('reference/ML-MCTDH_n32.dat')
    plt.plot(ml_mctdh[:,0]/1000, ml_mctdh[:,1], label='ML-MCTDH', ls='dashed', color='black')
    plt.xlim(0,50)
    if c==0:
        plt.legend(fontsize=25, ncol=2, loc='upper right')
    plt.ylabel('S$_1$ Population')
    plt.xlabel('Time, ps')
    if method=='FSSH2':
        plt.title('FSSH-2')
    else:    
        plt.title(method)
plt.suptitle('No SSY')
plt.tight_layout()
#plt.savefig(f'shxf_dyn_no_ssy.jpg',dpi=600)
plt.savefig(f'shxf_dyn_no_ssy_2.jpg',dpi=600)

In [None]:
comparing_indices = []
simulation_time_vec = np.arange(sh_pop.shape[0])*units.au2fs*20
reference_time_vec = ml_mctdh[:,0]
print(simulation_time_vec)
for i in range(simulation_time_vec.shape[0]):
    time_error = np.abs(reference_time_vec - simulation_time_vec[i])
    comparing_indices.append(np.argmin(time_error))
print(len(comparing_indices))

In [None]:
comparing_indices = comparing_indices[0:103349]
print(ml_mctdh.shape)
ml_mctdh[comparing_indices,1]

# Now the overall error measurements (the main Figure)

In [None]:
%matplotlib inline
plt.rcParams.update({'font.size': 35, 'axes.linewidth': 3, 'lines.linewidth': 3.0})
plt.figure(figsize=(3.21*6,2.41*4))
#colors1 = []
colors = plt.cm.tab20.colors
# for i in [2,5,6,0,4,3,1,7,8]:
#     print(i)
#     colors1.append(colors[i])
ml_mctdh = np.loadtxt('reference/ML-MCTDH_n32.dat')
main_fig_labels = []
main_fig_errors = []
folders = ['FSSH','FSSH2_new','GFSH_orig']
for c1, method in enumerate(['FSSH','FSSH2','GFSH']):
    print(method)
    for dec_method in ['','BCSH','SHXF','ID-A','MFSD','SDM','DISH']:
        errors = []
        #for i in range(0,50):
        if dec_method=='MFSD':
            file = f'all_methods/fssh2_MFSD_SCHW2_ntraj_2000_iter_0_dt_20.0_A_0.001/mem_data.hdf'
        elif dec_method=='':
            file = f'all_methods/fssh__ntraj_2000_iter_0_dt_20.0/mem_data.hdf'
        elif dec_method=='SDM':
            file = f'all_methods/{method.lower()}_SDM_SCHW1_ntraj_2000_iter_0_dt_20.0_A_0.001/mem_data.hdf'
        elif dec_method=='DISH':
            file = f'all_methods/{method.lower()}_DISH_REV23_SCHW1_ntraj_2000_iter_0_dt_20.0_A_0.001/mem_data.hdf'
        elif dec_method=='SHXF':
            if method=='GFSH':
                file = f'all_methods/{method.lower()}_SHXF_ntraj_2000_iter_0_dt_20.0_wpwidth_scale_0.1/mem_data.hdf'
            elif method=='FSSH':
                file = f'all_methods/{method.lower()}_SHXF_ntraj_2000_iter_0_dt_20.0_wpwidth_scale_0.05/mem_data.hdf'
            elif method=='FSSH2':
                file = f'all_methods/{method.lower()}_SHXF_ntraj_2000_iter_0_dt_20.0_wpwidth_scale_0.1/mem_data.hdf'
        else:
            file = f'all_methods/{method.lower()}_{dec_method}_ntraj_2000_iter_0_dt_20.0/mem_data.hdf' 
        F = h5py.File(file)
        sh_pop = np.array(F['sh_pop_adi/data'])
        F.close()
        #print(sh_pop.shape)
        #error1 = np.average(np.abs(mctdh[0:420000,1]-sh_pop[0:2100000:5,1]))
        error = np.average(np.abs(ml_mctdh[comparing_indices,1]-sh_pop[0:103349,1]))
        #errors.append(error)
        main_fig_errors.append(error)
        if dec_method=='':
            main_fig_labels.append(method)
        elif dec_method=='MFSD' and method=='FSSH2':
            main_fig_labels.append(dec_method)
        else:
            main_fig_labels.append(method+'-'+dec_method)
print(main_fig_labels)
main_fig_labels = np.array(main_fig_labels)
main_fig_errors = np.array(main_fig_errors)
mask = (main_fig_labels != 'FSSH-MFSD') & (main_fig_labels != 'GFSH-MFSD')
filtered_labels = main_fig_labels[mask]
filtered_errors = main_fig_errors[mask]
sorted_indices = np.argsort(filtered_errors)
sorted_labels = filtered_labels[sorted_indices]
sorted_errors = filtered_errors[sorted_indices]
plt.bar(sorted_labels, sorted_errors, color=colors)
# indices_1 = np.argsort(main_fig_errors)
# plt.bar(main_fig_labels[indices_1], main_fig_errors[indices_1], color=colors)
plt.ylabel('$\\epsilon_{pop}$')
plt.tick_params(axis='y', width=3)
plt.tick_params(axis='y', which='minor', width=1.5, length=4)
plt.tick_params(axis='y', which='major', width=1.5, length=7)
plt.yscale('log')
plt.xticks(rotation=90)  # or 90 for vertical
for i, value in enumerate(sorted_errors):
    if i>=16:
        plt.text(i, value - 0.5, f'{value:.4f}', ha='center', va='bottom', fontsize=30, rotation=90)
    else:
        plt.text(i, value + 0.002, f'{value:.4f}', ha='center', va='bottom', fontsize=30, rotation=90)
plt.title('No SSY')
plt.tight_layout()
plt.savefig(f'main_fig_error_no_ssy.jpg',dpi=600)

In [None]:
for i in range(len(sorted_labels)):
    print(sorted_labels[i], sorted_errors[i])

In [None]:
%matplotlib inline
plt.rcParams.update({'font.size': 40, 'axes.linewidth': 3, 'lines.linewidth': 6.0})
plt.figure(figsize=(3.21*9,2.41*4))
#colors1 = []
colors = plt.cm.tab20.colors
# for i in [2,5,6,0,4,3,1,7,8]:
#     print(i)
#     colors1.append(colors[i])
main_fig_labels = []
main_fig_errors = []
for c1, method in enumerate(['FSSH','FSSH2','GFSH']):
    print(method)
    plt.subplot(1,3,c1+1) 
    c = 0
    for dec_method in ['','BCSH','SHXF','ID-A','MFSD','SDM','DISH']:
        if dec_method=='MFSD':
            file = f'all_methods/fssh_MFSD_SCHW2_ntraj_2000_iter_0_dt_20.0_A_0.001/mem_data.hdf'
        elif dec_method=='':
            file = f'all_methods/fssh__ntraj_2000_iter_0_dt_20.0/mem_data.hdf'
        elif dec_method=='SDM':
            file = f'all_methods/{method.lower()}_SDM_SCHW1_ntraj_2000_iter_0_dt_20.0_A_0.001/mem_data.hdf'
        elif dec_method=='DISH':
            file = f'all_methods/{method.lower()}_DISH_REV23_SCHW1_ntraj_2000_iter_0_dt_20.0_A_0.001/mem_data.hdf'
        elif dec_method=='SHXF':
            if method=='GFSH':
                file = f'all_methods/{method.lower()}_SHXF_ntraj_2000_iter_0_dt_20.0_wpwidth_scale_0.1/mem_data.hdf'
            elif method=='FSSH':
                file = f'all_methods/{method.lower()}_SHXF_ntraj_2000_iter_0_dt_20.0_wpwidth_scale_0.05/mem_data.hdf'
            elif method=='FSSH2':
                file = f'all_methods/{method.lower()}_SHXF_ntraj_2000_iter_0_dt_20.0_wpwidth_scale_0.1/mem_data.hdf'
        else:
            file = f'all_methods/{method.lower()}_{dec_method}_ntraj_2000_iter_0_dt_20.0/mem_data.hdf' 
        #print(file)
        F = h5py.File(file)
        sh_pop = np.array(F['sh_pop_adi/data'])
        #print(sh_pop[:,1])
        time_vec = np.array(F['time/data'])*units.au2fs/1000
        F.close()
        
        if dec_method=='':
            label=method
            ls = '-'
        elif dec_method=='MFSD':
            label='MFSD'
            ls = '-.'
        else:
            label=dec_method
            ls = '-'
        plt.plot(time_vec, sh_pop[:,1], label=label, ls=ls)
        c += 1
    ml_mctdh = np.loadtxt('reference/ML-MCTDH_n32.dat')
    plt.plot(ml_mctdh[:,0]/1000, ml_mctdh[:,1], label='ML-MCTDH', ls='dashed', color='black')
    plt.xlim(0,50)
    plt.legend(fontsize=25, ncol=2, loc='upper right')
    plt.ylabel('S$_1$ Population')
    plt.xlabel('Time, ps')
    if method=='FSSH2':
        plt.title(f'FSSH-2')
    else:
        plt.title(method) 
plt.suptitle('No SSY')
plt.tight_layout()
plt.savefig(f'main_fig_dyn_no_ssy.jpg',dpi=600)

# Time step convergence

In [None]:
%matplotlib inline
plt.rcParams.update({'font.size': 40, 'axes.linewidth': 3, 'lines.linewidth': 10.0})
plt.figure(figsize=(3.21*6,2.41*6))
colors = plt.cm.tab20.colors
for c, dt in enumerate([40.0, 20.0, 10.0]):
    print(dt)
    file = f'FSSH/FSSH_ntraj_2000_iter_0_dt_{dt}/mem_data.hdf'
    F = h5py.File(file)
    sh_pop = np.array(F['sh_pop_adi/data'])
    F.close()
    time_vec = np.arange(0,sh_pop.shape[0])*dt*units.au2fs/1000
    #plt.plot(time_vec, np.average(sh_pops, axis=0)[:,1], label=f'{method}')
    plt.plot(time_vec, sh_pop[:,1], label=f'dt={dt}')
ml_mctdh = np.loadtxt('reference/ML-MCTDH_n32.dat')
# plt.plot(ml_mctdh[:,0]/1000, ml_mctdh[:,1], label='ML-MCTDH', ls='dashed', color='black')
plt.xlim(0,50)
plt.legend(fontsize=30, ncol=2, loc='upper right')
plt.ylabel('S$_1$ Population')
plt.xlabel('Time, ps')
# plt.title('Surfce hopping methods')
plt.suptitle('FSSH, No SSY')
plt.tight_layout()
# plt.savefig(f'comparing_sh_no_dec_dyn_no_ssy_.jpg',dpi=600)