In [None]:
import os
from tqdm import tqdm
import pickle
import json
import numpy as np
import scipy as sp
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances
from scipy.optimize import linear_sum_assignment
from scipy.constants import c, pi
import matplotlib.pyplot as plt
import seaborn as sns
import ase, ase.io
from ase.units import kcal, mol, Ang, second, eV, invcm
import ase.units as units
from ase.data import atomic_masses
from ase.vibrations.data import VibrationsData
from ase.thermochemistry import IdealGasThermo
from typing import Any, Dict, Iterator, List, Sequence, Tuple, TypeVar, Union
rxns = [
    # no match with training set
    [2, 5, 7, 16, 18, 21, 25, 26, 27, 29, 30, 31, 32, 34, 36, 38, 40, 41, 44, 47, 48, 49, 51, 52, 54, 55, 58, 62, 67, 68, 69, 70, 73, 74, 76, 78, 81, 82, 86, 87, 89, 93, 94, 95, 97, 98, 99, 101, 103, 105, 107, 110, 113, 114, 116, 117, 118, 122, 123, 125, 126, 132, 138, 141, 143, 145, 150, 152, 154, 157, 158, 160, 161, 162, 165, 167, 173, 174, 181, 182, 183, 187, 188, 191, 193, 194, 196, 197, 198, 209, 210, 211, 214, 218, 219, 221, 223, 226, 227, 231, 232, 234, 236, 237, 238, 240, 242, 243, 244, 248, 249, 250, 252, 254, 261, 262], 
    # 1-end match with training set
    [0, 1, 3, 4, 6, 8, 9, 10, 11, 12, 13, 14, 19, 20, 23, 24, 28, 33, 35, 37, 39, 42, 45, 46, 56, 57, 59, 60, 63, 64, 65, 66, 71, 77, 79, 83, 84, 85, 88, 90, 91, 92, 96, 100, 104, 106, 108, 109, 111, 112, 115, 119, 120, 121, 124, 127, 128, 129, 131, 133, 134, 135, 137, 139, 140, 146, 148, 149, 151, 153, 155, 156, 159, 163, 164, 166, 169, 170, 171, 172, 175, 176, 177, 178, 179, 180, 184, 185, 186, 189, 190, 192, 195, 199, 200, 201, 202, 203, 204, 205, 206, 207, 212, 213, 215, 220, 222, 224, 225, 229, 235, 239, 245, 246, 247, 251, 253, 255, 256, 257, 258, 259, 260, 263, 264], 
    # 2-end match with training set
    [15, 17, 22, 43, 50, 53, 61, 72, 75, 80, 102, 130, 136, 142, 144, 147, 168, 208, 216, 217, 228, 230, 233, 241],
    ]

class VibrationsData(VibrationsData):
    def __init__(self, **kwargs):
        super().__init__(**kwargs)

    def _energies_and_modes(self) -> Tuple[np.ndarray, np.ndarray]:
        """Diagonalise the Hessian to obtain harmonic modes

        This method is an internal implementation of get_energies_and_modes(),
        see the docstring of that method for more information.

        """
        active_atoms = self._atoms[self.get_mask()]
        n_atoms = len(active_atoms)
        masses = active_atoms.get_masses()

        if not np.all(masses):
            raise ValueError('Zero mass encountered in one or more of '
                             'the vibrated atoms. Use Atoms.set_masses()'
                             ' to set all masses to non-zero values.')
        mass_weights = np.repeat(masses**-0.5, 3)

        positions = active_atoms.get_positions() - active_atoms.get_center_of_mass()
        _, vectors_inertia = active_atoms.get_moments_of_inertia(vectors=True)
        vectors_transrot = np.zeros((6, n_atoms, 3))
        vectors_transrot[0, :, 0] = 1
        vectors_transrot[1, :, 1] = 1
        vectors_transrot[2, :, 2] = 1
        vectors_transrot[3] = positions @ vectors_inertia[[1]].T @ vectors_inertia[[2]] - positions @ vectors_inertia[[2]].T @ vectors_inertia[[1]]
        vectors_transrot[4] = positions @ vectors_inertia[[2]].T @ vectors_inertia[[0]] - positions @ vectors_inertia[[0]].T @ vectors_inertia[[2]]
        vectors_transrot[5] = positions @ vectors_inertia[[0]].T @ vectors_inertia[[1]] - positions @ vectors_inertia[[1]].T @ vectors_inertia[[0]]
        vectors_transrot = vectors_transrot.reshape((6, n_atoms * 3))
        vectors_transrot = vectors_transrot / mass_weights
        vectors_transrot, _ = np.linalg.qr(vectors_transrot.T)
        vectors_transrot = vectors_transrot.T
        proj = np.eye(n_atoms * 3) - vectors_transrot.T @ vectors_transrot

        omega2, vectors = np.linalg.eigh(
            proj.T @ (
                mass_weights
                * self.get_hessian_2d()
                * mass_weights[:, np.newaxis])
            @ proj)

        unit_conversion = units._hbar * units.m / np.sqrt(units._e * units._amu)
        energies = unit_conversion * omega2.astype(complex)**0.5

        modes = vectors.T.reshape(n_atoms * 3, n_atoms, 3)
        # modes = modes * masses[np.newaxis, :, np.newaxis]**-0.5

        return (energies, modes)

def thermo_from_output(output):
    vib = vib_from_output(output)
    thermo = IdealGasThermo(
        vib.get_energies(), 
        'nonlinear', 
        potentialenergy=output['results']['energy'],
        atoms=vib.molecule,
        symmetrynumber=1,
        spin=0,
        ignore_imag_modes=True,
        )
    return thermo
