In [None]:
%matplotlib inline

import numpy as np
import scipy.sparse as sp
import h5py
from scipy.optimize import curve_fit
import matplotlib.pyplot as plt
from scipy.constants import physical_constants
from libra_py import units
from libra_py.workflows.nbra import lz, step4
from libra_py import data_outs, data_conv
import warnings

au2ev = physical_constants['hartree-electron volt relationship'][0]
au2fs = physical_constants['atomic unit of time'][0] * 1e15

def single_exp_fixed_E0(t, A, tau, E0):
    """E(t) = (E0 - A) * exp(-t/tau) + A  (E0 fixed)"""
    return (E0 - A) * np.exp(-t / tau) + A

def double_exp_func(t, A1, tau1, A2, tau2, c):
    return A1 * np.exp(-t / tau1) + A2 * np.exp(-t / tau2) + c

energies = []
for step in range(1000, 4994):
    energy = sp.load_npz('../../step3_data_9_10_2025/c20/res-mb-sd-c20/Hvib_ci_{}_re.npz'.format(step)).todense().real
    energies.append(np.diag(energy))
energies = np.array(energies) * au2ev   # eV
e2 = energies

methods = ['FSSH', 'IDA', 'MSDM', 'FSSH2', 'GFSH', 'MSDM_GFSH']

fig, axs = plt.subplots(4, 2, figsize=(12, 18))  # 6 methods + BLLZ
plt.subplots_adjust(hspace=0.5, wspace=0.5)

for c, method in enumerate(methods):
    row, col = divmod(c, 2)
    ax = axs[row, col]

    taus_single = []         # tau from single-exp
    As_single   = []         # A from single-exp
    E0_list     = []         # initial E(0)
    tau1_list   = []         # tau1 (fast) from double-exp (only stored for FSSH/FSSH2)
    avg_tau_list = []        # amplitude-weighted avg time from double-exp (for reference)

    for icond in range(1, 3994, 200):
        with h5py.File(f'{method}_latestnewNBRA_icond_{icond}/mem_data.hdf', 'r') as F:
            sh_pop = np.array(F['sh_pop_adi/data'])
            time   = np.array(F['time/data'])[0:3994] * au2fs 

        tmp2 = np.roll(e2[0:3994, :], -icond, axis=0)      
        tmp1 = np.multiply(sh_pop[0:3994, :], tmp2)        
        tmp  = np.sum(tmp1, axis=1)                      

        ax.plot(time, tmp, color='gray', alpha=0.55)

        E0 = float(tmp[0])
        E_inf_guess = float(np.median(tmp[-200:])) 
        tau_guess   = 500.0                      

        lo, hi = float(np.min(tmp)), float(np.max(tmp))
        A_lower = min(lo, E0, E_inf_guess) - 1.0
        A_upper = max(hi, E0, E_inf_guess) + 1.0

        f_single = lambda t, A, tau: single_exp_fixed_E0(t, A, tau, E0)
        try:
            popt_s, pcov_s = curve_fit(
                f_single, time, tmp,
                p0=(E_inf_guess, tau_guess),
                bounds=([A_lower, 1e-6], [A_upper, np.inf]),
                maxfev=20000
            )
            A_fit, tau_fit = popt_s
            taus_single.append(tau_fit)
            As_single.append(A_fit)
            E0_list.append(E0)
        except Exception as err:
            warnings.warn(f"{method} icond {icond}: single-exp fit failed ({err})")
            continue

        # Double exponential fit (only to extract tau1 for FSSH/FSSH2)
        if method in ('FSSH', 'FSSH2'):
            initial_energy = E0
            p0_double = [initial_energy / 3, 100.0, initial_energy / 3, 1000.0, initial_energy / 3]
            lb_double = [0, 0, 0, 0, 0]
            ub_double = [max(hi, E0), np.inf, max(hi, E0), np.inf, max(hi, E0)]
            try:
                popt_d, _ = curve_fit(
                    double_exp_func, time, tmp,
                    p0=p0_double, bounds=(lb_double, ub_double),
                    maxfev=20000
                )
                A1, tau1, A2, tau2, c0 = popt_d
                tau1_list.append(tau1)
                
                if (A1 + A2) > 0:
                    avg_tau = (A1 * tau1 + A2 * tau2) / (A1 + A2)
                else:
                    avg_tau = np.nan
                avg_tau_list.append(avg_tau)
            except Exception as err:
                warnings.warn(f"{method} icond {icond}: double-exp fit failed ({err})")

    taus_single = np.array(taus_single, dtype=float)
    As_single   = np.array(As_single, dtype=float)
    E0_list     = np.array(E0_list, dtype=float)

    if taus_single.size == 0:
        ax.text(0.05, 0.9, "No successful fits", transform=ax.transAxes, fontsize=14, bbox=dict(facecolor='white', alpha=0.6))
        ax.set_title(f'C$_{{20}}$ {method}', fontsize=20)
        ax.set_xlabel('Time (fs)', fontsize=16)
        ax.set_ylabel('Excess Energy (eV)', fontsize=16)
        ax.tick_params(axis='both', which='major', labelsize=16)
        ax.set_xlim(0, 2000); ax.set_xticks([0, 500, 1000, 1500, 2000])
        continue

    Z = 1.96
    mean_tau  = np.nanmean(taus_single)
    std_tau   = np.nanstd(taus_single, ddof=1) if taus_single.size > 1 else 0.0
    N         = max(1, taus_single.size)
    ci_tau    = Z * std_tau / np.sqrt(N)

    E0_rep = np.nanmean(E0_list)
    A_rep  = np.nanmean(As_single)
    tau_rep = mean_tau

    t_fit = np.linspace(0.0, min(2000.0, float(time[-1])), 1000)
    fit_line = single_exp_fixed_E0(t_fit, A_rep, tau_rep, E0_rep)
    ax.plot(t_fit, fit_line, 'b--', linewidth=4, label='Single Exp Fit')

    if method in ('FSSH', 'FSSH2') and len(tau1_list) > 0:
        tau1_arr = np.array(tau1_list, dtype=float)
        mean_tau1 = np.nanmean(tau1_arr)
        std_tau1  = np.nanstd(tau1_arr, ddof=1) if tau1_arr.size > 1 else 0.0
        N1        = max(1, tau1_arr.size)
        ci_tau1   = Z * std_tau1 / np.sqrt(N1)
        txt = (f'Single Exp τ = {int(round(mean_tau))} ± {int(round(ci_tau))} fs\n'
               f'(Double Exp) τ₁ = {int(round(mean_tau1))} ± {int(round(ci_tau1))} fs')
    else:
        txt = f'Single Exp τ = {int(round(mean_tau))} ± {int(round(ci_tau))} fs'

    ax.text(0.03, 0.95, txt, transform=ax.transAxes, fontsize=16,
            va='top', bbox=dict(facecolor='white', alpha=0.6))

    ax.set_title(f'C$_{{20}}$ {method}', fontsize=20)
    ax.set_xlabel('Time (fs)', fontsize=16)
    ax.set_ylabel('Excess Energy (eV)', fontsize=16)
    ax.tick_params(axis='both', which='major', labelsize=16)
    ax.set_xlim(0, 2000)
    ax.set_xticks([0, 500, 1000, 1500, 2000])

