In [None]:
import json
import numpy as np
from openff.toolkit import Molecule
from pint import UnitRegistry
import matplotlib.pyplot as plt
from tabulate import tabulate
from visualization import show_oemol_struc
from cinnabar.stats import bootstrap_statistic
from matplotlib.backends.backend_pdf import PdfPages
from collections import defaultdict
from matplotlib import rcParams
from PIL import Image
import io

from openff.units import unit

from qcelemental.models.common_models import Model
from qcelemental.models import AtomicInput
from qcengine import compute

In [None]:
def get_relative_energies(energies, conv_factor=1.0):
    min_energy = min(energies)
    energies = [conv_factor*(ener - min_energy) for ener in energies]
    return energies

def return_energy(qcmol, method, program, basis):
    result = compute(
    input_data=AtomicInput(molecule=qcmol, 
                           driver="energy", 
                           model=Model(method=method, 
                                       basis=basis,
                                       keywords={'fix_orientation':True, 'fix_com':True})),
        program=program,
        local_options={"memory": 96, "ncores": 32})
    return result

rcParams.update({"font.size": 12})
REF_SPEC = 'mp2/heavy-aug-cc-pv[tq]z + d:ccsd(t)/heavy-aug-cc-pvdz'
with open('./data/'+REF_SPEC.replace('/','_')+'_single_points_data.json', 'r') as file:
        ref_dict = json.load(file)
with open('./data/MP2_heavy-aug-cc-pVTZ_torsiondrive_data.json', 'r') as file:
    mp2_data = json.load(file)

In [None]:
ureg = UnitRegistry()
hartree = 1 * ureg.hartree
HARTREE_TO_KCALMOL = hartree.to(ureg.kilocalorie/(ureg.avogadro_constant*ureg.mole)).magnitude
kj = 1 * ureg.kilojoule
kj_to_kcal = kj.to(ureg.kilocalorie)

In [None]:
method = 'b97-d3bj'
program = 'psi4'
basis = 'def2-tzvp'
pdf = PdfPages("./output/torsion_profiles_"+method+"_"+basis+".pdf")

In [None]:
rmse = defaultdict(float)
mae = defaultdict(float)
all_energies = []
neutral_energies = []
charged_energies = []
neutral_ref = []
charged_ref = []
all_ref = []
energies_and_dipoles = defaultdict(dict)

for i in range(59):
    print(i, flush=True)
    ref_energies = np.array(get_relative_energies(ref_dict[str(i)]['total energies'], HARTREE_TO_KCALMOL))
    ref_angles = ref_dict[str(i)]['angles']
    mapped_smiles = mp2_data[str(i)]['metadata']['mapped_smiles']
    dihedrals = mp2_data[str(i)]['metadata']['dihedral scanned'][0]
    mol_charge = mp2_data[str(i)]['metadata']['mol_charge']
    offmol = Molecule.from_mapped_smiles(mapped_smiles, allow_undefined_stereo=True)
    energies_and_dipoles[i]['angles'] = ref_angles
    
    
    flag_neutral = True
    if mol_charge != 0:
        flag_neutral = False
    elif mol_charge == 0  and mapped_smiles.find('+') != -1:
        flag_neutral = False
    
    all_ref.extend(ref_energies)
    if flag_neutral:
        neutral_ref.extend(ref_energies)
    else:
        charged_ref.extend(ref_energies)
        
    fig, ax = plt.subplots(figsize=[12, 10])
    ax.plot(
            ref_angles,
            ref_energies,
            "-D",
            label="REFERENCE",
            linewidth=3.0,
            c="k",
            markersize=10,
        )
    method_energies = []
    disp_energies = []
    func_energies = []
    dipoles = []
    for j in range(24):
        positions = np.array(mp2_data[str(i)]['final_geometries'][j]) * 0.529177
        offmol._conformers = [positions * unit.angstrom]
        qcmol = offmol.to_qcschema()
        qc_dict = qcmol.dict()
        qc_dict.update({'fix_com': True, 'fix_orientation': True})
        qcmol = qcel.models.Molecule.from_data(qc_dict, validate=True)
        result = return_energy(qcmol, method=method, program=program, basis=basis)
        method_energies.append(result.properties.return_energy)
        disp_energies.append(result.properties.scf_dispersion_correction_energy)
        func_energies.append(result.properties.nuclear_repulsion_energy+result.properties.scf_one_electron_energy+result.properties.scf_two_electron_energy+result.properties.scf_xc_energy)
        dipoles.append(list(result.properties.scf_dipole_moment))
    energies_and_dipoles[i]['total energies'] = list(method_energies)
    energies_and_dipoles[i]["dispersion energies"] = list(disp_energies)
    energies_and_dipoles[i]['dft energies'] = list(func_energies)
    energies_and_dipoles[i]['dipoles'] = list(dipoles)
    energies = np.array(get_relative_energies(method_energies, HARTREE_TO_KCALMOL))
    
    all_energies.extend(energies)
    if flag_neutral:
        neutral_energies.extend(energies)
    else:
        charged_energies.extend(energies)
    
    rmse_energies = np.sqrt(np.mean((energies - ref_energies) ** 2))
    mae_energies = np.mean(np.abs(energies - ref_energies))
    rmse[i] = rmse_energies
    mae[i] = mae_energies
    ax.plot(
            ref_angles,
            energies,
            "-v",
            label='esp',
            linewidth=2.0,
            markersize=10,
        )
        
    
    plt.xlabel(
        "Dihedral angle in degrees",
    )
    plt.ylabel("Relative energies in kcal/mol")
    
    plt.legend(loc="lower left", bbox_to_anchor=(1.04, 0), fontsize=12)
    oemol = offmol.to_openeye()
    image = show_oemol_struc(
        oemol, torsions=True, atom_indices=dihedrals, width=600, height=500
    )
    img = Image.open(io.BytesIO(image.data))
    im_arr = np.asarray(img)
    newax = fig.add_axes([0.9, 0.6, 0.35, 0.35], anchor="SW", zorder=-1)
    newax.imshow(im_arr)
    newax.axis("off")
    if flag_neutral:
        plt.title('Neutral molecule')
    else:
        plt.title('Charged molecule')
    plt.show()
    pdf.savefig(fig, dpi=600, bbox_inches="tight")