def vib_from_output(output):
    positions = np.array(json.loads(output['atoms']['atoms_json'])['positions']['__ndarray__'][-1]).reshape(-1, 3)
    numbers = json.loads(output['atoms']['atoms_json'])['numbers']['__ndarray__'][-1]
    molecule = ase.Atoms(numbers, positions)
    n_atoms = len(molecule)
    hessian = np.array(output['results']['hessian']).reshape(n_atoms, 3, n_atoms, 3)
    vib = VibrationsData(atoms=molecule, hessian=hessian)
    return vib, molecule

In [None]:
rxn_reactive = [
    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 16, 18, 19, 20, 21, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 34, 35, 36, 38, 39, 41, 42, 44, 45, 46, 47, 48, 49, 51, 52, 56, 57, 58, 59, 63, 66, 68, 69, 70, 71, 73, 74, 77, 78, 79, 81, 82, 83, 85, 87, 88, 89, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 103, 104, 105, 106, 108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 132, 133, 134, 135, 138, 139, 140, 141, 143, 145, 146, 148, 150, 151, 152, 153, 154, 155, 156, 157, 159, 160, 161, 162, 163, 164, 167, 169, 170, 171, 172, 173, 174, 175, 176, 177, 178, 180, 181, 182, 183, 185, 186, 187, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 209, 210, 211, 213, 214, 215, 218, 219, 220, 221, 224, 225, 226, 227, 229, 231, 232, 234, 235, 236, 239, 240, 242, 243, 244, 246, 247, 248, 249, 250, 252, 254, 255, 256, 257, 258, 259, 260, 261, 262, 264],
    [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 11, 12, 13, 14, 16, 18, 19, 20, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 41, 42, 44, 45, 46, 47, 48, 49, 51, 52, 55, 56, 57, 58, 59, 60, 63, 65, 66, 68, 69, 70, 71, 73, 74, 77, 78, 79, 81, 82, 83, 85, 87, 88, 90, 91, 92, 93, 94, 95, 96, 97, 98, 99, 100, 101, 103, 104, 105, 106, 107, 108, 109, 110, 111, 112, 113, 114, 115, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 132, 133, 134, 135, 137, 138, 139, 140, 141, 145, 146, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 159, 160, 161, 162, 163, 164, 165, 166, 167, 169, 170, 171, 172, 173, 174, 175, 176, 178, 180, 181, 182, 183, 184, 185, 186, 187, 189, 190, 191, 192, 193, 194, 195, 196, 197, 198, 199, 200, 201, 202, 203, 204, 205, 206, 207, 209, 210, 211, 213, 214, 215, 218, 219, 220, 221, 222, 224, 225, 226, 227, 229, 232, 234, 235, 236, 237, 239, 240, 242, 243, 244, 245, 246, 247, 248, 249, 250, 251, 253, 254, 255, 256, 257, 258, 259, 260, 261, 262, 264],
]
colors = ['tab:green', 'tab:orange']

plt.figure(figsize=(3, 3))

for ts_method_, ts_method in enumerate(['nn1', 'nn0']):
    res_nn, res_dft = [], []

    for rxn in tqdm(range(265)):
        try:
            assert rxn not in rxns[2]
            assert rxn in rxn_reactive[ts_method_]
        except:
            continue
        try:
            output_nn = json.load(open(f'Outputs/{rxn:03}noise00_TS_{ts_method}.json', 'r'))
            output_dft = json.load(open(f'Outputs/{rxn:03}noise00_TS_{ts_method}_freq_dft1.json', 'r'))
        except FileNotFoundError:
            continue
        n_atoms = output_nn['natoms']

        vib_nn, molecule_nn = vib_from_output(output_nn)
        vib_dft, molecule_dft = vib_from_output(output_dft)
        freqs_nn = vib_nn.get_frequencies()
        freqs_nn = freqs_nn.real - freqs_nn.imag
        freqs_dft = vib_dft.get_frequencies()
        freqs_dft = freqs_dft.real - freqs_dft.imag
        modes_nn = vib_nn.get_modes().reshape(-1, n_atoms * 3)
        modes_dft = vib_dft.get_modes().reshape(-1, n_atoms * 3)
        similarity_matrix = np.abs(cosine_similarity(modes_nn[[0]], modes_dft))
        row_ind, col_ind = linear_sum_assignment(-similarity_matrix)

        res_nn.append(freqs_nn[row_ind][0])
        res_dft.append(freqs_dft[col_ind][0])

    plt.plot(res_dft, res_nn, '.', color=colors[ts_method_], markerfacecolor='none')

xymin, xymax = min(min(res_dft), min(res_nn)), max(max(res_dft), max(res_nn))
xymin, xymax = xymin - 0.05 * (xymax - xymin), xymax + 0.05 * (xymax - xymin)
plt.plot([xymin, xymax], [xymin, xymax], 'k:', zorder=0)
plt.plot([xymin, xymax], [0, 0], 'k:', zorder=0)
plt.plot([0, 0], [xymin, xymax], 'k:', zorder=0)
plt.xlim([xymin, xymax])
plt.ylim([xymin, xymax])
plt.xticks(plt.yticks()[0][1:-1:2])
plt.yticks(plt.yticks()[0][1:-1:2])
plt.legend(['QN Hessian\n(NewtonNet)', 'Full-Hessian\n(NewtonNet)'], title='Optimized TS', loc='upper center', bbox_to_anchor=(0.5, 4/3), framealpha=1)
plt.xlabel('DFT corresponding frequency (cm$^{-1}$)')
plt.ylabel('NewtonNet leftmost frequency (cm$^{-1}$)')
plt.savefig('Freq.pdf', bbox_inches='tight', transparent=True)
plt.show()