In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [None]:
from collections import Counter
import json
from pathlib import Path
import string
import sys

from rdkit import Chem
from IPython.display import clear_output
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import pandas as pd
from scipy.spatial import distance_matrix
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import tqdm
import torch
from sklearn.metrics import balanced_accuracy_score, accuracy_score, f1_score, recall_score

from multimodal_molecules.core import Ensemble, Estimator, get_data
from multimodal_molecules.plotting import set_defaults, set_grids, density_scatter, remove_axis_spines

In [None]:
class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return json.JSONEncoder.default(self, obj)

In [None]:
set_defaults()

In [None]:
def load_qm9_xyz(path):
    with open(path, "r") as f:
        lines = f.readlines()
    smiles = lines[-2].strip().split()[1]
    return lines, smiles

In [None]:
smiles_to_info = {}

In [None]:
for fname in tqdm.tqdm(list(Path("/Users/mc/Data/QM9").glob("*.xyz"))):
    lines, smiles = load_qm9_xyz(fname)
    smiles_to_info[smiles] = {"xyz": lines, "fname": fname.name}

# Compare N-XANES experiment to molecules in QM9

In [None]:
element = "O"

In [None]:
feff_data = get_data(elements=element)
all_spectra = np.concatenate([feff_data["X_train"], feff_data["X_val"], feff_data["X_test"]], axis=0)
all_smiles = feff_data["smiles_train"] + feff_data["smiles_val"] + feff_data["smiles_test"]
all_smiles_canonicalized = [Chem.CanonSmiles(smile) for smile in all_smiles]

In [None]:
with open(f"experiment/{element.lower()}_exp.json", "r") as f:
    exp = json.load(f)
grid = np.array(exp["feff_grid"])

In [None]:
new_data = []

In [None]:
for filename, value in exp["data"].items():
    if filename in exp["errors"]:
        continue

    # Get possible keys from the experimental data
    try:
        key1 = Chem.CanonSmiles(value["smiles"])
    except:
        key1 = None
    key2 = value["smiles"]
    qm9_data1 = smiles_to_info.get(key1)
    qm9_data2 = smiles_to_info.get(key2)
    if qm9_data1 is None and qm9_data2 is None:
        continue
    
    # At this stage we have the QM9 structural data
    # and the corresponding experimental data
    # Now check to see if we have the FEFF data
    feff_index = None
    if qm9_data1 is not None:
        try:
            feff_index = all_smiles.index(key1)
        except ValueError:
            try:
                feff_index = all_smiles_canonicalized.index(key1)
            except ValueError:
                pass
    if feff_index is None:
        try:
            feff_index = all_smiles.index(key2)
        except ValueError:
            try:
                feff_index = all_smiles_canonicalized.index(key2)
            except ValueError:
                pass
    if feff_index is None:
        continue

    if qm9_data1 is None and qm9_data2 is not None:
        qm9_data = qm9_data2
        smiles = key2
    else:
        qm9_data = qm9_data1
        smiles = key1
    
    # If we're here, we also have the feff index in the
    # ML training data and thus we have a successful FEFF
    # spectrum
    feff_spectrum = all_spectra[feff_index, :]
    exp_spectrum = np.array(value["infilled_spectrum"])
    d = {
        "feff_spectrum": feff_spectrum,
        "exp_spectrum": exp_spectrum,
        "key1": key1,
        "key2": key2,
        "exp_filename": filename,
        "smiles": smiles,
        "qm9_data": qm9_data
    }
    new_data.append(d)

In [None]:
len(new_data)

In [None]:
fig, axs = plt.subplots(len(new_data), 1, figsize=(3, 2 * len(new_data)), sharex=True)

for d, ax in zip(new_data, axs):
    ax.plot(grid, d["feff_spectrum"], color="red")
    ax.plot(grid, d["exp_spectrum"], color="blue")

In [None]:
root = Path(f"experiment/experiment_feff_pairings/{element}")

In [None]:
with open(root / "data.json", "w") as f:
    json.dump(new_data, f, cls=NumpyEncoder)

In [None]:
for item in new_data:
    p = root / item["exp_filename"]
    p.mkdir(exist_ok=True, parents=True)
    with open(p / "structure.xyz", "w") as f:
        for line in item["qm9_data"]["xyz"]:
            f.write(line)