In [1]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

Helper functions and imports

In [2]:
from collections import Counter

import json
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import pandas as pd
import pickle
from tqdm import tqdm

from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

In [3]:
from multimodal_molecules.data import get_dataset

In [4]:
labelsize = 12
plt.rcParams["figure.figsize"] = (3, 2)
mpl.rcParams['mathtext.fontset'] = 'stix'
mpl.rcParams['font.family'] = 'STIXGeneral'
mpl.rcParams['text.usetex'] = True
plt.rc('xtick', labelsize=labelsize)
plt.rc('ytick', labelsize=labelsize)
plt.rc('axes', labelsize=labelsize)
mpl.rcParams['figure.dpi'] = 250

In [6]:
with open(path, 'r') as infile:
    functional_groups = json.load(infile)

Join everything together into one massive counter. Interpretation fo the functional group labels can be found [here](https://github.com/openbabel/openbabel/blob/master/data/SMARTS_InteLigand.txt).

In [7]:
all_functional_groups_enumerated = [g for groups in functional_groups.values() for g in groups]

In [8]:
counter = Counter(all_functional_groups_enumerated)

We keep only a subset of the total functional groups according to certain heuristics:
- Must have at least 2000 occurrences in the data
- Pop `Acetal` in favor of the more general `Acetal_like`

In [9]:
counter = {key: value for key, value in dict(counter).items() if value >= 2000}
counter.pop("Acetal")
print(len(counter))

62


In [44]:
set("co") - set("n")

{'c', 'o'}

In [59]:
must_not_have = set("n")
smiles = set("cccc")

In [61]:
must_not_have - smiles != must_not_have

False

Plot where things light up after PCA-decomposing the spectra.

In [21]:
data = pickle.load(open("data/221205/qm9_molecule_xanes.pkl", "rb"))

In [30]:
data2 = get_dataset(
    spectral_data_path="data/221205/qm9_molecule_xanes.pkl",
    functional_group_data_path="data/221205/functional_groups.json",
    elements="C,O"
)

In [31]:
data2.keys()

dict_keys(['C-XANES', 'O-XANES', 'SMILES', 'nones', 'FG', 'FG-Counter'])

In [36]:
len(data2["O-XANES"])

44871

In [24]:
XANES_C = []
XANES_N = []
XANES_O = []
SMILES = []
Nones = []
FGS = []

for smile, dat in data["data"].items():
    if not ("C" in smile and "N" in smile and "O" in smile):
        continue
        
    if dat["C-XANES"] is None or dat["O-XANES"] is None or dat["N-XANES"] is None:
        Nones.append(smile)
        continue
    
    SMILES.append(smile)
    XANES_C.append(np.array(dat["C-XANES"]))
    XANES_N.append(np.array(dat["N-XANES"]))
    XANES_O.append(np.array(dat["O-XANES"]))
    FGS.append(functional_groups[smile])
    
XANES_C = np.array(XANES_C)
XANES_N = np.array(XANES_N)
XANES_O = np.array(XANES_O)

In [25]:
XANES_N.shape

(65382, 200)

In [None]:
XANES_stacked = np.concatenate([XANES_C, XANES_N, XANES_O], axis=1)

In [None]:
XANES_C_scaler = StandardScaler()
XANES_C_scaled = XANES_C_scaler.fit_transform(XANES_C)

XANES_N_scaler = StandardScaler()
XANES_N_scaled = XANES_N_scaler.fit_transform(XANES_N)

XANES_O_scaler = StandardScaler()
XANES_O_scaled = XANES_O_scaler.fit_transform(XANES_O)

XANES_stacked_scaler = StandardScaler()
XANES_stacked_scaled = XANES_stacked_scaler.fit_transform(XANES_stacked)

In [None]:
pca_C = PCA(2)
w_C = pca_C.fit_transform(XANES_C)

pca_N = PCA(2)
w_N = pca_N.fit_transform(XANES_N)

pca_O = PCA(2)
w_O = pca_O.fit_transform(XANES_O)

pca = PCA(2)
w_stacked = pca.fit_transform(XANES_stacked)

In [None]:
def get_functional_group_presence(groups=["Alkyne"], fgs=FGS):
    groups = set(groups)
    return np.array([groups.issubset(set(fg_in_molecule)) for fg_in_molecule in fgs]).astype(int)

In [None]:
unique_functional_groups = sorted(list(counter.keys()))

Save a relatively large pdf for visual inspection (this isn't really for any manuscript).

In [None]:
fig, axs_all = plt.subplots(len(unique_functional_groups), 4, figsize=(8, 2*len(unique_functional_groups)), sharey=True, sharex=True)


cc = 0
for axs, g in zip(axs_all, unique_functional_groups):
    binary = get_functional_group_presence(groups=[g])
    where_on = np.where(binary)[0]
    where_off = np.where(np.abs(binary - 1))[0]
    for ax, w in zip(axs, [w_C, w_N, w_O, w_stacked]):
        ax.scatter(w[where_off, 0], w[where_off, 1], s=0.5, c="grey", rasterized=True)
        ax.scatter(w[where_on, 0], w[where_on, 1], s=0.5, c="black", rasterized=True)
        ax.tick_params(which='both', direction='in', bottom=True, left=True, top=True, right=True)
        ax.set_xticklabels([])
        ax.set_yticklabels([])
    axs[0].set_ylabel(g)

axs_all[0, 0].set_title("C")
axs_all[0, 1].set_title("N")
axs_all[0, 2].set_title("O")
axs_all[0, 3].set_title("CNO")
    

plt.subplots_adjust(hspace=0.1, wspace=0.1)
plt.savefig("pca_functional_groups.pdf", bbox_inches="tight", dpi=300)
plt.clf()