In [1]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [2]:
from datetime import datetime
from pathlib import Path
import sys
import pickle

import numpy as np
from pymatgen.core.structure import Molecule
from scipy.interpolate import InterpolatedUnivariateSpline
from tqdm import tqdm

Append the `home` path of this project.

In [3]:
sys.path.append(str(Path.cwd().parent))

# Introduction

This notebook contains the final post-processing scripts for preparing the machine learning (ML) inputs for the QM9 datasets. If you want to generate all files from scratch, you'll need to run the `XAS-NNE/data/qm9/00_process_qm9_data.py` script from within the `XAS-NNE/data/qm9/` directory.

However, we have constructed the ML inputs already and stored them in the GitHub repository. These inputs are all derived directly from the final output `pickle` file from `XAS-NNE/data/qm9/00_process_qm9_data.py`.

# Construct the ACSF feature vectors

We use the Atom-centered Symmetry Functions (ACSF) feature vectors as inputs for the ML models. See [here](https://singroup.github.io/dscribe/latest/tutorials/descriptors/acsf.html) for the `Dscribe` library implementation docs. This is the original paper:

> Jörg Behler. Atom-centered symmetry functions for constructing high-dimensional neural network potentials. J. Chem. Phys., 134(7):074106, 2011.

In [4]:
from dscribe.descriptors import ACSF

Read in the results from the outputs of the previous scripts. This is a relatively large `pickle` file (~20 GB) so proceed with caution.

In [5]:
path = Path("data/qm9/XANES-220622-C-N-O.pkl")
print(path.exists())
data = pickle.load(open(path, "rb"))

True


Setup some common grids for interpolating all of the spectra onto.

In [6]:
N = 200
grids = {
    "O": np.linspace(528, 582, N),
    "N": np.linspace(395, 449, N),
    "C": np.linspace(275, 329, N)
}

## Construct a lookup table for the molecular spectra if desired

Note this isn't really needed in this work but we created it for convenience.

In [None]:
lookup_table_data = {
    dat["smiles"]: {
        "qm9id": qm9id,
        "canon_smiles": None,
        "C-XANES": None,
        "N-XANES": None,
        "O-XANES": None
    } for qm9id, dat in data.items()
}

In [None]:
for qm9id, datum in tqdm(data.items()):
    smiles = datum["smiles"]
    lookup_table_data[smiles]["canon_smiles"] = datum["canon_smiles"]
    molecule = Molecule.from_dict(datum["molecule"])
    atoms = [site.specie.symbol for site in molecule]
    
    for central_atom in ["C", "N", "O"]:
        central_atom_indexes = [ii for ii, atom_type in enumerate(atoms) if atom_type == central_atom]
        if len(central_atom_indexes) == 0:
            continue

        # Each of these comes with a spectrum
        molecular_spectrum = []
        molecule_valid = True
        for ii in central_atom_indexes:
            key = f"{ii}_{central_atom}"
            
            # Get the spectrum corresponding to it
            s = np.array(datum["xanes"][key]["spectrum"])
            
            try:
                spline = InterpolatedUnivariateSpline(s[:, 0], s[:, 3])
            except IndexError:
                molecule_valid = False
                break
                
            res = spline(grids[central_atom])
            
            # Oxygen and Carbon screening condition for unphysical/outlier results
            if CENTRAL_ATOM == "O":
                if np.any(res[:10] > 5.0):
                    molecule_valid = False
                    break
            elif CENTRAL_ATOM == "C":
                if np.any(res[:35] > 10.0):
                    molecule_valid = False
                    break
            
            molecular_spectrum.append(res)
            
        if molecule_valid:
            molecular_spectrum = np.array(molecular_spectrum).mean(axis=0)
            molecular_spectrum[molecular_spectrum < 0.0] = 0.0
            lookup_table_data[smiles][f"{central_atom}-XANES"] = molecular_spectrum.tolist()


In [None]:
final_lookup_table_data = {
    "data": lookup_table_data,
    "grids": grids
}

In [None]:
nones = {"C": 0, "N": 0, "O": 0}
for central_atom in ["C", "N", "O"]:
    for smiles, datum in final_lookup_table_data["data"].items():
        if datum[f"{central_atom}-XANES"] is None:
            nones[central_atom] += 1
print(nones)

In [None]:
pickle.dump(final_lookup_table_data, open("qm9_molecule_xanes.pkl", "wb"), protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
final_lookup_table_data["data"]["CCCC"].keys()

## Construct for each absorbing atom type

In [None]:
CENTRAL_ATOM = "C"   # C, N or O
grid = grids[CENTRAL_ATOM]

In [None]:
species = ["H", "C", "O", "N", "F"]
rcut = 6.0
g2_params = [[1.0, 0], [0.1, 0], [0.01, 0]]
g4_params=[
    [0.001, 1.0, -1.0],
    [0.001, 2.0, -1.0],
    [0.001, 4.0, -1.0],
    [0.01, 1.0, -1.0],
    [0.01, 2.0, -1.0],
    [0.01, 4.0, -1.0],
    [0.1, 1.0, -1.0],
    [0.1, 2.0, -1.0],
    [0.1, 3.0, -1.0]
]  # aenet paper
acsf = ACSF(
    species=species,
    rcut=rcut,
    g2_params=g2_params,
    g4_params=g4_params
)

In [None]:
from ase import Atom, Atoms

In [None]:
origin_smiles = []
molecule_site_pairs = []
acsf_array = []
spectra = []

# cc = 0
for qm9id, datum in tqdm(data.items()):
    molecule = Molecule.from_dict(datum["molecule"])
    atoms = []
    central_atom_indexes = []
    for ii, site in enumerate(molecule):
        atom = Atom(site.specie.symbol, site.coords)
        atoms.append(atom)
        if site.specie.symbol == CENTRAL_ATOM:
            central_atom_indexes.append(ii)
    atoms = Atoms(atoms)
    
    if len(central_atom_indexes) == 0:
        continue

    tmp_acsf = acsf.create(atoms, positions=central_atom_indexes)
    
    for idx, ii in enumerate(central_atom_indexes):
        key = f"{ii}_{CENTRAL_ATOM}"
        s = np.array(datum["xanes"][key]["spectrum"])

        try:
            spline = InterpolatedUnivariateSpline(s[:, 0], s[:, 3])
        except IndexError:
            continue
            
        res = spline(grid)
        
        # Oxygen and Carbon screening condition for unphysical/outlier results
        if CENTRAL_ATOM == "O":
            if np.any(res[:10] > 5.0):
                continue
        elif CENTRAL_ATOM == "C":
            if np.any(res[:35] > 10.0):
                continue
        
        origin_smiles.append(datum["smiles"])
        spectra.append(res)
        acsf_array.append(tmp_acsf[idx])
        molecule_site_pairs.append(f"{qm9id}_{ii}")
    
#     cc += 1
    
#     if cc > 20:
#         break

acsf_array = np.array(acsf_array)
spectra = np.array(spectra)

And finally save to disk.

In [None]:
now = datetime.now().strftime("%y%m%d")
fname = f"../data/qm9/XANES-{now}-ACSF-{CENTRAL_ATOM}.pkl"
print(fname)

We take the convention that `"x"` is the input and `"y"` is the output. These are the only two required keys for the ML pipeline. The rest is considered metadata.

In [None]:
pickle.dump(
    {"grid": grid, "y": spectra, "x": acsf_array, "names": molecule_site_pairs, "origin_smiles": origin_smiles},
    open(fname, "wb"),
    protocol=pickle.HIGHEST_PROTOCOL
)