In [None]:
with open('./data/'+method.replace('/','_')+'_'+basis+'_single_points_data.json', 'w') as outfile:
    json.dump(energies_and_dipoles, outfile)       

In [None]:
table = []
all_ref = np.array(all_ref)
method_energies = np.array(all_energies)
rmse_stats = bootstrap_statistic(y_true=all_ref, y_pred=method_energies, statistic='RMSE')
mue_stats = bootstrap_statistic(y_true=all_ref, y_pred=method_energies, statistic='MUE')   
neutral_stats = bootstrap_statistic(y_true=neutral_ref, y_pred=np.array(neutral_energies), statistic='RMSE')
charged_stats = bootstrap_statistic(y_true=charged_ref, y_pred=np.array(charged_energies), statistic='RMSE')
lt_five_rmse_stats = bootstrap_statistic(y_true=all_ref[np.abs(all_ref) < 5], y_pred=method_energies[np.abs(all_ref) < 5], statistic='RMSE')
lt_five_mue_stats = bootstrap_statistic(y_true=all_ref[np.abs(all_ref) < 5], y_pred=method_energies[np.abs(all_ref) < 5], statistic='MUE')
table.append([f'{method}/{basis}',
              "%.4f" % rmse_stats['mle'],"%.4f" % rmse_stats['low'],"%.4f" % rmse_stats['high'],
              "%.4f" % mue_stats['mle'],"%.4f" % mue_stats['low'],"%.4f" % mue_stats['high'],
              "%.4f" % neutral_stats['mle'],"%.4f" % neutral_stats['low'],"%.4f" % neutral_stats['high'],
              "%.4f" % charged_stats['mle'],"%.4f" % charged_stats['low'],"%.4f" % charged_stats['high'],
              "%.4f" % lt_five_rmse_stats['mle'],"%.4f" % lt_five_rmse_stats['low'],"%.4f" % lt_five_rmse_stats['high']])

In [None]:
pdf.close()

print(
    tabulate(
        table,
        headers=["Specification", 
                 "RMSE in kcal/mol", "low 95% ci", "high 95% ci", 
                 "MAE in kcal/mol", "low 95% ci", "high 95% ci", 
                 "Neutral RMSE in kcal/mol", "low 95% ci", "high 95% ci", 
                 "Charged RMSE in kcal/mol", "low 95% ci", "high 95% ci", 
                 "lt_five_RMSE in kcal/mol", "low 95% ci", "high 95% ci"],
        tablefmt="orgtbl",
    )
)
print("* closer to zero the better")

In [None]:
# For latex table entries
print("RMSE MUE TRR")
for item in table:
    print(
        f"{item[0].upper()} & ${{{item[1]}}}_{{{item[2]}}}^{{{item[3]}"
        f"}}$ & ${{{item[4]}}}_{{{item[5]}}}^{{{item[6]}}}$ & ${{{item[13]}}}_{{{item[14]}}}^{{{item[15]}}}$\\\\ \\vspace{{2mm}}")


In [None]:
# For latex table entries
print("RMSE Neutral Charged")
for item in table:
    print(
        f"{item[0].upper()} & ${{{item[1]}}}^{{{item[2]}}}_{{{item[3]}"
        f"}}$ & ${{{item[7]}}}^{{{item[8]}}}_{{{item[9]}}}$ & ${{{item[10]}}}^{{{item[11]}}}_{{{item[12]}}}$\\\\ \\vspace{{2mm}}")