# BLLZ panel 
bllz_ax = axs[3, 0]

params = { "data_set_paths": ["/projects/academic/alexeyak/kosar/cp2k/fullerenes/step3_data_9_10_2025/c20/res-mb-sd-c20/"],
           "Hvib_re_prefix": "Hvib_ci_", "Hvib_re_suffix" : "_re",
           "Hvib_im_prefix": "Hvib_ci_", "Hvib_im_suffix" : "_im",
           "init_times": 1000,  "nfiles": 3994,
           "nstates": 40, "active_space": list(range(40)) }
Hvibs = step4.get_Hvib_scipy(params)

params = { "dt": 0.5*units.fs2au, "ntraj": 1, "nsteps": 3994, "istate": 6,
           "Boltz_opt_BL": 1, "Boltz_opt": 1, "T": 300.0,
           "do_output": True, "outfile": "BL.txt", "do_return": True,
           "evolve_Markov": True, "evolve_TSH": False,
           "extend_md": False, "extend_md_time": 3994,
           "detect_SD_difference": False, "return_probabilities": False,
           "init_times": [0] }
params["gap_min_exception"] = 0
params["target_space"] = 0

res, _ = lz.run(Hvibs, params)
res0 = data_conv.MATRIX2nparray(res, _dtype=float)

def E_function(t, E0, Einf, tau):
    return (E0 - Einf) * np.exp(-t / tau) + Einf

t_bdata = res0[:, 0] * units.au2fs
t_bdata_rounded = np.round(t_bdata)
E_SE_bdata = res0[:, 124] * units.au2ev

p0 = (E_SE_bdata[0], E_SE_bdata[-1], 100.0)
popt, pcov = curve_fit(E_function, t_bdata, E_SE_bdata, p0=p0, bounds=([-np.inf, -np.inf, 1e-6],[np.inf, np.inf, np.inf]))
E0_fit, Einf_fit, tau_fit = popt

bllz_ax.plot(t_bdata_rounded, E_SE_bdata, color="gray", label="BLLZ", linewidth=2)
bllz_ax.plot(t_bdata_rounded, E_function(t_bdata_rounded, *popt), color="blue", linestyle="--",
             label=f"Single Exp Fit: τ={int(round(tau_fit))} fs", linewidth=4)

bllz_ax.set_xlabel('Time (fs)', fontsize=16)
bllz_ax.set_ylabel('Energy (eV)', fontsize=16)
bllz_ax.set_title('C$_{20}$ BLLZ', fontsize=20)
bllz_ax.tick_params(axis='both', which='major', labelsize=16)
bllz_ax.legend(loc='upper left', fontsize=16)
bllz_ax.set_xlim(0, 2000)
bllz_ax.set_xticks([0, 500, 1000, 1500, 2000])

fig.delaxes(axs[3, 1])

plt.tight_layout(rect=[0.05, 0, 1, 0.97])
plt.savefig('C20_singleExp.png', dpi=600, bbox_inches='tight')
plt.show()