In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [None]:
from collections import Counter
import json
from pathlib import Path
import string
import sys

from rdkit import Chem
from IPython.display import clear_output
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy as np
import pandas as pd
from scipy.spatial import distance_matrix
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler
import tqdm
import torch
from sklearn.metrics import balanced_accuracy_score, accuracy_score, f1_score, recall_score

from multimodal_molecules.core import Ensemble, Estimator, get_data
from multimodal_molecules.plotting import set_defaults, set_grids, density_scatter, remove_axis_spines

In [None]:
set_defaults()

# Helper functions

In [None]:
def get_deviating_estimators(preds):
    N = preds.shape[0]
    assert N % 2 == 0
    s = preds.round().sum(axis=0)
    where = np.where(s > N//2)
    s[where] = N - s[where]
    return s

# Constants

In [None]:
C_grid = np.loadtxt("data/c_grid.txt")
N_grid = np.loadtxt("data/n_grid.txt")
O_grid = np.loadtxt("data/o_grid.txt")

# Multimodal PCA

In [None]:
data = get_data(elements="CNO")

In [None]:
pca_CNO = PCA(2)
scaler = StandardScaler()
pca_CNO.fit(scaler.fit_transform(data["X_test"]))
# pca_CNO.fit(data["X_test"])

In [None]:
w_CNO = pca_CNO.transform(scaler.transform(data["X_test"]))
# w_CNO = pca_CNO.transform(data["X_test"])

In [None]:
def get_example(p, w=w_CNO, keep=10):
    p = np.array(p)
    delta = np.abs(w - p).sum(axis=1)
    argsorted = np.argsort(delta)
    return argsorted[:keep]

In [None]:
cmap = mpl.colormaps["viridis"].resampled(8)

In [None]:
# Edge points
p1 = get_example([-2, 29])
p2 = get_example([20, 5])
p3 = get_example([5, -20])
p4 = get_example([-22, -3])
p5 = get_example([-16, 5])
p6 = get_example([-1.5, 1])
p7 = get_example([-2.5, -5])
p8 = get_example([12.5, -9.5])
selected_points = [p1, p2, p3, p4, p5, p6, p7, p8]
colors = [cmap(ii) for ii in range(len(selected_points))]
smiles = [[data["smiles_test"][xx] for xx in yy] for yy in selected_points]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(2, 2), sharey=True)

# remove_axis_spines(ax, visible=True)

density_scatter(w_CNO[:, 0], w_CNO[:, 1], ax, bins=20, s=1, cmap="binary", rasterized=True)

for point, color in zip(selected_points, colors):
    ax.scatter(w_CNO[point[0], 0], w_CNO[point[0], 1], color=color, marker="o", s=20, facecolors='none', linewidth=1)

ax.set_xlabel(r"$z_1$ [a.u.]")
ax.set_ylabel(r"$z_2$ [a.u.]")
ax.set_xticks([])
ax.set_yticks([])
ax.text(0.9, 0.9, "(a)", ha="center", va="center", transform=ax.transAxes)


plt.show()
# plt.savefig("figures/fig_pca_1/fig_pca_1.svg", bbox_inches="tight", dpi=300)

In [None]:
letters = ["i", "ii", "iii", "iv", "v", "vi", "vii", "viii"]

In [None]:
lw = 0.5
offset = 1
C_limit = 160
N_limit = 160
O_limit = 120

fig, all_axs = plt.subplots(8, 3, figsize=(5, 8), sharey=True)

for ii, (points, smile, color, letter, axs) in enumerate(zip(selected_points, smiles, colors, letters, all_axs)):
    
    # fig, axs = plt.subplots(1, 3, figsize=(6, 1), sharey=True)

    axs[0].text(0.1, 0.9, f"({letter})", ha="center", va="center", transform=axs[0].transAxes)
    
    for ax in axs:

        for p in points:
            ax.plot(C_grid[:C_limit], data["X_test"][p, :C_limit] + offset, color=color, linewidth=lw)
        
        for p in points:
            ax.plot(N_grid[:N_limit], data["X_test"][p, 200:200+N_limit] + offset, color=color, linewidth=lw)
        
        for p in points:
            ax.plot(O_grid[:O_limit], data["X_test"][p, 400:400+O_limit] + offset, color=color, linewidth=lw)

        ax.set_yticks([])
        ax.spines[["right", "top"]].set_visible(False)
        ax.tick_params(
            which='both', direction='in', bottom=True, left=True
        )

    axs[0].set_xlim(C_grid[0], C_grid[C_limit])
    axs[0].set_xticks([280, 315])
    axs[1].set_xlim(N_grid[0], N_grid[N_limit])
    axs[1].set_xticks([400, 432])
    axs[2].set_xlim(O_grid[0], O_grid[O_limit])
    axs[2].set_xticks([533, 558])

    # Hide specific spines
    axs[1].spines.left.set_visible(False)
    axs[2].spines.left.set_visible(False)

    if ii == len(all_axs) - 1 or ii == len(all_axs) // 2 - 1:
        axs[0].set_ylabel(r"$\mu(E)$ [a.u.]")
        axs[1].set_xlabel(r"$E$ [e.V.]")
    else:
        for ax in axs:
            ax.set_xticklabels([])

    d = 1.5  # proportion of vertical to horizontal extent of the slanted line
    kwargs = dict(marker=[(-1, -d), (1, d)], markersize=4, linestyle="none", color='k', mec='k', mew=0.75, clip_on=False)
    axs[0].plot([1, 1], [0, 0], transform=axs[0].transAxes, **kwargs)
    axs[1].plot([0, 0], [0, 0], transform=axs[1].transAxes, **kwargs)
    axs[1].plot([1, 1], [0, 0], transform=axs[1].transAxes, **kwargs)
    axs[2].plot([0, 0], [0, 0], transform=axs[2].transAxes, **kwargs)
        
    # Get the smiles...
    # svgs = Chem.Draw.MolsToGridImage([Chem.MolFromSmiles(s) for s in smile], useSVG=True)
    # with open(f"figures/fig_pca_1/example_structures_{letter}.svg", "w") as f:
    #     f.write(svgs.data)

plt.subplots_adjust(wspace=0.05, hspace=0.2)
# plt.savefig("figures/fig_pca_1/example_spectra.svg", bbox_inches="tight", dpi=300)
plt.show()

In [None]:
for xx in indexes[2:]:
    smiles = [data["smiles_test"][ii] for ii in xx]
    svgs = Chem.Draw.MolsToGridImage([Chem.MolFromSmiles(s) for s in smiles], useSVG=True)
    break

In [None]:
svgs

# Carbon exploration and validation

In [None]:
ELEMENT = "C"

In [None]:
ensemble = Ensemble.from_path(f"data/23-12-06_torch_models/{ELEMENT}")
data = get_data(elements=ELEMENT)
grid = np.loadtxt(f"data/{ELEMENT.lower()}_grid.txt")

In [None]:
mean_spectrum = data["X_train"].mean(axis=0).squeeze()
for ii, (e, s) in enumerate(zip(grid, mean_spectrum)):
    if s > 1:
        break
print(grid[ii])

## PCA

We make some nice plots here by doing the following:
1. Decompose the spectral data into 2 dimensions
2. Select interesting functional groups
3. Plot the first two principal components of the data against labeled data
4. Plot using 2d density histograms

In [None]:
pca = PCA(2)
X = data["X_test"]
w = pca.fit_transform(X)

In [None]:
interesting_functional_groups = [
    # None,
    "Quaternary_carbon",
    "Alkene",
    "Alkyne",
    "Aromatic",
]
N = len(interesting_functional_groups)

In [None]:
smiles = ["CC(C)(C)(C)", "C(C)(C)=C(C)(C)", "C(C)#C(C)", "c1ccccc1"]

In [None]:
svgs = Chem.Draw.MolsToGridImage([Chem.MolFromSmiles(s) for s in smiles], useSVG=True)

In [None]:
svgs

In [None]:
with open("figures/fig_pca_2-carbon/motifs.svg", "w") as f:
    f.write(svgs.data)

In [None]:
chars = ["i", "ii", "iii", "iv"]
fig, axs = plt.subplots(1, N, figsize=(2 * N, 2), sharey=True)

for fg, ax, label in zip(interesting_functional_groups, axs, chars):
    remove_axis_spines(ax, visible=True)

    density_scatter(w[:, 0], w[:, 1], ax, bins=20, s=0.2, cmap="binary", rasterized=True)
    
    if fg is not None:
        index = data["functional_groups"].index(fg)
        condition = (data["Y_test"][:, index] == 1)
        where = np.where(condition)[0]
        density_scatter(w[where, 0], w[where, 1], ax, s=0.3, cmap="viridis", rasterized=True)

    ax.text(0.05, 0.95, f"({label})", ha="left", va="top", transform=ax.transAxes)

axs[0].set_xlabel(r"$z_1$~[a.u.]")
axs[0].set_ylabel(r"$z_2$~[a.u.]")

plt.subplots_adjust(wspace=0.1)
# plt.savefig("figures/fig_pca_2-carbon/pca.svg", bbox_inches='tight', dpi=300)
plt.show()

In [None]:
mult = 2

fig, ax = plt.subplots(1, 1, figsize=(1.5, 2), sharey=True)

# remove_axis_spines(ax, visible=True)

cmap = mpl.colormaps["rainbow"].resampled(len(interesting_functional_groups))

for ii, (fg, label) in enumerate(zip(interesting_functional_groups, chars)):
    
    index = data["functional_groups"].index(fg)
    condition = (data["Y_test"][:, index] == 1)
    where = np.where(condition)[0]
    X_where = X[where, :]
    mu = X_where.mean(axis=0)
    sd = X_where.std(axis=0)
    ax.plot(grid, mu - ii*mult, color=cmap(ii))
    for jj in range(10):
        ax.plot(grid, X_where[jj, :] - ii * mult, color=cmap(ii), alpha = 0.5, linewidth=0.5)
    # ax.fill_between(grid, mu + ii*mult - sd, mu + ii*mult + sd, linewidth=0, color=cmap(ii), alpha=0.2)

# ax.legend()
ax.set_yticks([])
set_grids(ax, grid=False)
ax.set_ylabel(r"$\mu(E)$~[a.u.]")
ax.set_xlabel(r"$E$~[eV]")
ax.set_xticks([280, 290, 300, 310])
ax.set_xlim([275, 315])

# plt.subplots_adjust(wspace=0.1)
# plt.savefig("figures/fig_pca_2-carbon/spectra.svg", bbox_inches='tight', dpi=300)
plt.show()

# Nitrogen exploration and validation

In [None]:
ELEMENT = "N"

In [None]:
ensemble = Ensemble.from_path(f"data/23-12-06_torch_models/{ELEMENT}")
data = get_data(elements=ELEMENT)
grid = np.loadtxt(f"data/{ELEMENT.lower()}_grid.txt")
X = data["X_test"]

In [None]:
mean_spectrum = data["X_train"].mean(axis=0).squeeze()
for ii, (e, s) in enumerate(zip(grid, mean_spectrum)):
    if s > 1:
        break
print(grid[ii])

## PCA

We make some nice plots here by doing the following:
1. Decompose the spectral data into 2 dimensions
2. Select interesting functional groups
3. Plot the first two principal components of the data against labeled data
4. Plot using 2d density histograms

In [None]:
pca = PCA(2)
w = pca.fit_transform(X)

In [None]:
data["functional_groups"]

In [None]:
interesting_functional_groups = [
    # None,
    "Amine",
    "Amide",
    "Nitrile",
    "Heteroaromatic",
    # "Secondary_aliph_amine",
    # "Tertiary_aliph_amine",
]
N = len(interesting_functional_groups)

In [None]:
fig, axs = plt.subplots(1, N, figsize=(2 * N, 2), sharey=True)

for ax in axs:
    ax.set_aspect("equal")

for fg, ax in zip(interesting_functional_groups, axs):
    remove_axis_spines(ax, visible=True)

    density_scatter(w[:, 0], w[:, 1], ax, bins=20, s=0.2, cmap="Blues", rasterized=True)
    
    if fg is not None:
        index = data["functional_groups"].index(fg)
        condition = (data["Y_test"][:, index] == 1)

        # Plot where one nitrogen
        condition_only_one = [Counter(xx.lower())["n"] == 1 for xx in data["smiles_test"]]
        where = np.where(condition & condition_only_one)[0]
        density_scatter(w[where, 0], w[where, 1], ax, s=0.3, cmap="viridis", rasterized=True)

        # Plot where two nitrogen
        # condition_only_one = [Counter(xx.lower())["n"] == 2 for xx in data["smiles_test"]]
        # where = np.where(condition & condition_only_one)[0]
        # density_scatter(w[where, 0], w[where, 1], ax, s=0.3, cmap="viridis", rasterized=True)
        

    # ax.text(0.05, 0.95, f"({label})", ha="left", va="top", transform=ax.transAxes)

axs[0].set_xlabel(r"$z_1$~[a.u.]")
axs[0].set_ylabel(r"$z_2$~[a.u.]")

plt.subplots_adjust(wspace=0.1)
plt.show()
# plt.savefig("figures/fig_pca_2-nitrogen/pca.svg", bbox_inches="tight", dpi=300)

### Plots of actual spectra examples

In [None]:
mult = 4

fig, ax = plt.subplots(1, 1, figsize=(2, 2), sharey=True)

# remove_axis_spines(ax, visible=True)

cmap = mpl.colormaps["rainbow"].resampled(len(interesting_functional_groups))

for ii, (fg,) in enumerate(zip(interesting_functional_groups)):
    
    index = data["functional_groups"].index(fg)
    condition = (data["Y_test"][:, index] == 1)
    condition_only_one = [Counter(xx.lower())["n"] == 1 for xx in data["smiles_test"]]
    
    where = np.where(condition & condition_only_one)[0]
    
    X_where = X[where, :]
    mu = X_where.mean(axis=0)
    sd = X_where.std(axis=0)
    ax.plot(grid, mu - ii*mult, color=cmap(ii), zorder=-ii)
    for jj in range(100, 110):
        ax.plot(grid, X_where[jj, :] - ii * mult, color=cmap(ii), alpha = 0.5, linewidth=0.5)
    # ax.fill_between(grid, mu + ii*mult - sd, mu + ii*mult + sd, linewidth=0, color=cmap(ii), alpha=0.2)

ax.set_yticks([])
set_grids(ax, grid=False)
ax.set_ylabel(r"$\mu(E)$~[a.u.]")
ax.set_xlabel(r"$E$~[eV]")
ax.set_xticks([400, 410, 420, 430])
ax.set_xlim([397, 431])

# plt.subplots_adjust(wspace=0.1)
plt.savefig("figures/fig_pca_3-nitrogen-oxygen/spectra.svg", bbox_inches='tight', dpi=300)
# plt.show()

# Testing on experiment

In [None]:
with open("experiment/n_exp.json", "r") as f:
    exp_data = json.load(f)

Get the infilled spectra from the loaded dictionary.

In [None]:
spectra_from_smiles = {
    value["smiles"]: value["infilled_spectrum"]
    for key, value in exp_data["data"].items()
    if key not in exp_data["errors"]
}

In [None]:
for key, value in spectra_from_smiles.items():
    value = np.array(value)
    value[value < 0] = 0
    spectra_from_smiles[key] = value

The index specifies the mapping between the smiles and targets. Note that we screen out the data we're not going to use on the functional group side of things here. For example, here we screen out any data that do not conform to the "qm9" standard (C, N, O, F, H).

In [None]:
def _qm9_condition(smile):
    if "+" in smile or "-" in smile:
        return False
    alphanumeric = [s for s in smile if s.isalpha()]
    return all([xx.lower() in ["c", "n", "o", "f", "h"] for xx in alphanumeric])

def _error_free_condition(smiles, s=spectra_from_smiles):
    return smiles in s

def get_index(
    index=None,
    require_error_free=True,
    require_qm9=False,
    require_functional_group=None,
    feff_functional_groups=data["functional_groups"]
):
    
    if index is None:
        index = pd.read_csv("experiment/index.csv", index_col=0)

    feff_mask = None
    if feff_functional_groups is not None:

        # We only want to keep the subset of experimental and simulation functional groups
        common_fg = sorted(list(set(feff_functional_groups).intersection(set(index.columns[1:]))))

        # Apply the mask
        index = index[["SMILES"] + common_fg]

        # Compute the feff mask as a helper
        feff_mask = np.array([ii for ii, fg in enumerate(feff_functional_groups) if fg in common_fg])
        for mask_index, fg in zip(feff_mask, index.columns[1:]):
            assert feff_functional_groups[mask_index] == fg

    if require_qm9:
        keep = index["SMILES"].apply(_qm9_condition)
        index = index[keep]

    if require_error_free:
        keep = index["SMILES"].apply(_error_free_condition)
        index = index[keep]

    if require_functional_group is not None:
        index = index[index[require_functional_group] == 1]

    return index, feff_mask

In [None]:
FUNCTIONAL_GROUP = "Amine"

In [None]:
index, feff_mask = get_index(require_functional_group=FUNCTIONAL_GROUP)

Construct the new testing data (experimental results).

In [None]:
X_exp = []
Y_exp = []
for ii, row in index.iterrows():
    smiles = row["SMILES"]
    Y_exp.append(row[1:].to_numpy())
    X_exp.append(spectra_from_smiles[smiles])
X_exp = np.array(X_exp)
Y_exp = np.array(Y_exp, dtype=int)

In [None]:
feff_FG_index = data["functional_groups"].index(FUNCTIONAL_GROUP)

In [None]:
preds = ensemble.predict(torch.FloatTensor(X_exp))
mu = np.array(preds.round().mean(axis=0).round(), dtype=int)

This logic gets the number of deviating estimators, a proxy for uncertainty.

In [None]:
deviating_estimators = get_deviating_estimators(preds)

## Compare Amines from experiment and theory

First extract relevant information from FEFF

In [None]:
from collections import Counter

In [None]:
feff_smiles_test = data["smiles_test"]
where_one_n = np.array([1 if Counter(smile.lower())["n"] == 1 else 0 for smile in feff_smiles_test])

In [None]:
locations = np.where(
    (data["Y_test"][:, feff_FG_index] == 1) & (where_one_n == 1)
)[0] # Get all the Amines

In [None]:
kept_feff_spectra = data["X_test"][locations, :]

In [None]:
istar_exp = index.columns[1:].tolist()
istar_exp = istar_exp.index(FUNCTIONAL_GROUP)

In [None]:
def get_closest_spectrum(x1, X):
    d = distance_matrix(x1.reshape(1, -1), X)
    return np.argsort(d).squeeze()[:5]

In [None]:
all_inputs = np.concatenate([data["X_train"], data["X_val"], data["X_test"]], axis=0)
all_smiles = data["smiles_train"] + data["smiles_val"] + data["smiles_test"]

In [None]:
all_canon_smiles = [Chem.CanonSmiles(smile) for smile in all_smiles]

In [None]:
fs=4

fig, axs = plt.subplots(len(index), 1, figsize=(3, 2 * len(index)))

cc = 0
for _, row in index.iterrows():

    ax = axs[cc]
    
    smiles = row["SMILES"]
    y = row[1:].to_numpy()
    x = spectra_from_smiles[smiles]
    pred = mu[cc, feff_FG_index].item()
    dev = deviating_estimators[cc, feff_FG_index].item()
    target = Y_exp[cc, istar_exp].item()
    ax.plot(exp_data["feff_grid"], x, color="black", linewidth=0.5)

    # See if smiles exists in training set
    try:
        s = Chem.CanonSmiles(smiles)
        istar = all_canon_smiles.index(s)
        x_star = all_inputs[istar]
        ax.plot(exp_data["feff_grid"], x_star, color="blue", linewidth=0.5)
        ax.text(0.99, 0.75, f"smile in FEFF data: {s}", ha="right", va="top", transform=ax.transAxes, fontsize=fs, color="blue")
    except ValueError:
        pass

    closest_x = get_closest_spectrum(np.array(x), data["X_test"]).squeeze()
    for ii in range(len(closest_x)):
        closest_smiles = data["smiles_test"][closest_x[ii]]
        closest_smiles = closest_smiles.replace("#", "\#")
        ax.plot(
            exp_data["feff_grid"], data["X_test"][closest_x[ii], :],
            color="red", linewidth=0.5
        )
        ax.text(
            0.99, 0.7 - ii*0.05, "%s" % closest_smiles,
            ha="right", va="top", transform=ax.transAxes, fontsize=fs, color="red"
        )
    # ax.plot(exp_data["feff_grid"], kept_feff_spectra.mean(axis=0), color="red")

    ax.text(0.99, 0.95, f"exp: {smiles}", ha="right", va="top", transform=ax.transAxes, fontsize=fs)
    ax.text(0.99, 0.9, f"pred={pred}", ha="right", va="top", transform=ax.transAxes, fontsize=fs)
    ax.text(0.99, 0.85, f"target={target}", ha="right", va="top", transform=ax.transAxes, fontsize=fs)
    ax.text(0.99, 0.8, f"dev={dev}", ha="right", va="top", transform=ax.transAxes, fontsize=fs)
    # break

    cc += 1
    # break
plt.show()

In [None]:
closest_x

In [None]:
closest_smiles

### Experiment pca

In [None]:
exp_w = pca.transform(X_exp)
kept_feff_spectra_w = pca.transform(kept_feff_spectra)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(2, 2), sharey=True)

remove_axis_spines(ax, visible=True)

density_scatter(w[:, 0], w[:, 1], ax, bins=20, s=0.2, cmap="Blues", rasterized=True)

istar = data["functional_groups"].index("Amine")
condition = (data["Y_test"][:, istar] == 1)
where = np.where(condition)[0]
density_scatter(w[where, 0], w[where, 1], ax, s=0.3, cmap="viridis", rasterized=True)

for ii, point in enumerate(exp_w):
    pred = mu[ii, feff_FG_index].item()
    print(pred)
    if pred == 1:
        ax.scatter(point[0], point[1], c="black", marker="x", s=5)
    else:
        ax.scatter(point[0], point[1], c="black", marker="o", s=5)

# density_scatter(kept_feff_spectra_w[:, 0], kept_feff_spectra_w[:, 1], ax, s=0.3, cmap="magma")

# ax.text(0.05, 0.95, f"({label})", ha="left", va="top", transform=ax.transAxes)

# plt.subplots_adjust(wspace=0.1)
plt.show()
# plt.savefig("figures/n_pca.svg", bbox_inches="tight", dpi=300)

# Oxygen exploration and validation

In [None]:
ELEMENT = "O"

In [None]:
ensemble = Ensemble.from_path(f"data/23-12-06_torch_models/{ELEMENT}")
data = get_data(elements=ELEMENT)
grid = np.loadtxt(f"data/{ELEMENT.lower()}_grid.txt")

In [None]:
mean_spectrum = data["X_train"].mean(axis=0).squeeze()
for ii, (e, s) in enumerate(zip(grid, mean_spectrum)):
    if s > 1:
        break
print(grid[ii])

## PCA

See `Nitrogen-Oxygen preliminary analysis figures` section.

# Nitrogen-Oxygen preliminary analysis figures

In [None]:
N = 4

ELEMENT = "N"
data_N = get_data(elements=ELEMENT)
grid_N = np.loadtxt(f"data/{ELEMENT.lower()}_grid.txt")
interesting_functional_groups_N = [
    # None,
    "Amine",
    "Amide",
    "Nitrile",
    "Heteroaromatic",
    # "Secondary_aliph_amine",
    # "Tertiary_aliph_amine",
]
pca = PCA(2)
w_N = pca.fit_transform(data_N["X_test"])

smiles = ["N(C)(C)(C)", "CC(=O)N(C)(C)", "CC#N", "c1ccccn1"]
svgs = Chem.Draw.MolsToGridImage([Chem.MolFromSmiles(s) for s in smiles], useSVG=True)
with open("figures/fig_pca_3-nitrogen-oxygen/n_motifs.svg", "w") as f:
    f.write(svgs.data)


ELEMENT = "O"
data_O = get_data(elements=ELEMENT)
grid_O = np.loadtxt(f"data/{ELEMENT.lower()}_grid.txt")
interesting_functional_groups_O = [
    # None,
    "Alcohol",
    "Epoxide",
    "Carboxylic_acid_derivative",
    "Ketone",
]
pca = PCA(2)
w_O = pca.fit_transform(data_O["X_test"])

smiles = ["CO",  "C1CO1", "CC(=O)O", "CC(=O)C"]
svgs = Chem.Draw.MolsToGridImage([Chem.MolFromSmiles(s) for s in smiles], useSVG=True)
with open("figures/fig_pca_3-nitrogen-oxygen/o_motifs.svg", "w") as f:
    f.write(svgs.data)

In [None]:
labels = ["i", "ii", "iii", "iv", "v", "vi", "vii", "viii"]
fig, axs_all = plt.subplots(2, N, figsize=(2 * N, 4), sharey=True, sharex=True)

cc = 0

axs = axs_all[0]
w = w_N
data = data_N
ifg = interesting_functional_groups_N
for fg, ax in zip(ifg, axs):
    remove_axis_spines(ax, visible=True)
    density_scatter(w[:, 0], w[:, 1], ax, bins=20, s=0.2, cmap="Blues", rasterized=True)
    if fg is not None:
        index = data["functional_groups"].index(fg)
        condition = (data["Y_test"][:, index] == 1)
        # Plot where one nitrogen
        condition_only_one = [Counter(xx.lower())["n"] == 1 for xx in data["smiles_test"]]
        where = np.where(condition & condition_only_one)[0]
        density_scatter(w[where, 0], w[where, 1], ax, s=0.3, cmap="viridis", rasterized=True)
        ax.text(0.05, 0.95, f"({labels[cc]})", ha="left", va="top", transform=ax.transAxes)
        cc += 1

axs = axs_all[1]
w = w_O
data = data_O
ifg = interesting_functional_groups_O
for fg, ax in zip(ifg, axs):
    remove_axis_spines(ax, visible=True)
    density_scatter(w[:, 0], w[:, 1], ax, bins=20, s=0.2, cmap="Reds", rasterized=True)
    if fg is not None:
        index = data["functional_groups"].index(fg)
        condition = (data["Y_test"][:, index] == 1)
        # Plot where one oxygen
        if fg == "Carboxylic_acid_derivative":
            condition_only_one = [Counter(xx.lower())["o"] == 2 for xx in data["smiles_test"]]
        else:
            condition_only_one = [Counter(xx.lower())["o"] == 1 for xx in data["smiles_test"]]
        where = np.where(condition & condition_only_one)[0]
        density_scatter(w[where, 0], w[where, 1], ax, s=0.3, cmap="viridis", rasterized=True)
        ax.text(0.05, 0.95, f"({labels[cc]})", ha="left", va="top", transform=ax.transAxes)
        cc += 1

axs[0].set_xlabel(r"$z_1$~[a.u.]")
axs[0].set_ylabel(r"$z_2$~[a.u.]")

plt.subplots_adjust(wspace=0.1, hspace=0.1)
plt.show()
# plt.savefig("figures/fig_pca_3-nitrogen-oxygen/pca.svg", bbox_inches="tight", dpi=300)

## Spectra

In [None]:
mult = 4

fig, axs = plt.subplots(2, 1, figsize=(1.5, 4), sharey=True)

# remove_axis_spines(ax, visible=True)

cmap = mpl.colormaps["rainbow"].resampled(len(interesting_functional_groups))

ax = axs[0]
data = data_N
ifg = interesting_functional_groups_N
X = data_N["X_test"]
grid = np.loadtxt("data/N_grid.txt")

for ii, (fg,) in enumerate(zip(ifg)):
    
    index = data["functional_groups"].index(fg)
    condition = (data["Y_test"][:, index] == 1)
    condition_only_one = [Counter(xx.lower())["n"] == 1 for xx in data["smiles_test"]]
    
    where = np.where(condition & condition_only_one)[0]
    
    X_where = X[where, :]
    mu = X_where.mean(axis=0)
    sd = X_where.std(axis=0)
    ax.plot(grid, mu - ii*mult, color=cmap(ii), zorder=-ii)
    for jj in range(100, 110):
        ax.plot(grid, X_where[jj, :] - ii * mult, color=cmap(ii), alpha = 0.5, linewidth=0.5)
    # ax.fill_between(grid, mu + ii*mult - sd, mu + ii*mult + sd, linewidth=0, color=cmap(ii), alpha=0.2)

ax.set_yticks([])
set_grids(ax, grid=False)
ax.set_ylabel(r"$\mu(E)$~[a.u.]")
# ax.set_xlabel(r"$E$~[eV]")
ax.set_xticks([400, 410, 420, 430])
ax.set_xlim([397, 431])

ax = axs[1]
data = data_O
ifg = interesting_functional_groups_O
X = data_O["X_test"]
grid = np.loadtxt("data/O_grid.txt")

for ii, (fg,) in enumerate(zip(ifg)):
    
    index = data["functional_groups"].index(fg)
    condition = (data["Y_test"][:, index] == 1)
    condition_only_one = [Counter(xx.lower())["o"] == 1 for xx in data["smiles_test"]]
    
    where = np.where(condition & condition_only_one)[0]
    
    X_where = X[where, :]
    mu = X_where.mean(axis=0)
    sd = X_where.std(axis=0)
    ax.plot(grid, mu - ii*mult, color=cmap(ii), zorder=-ii)
    for jj in range(100, 110):
        ax.plot(grid, X_where[jj, :] - ii * mult, color=cmap(ii), alpha = 0.5, linewidth=0.5)
    # ax.fill_between(grid, mu + ii*mult - sd, mu + ii*mult + sd, linewidth=0, color=cmap(ii), alpha=0.2)

ax.set_yticks([])
set_grids(ax, grid=False)
# ax.set_ylabel(r"$\mu(E)$~[a.u.]")
ax.set_xlabel(r"$E$~[eV]")
ax.set_xticks([530, 540, 550, 560])
ax.set_xlim([529, 561])

plt.subplots_adjust(wspace=0.1, hspace=0.15)
plt.savefig("figures/fig_pca_3-nitrogen-oxygen/spectra.svg", bbox_inches='tight', dpi=300)
# plt.show()

## Analyze specific regions

### Secondary/Tertiary amines

In [None]:
data_N["X_test"].shape

In [None]:
len(w_N)

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(2, 2))

w = w_O
data = data_O
ifg = interesting_functional_groups_O
fg = "Carboxylic_acid_derivative"

density_scatter(w[:, 0], w[:, 1], ax, bins=20, s=0.2, cmap="Reds", rasterized=True)

index = data["functional_groups"].index(fg)
condition = (data["Y_test"][:, index] == 1)
condition_only_one = [Counter(xx.lower())["o"] == 2 for xx in data["smiles_test"]]

# condition_location = (w[:, 0] > 0) & (w[:, 0] < 4.5)
condition_location = w[:, 0] > 4.5

where = np.where(condition & condition_only_one & condition_location)[0]

density_scatter(w[where, 0], w[where, 1], ax, s=0.3, cmap="viridis", rasterized=True)

plt.show()

In [None]:
condition = (data["Y_test"][:, index] == 1)
condition_only_one = [Counter(xx.lower())["o"] == 2 for xx in data["smiles_test"]]
condition_location = w[:, 0] > 4.5
where = np.where(condition & condition_only_one & condition_location)[0]
smiles = [data["smiles_test"][ii] for ii in where][:10]
svgs = Chem.Draw.MolsToGridImage([Chem.MolFromSmiles(s) for s in smiles], useSVG=True)
svgs

# Model results and UQ

In [None]:
N_containing_functional_groups = [
    '1,2-Aminoalcohol',
    'Lactam',
    'Amide',
    'Imidolactone',
    'Heterocyclic',
    'Hetero_N_basic_H',
    'Amine',
    'Hetero_N_nonbasic',
    'Primary_arom_amine',
    'Tertiary_aliph_amine',
    'NH_aziridine',
    'Heteroaromatic',
    'Secondary_aliph_amine',
    'Nitrile',
]

O_containing_functional_groups = [
    '1,2-Aminoalcohol',
    'Secondary_alcohol',
    'Carbonic_acid_derivatives',
    'Tertiary_alcohol',
    'Lactam',
    'Primary_alcohol',
    'Aldehyde',
    'Ketone',
    'Carboxylic_acid_derivative',
    'Epoxide',
    'Imidolactone',
    'Heterocyclic',
    'Dialkylether',
    'Phenol',
    'Heteroaromatic',
    'Hetero_O',
    'Alcohol',
]

def get_statistics(
    deviating_estimators,
    functional_groups,
    truth,
    mu,
    element,
):
    
    cba_mean = []
    cba_var = []
    cba_min = []
    cba_max = []
    cba_counts = []
    
    special = {"1,2-Aminoalcohol": []}
    
    for ii, n_dev in enumerate(range(0, 11)):
        tmp_cba = []
        tmp_weights = []
        where = np.where(deviating_estimators.flatten() == n_dev)[0]
        cba_counts.append(len(where))
        
        for jj, fg in enumerate(functional_groups):
            if element == "O" and fg not in O_containing_functional_groups:
                continue
            if element == "N" and fg not in N_containing_functional_groups:
                continue

            where = np.where(deviating_estimators[:, jj] == n_dev)[0]
            if len(where) < 10:
                continue

            acc = balanced_accuracy_score(truth[where, jj].flatten(), mu[where, jj].flatten())

            if element == "N" and fg == "1,2-Aminoalcohol":
                special["1,2-Aminoalcohol"].append(acc)
                continue

            tmp_weights.append(len(where))
            tmp_cba.append(acc)

        mean = np.average(tmp_cba, weights=tmp_weights)
        cba_mean.append(mean)
        var = np.average((mean - np.array(tmp_cba))**2, weights=tmp_weights)
        cba_var.append(var)
        cba_min.append(np.min(tmp_cba))
        cba_max.append(np.max(tmp_cba))

    return {
        "mean": np.array(cba_mean),
        "var": np.array(cba_var),
        "min": np.array(cba_min),
        "max": np.array(cba_max),
        "counts": np.array(cba_counts),
        **special
    }

In [None]:
# data = get_data(elements="O")
# ensemble = Ensemble.from_path("data/23-12-06_torch_models/O")
# preds = ensemble.predict(data["X_test"])
# mu = preds.round().mean(axis=0).round()
# deviating_estimators = get_deviating_estimators(preds)
# truth = data["Y_test"]
# functional_groups = data["functional_groups"]
# jj = functional_groups.index("1,2-Aminoalcohol")
# balanced_accuracy_score(truth[:, jj], mu[:, jj])

In [None]:
stats = {}
ELEMENTS = ["C", "N", "O", "CNO"]

In [None]:
# use this one for main results
target = "data/23-04-26-ml-data"
estimator_directory = Path("data/23-12-06_torch_models")

# use this one for the cutoff-8 results
# target = "data/23-05-11-ml-data-CUTOFF8"
# estimator_directory = Path("data/23-12-06_torch_models/cutoff8")

In [None]:
for elements in tqdm.tqdm(ELEMENTS):

    data = get_data(target=target, elements=elements)
    p = "-".join([e for e in elements])
    ensemble = Ensemble.from_path(estimator_directory / p)

    preds = ensemble.predict(data["X_test"])
    mu = preds.round().mean(axis=0).round()
    deviating_estimators = get_deviating_estimators(preds)
    truth = data["Y_test"]
    functional_groups = data["functional_groups"]
    print(elements, functional_groups)

    stats[elements] = get_statistics(
        deviating_estimators,
        functional_groups,
        truth,
        mu,
        elements,
    )

In [None]:
plot_kwargs = {
    'linewidth': 1.0,
    'marker': 's',
    'ms': 2.0,
    'capthick': 0.3,
    'capsize': 2.0,
    'elinewidth': 0.3
}
COLORS = ["black", "blue", "red", "grey"]

In [None]:
fig, axs = plt.subplots(1, 4, figsize=(6, 2), sharey=True)

for ii, (ax, color, (key, value)) in enumerate(zip(axs, COLORS, stats.items())):

    cba_mean = value["mean"]
    cba_var = value["var"]
    cba_std = np.sqrt(cba_var)
    cba_min = value["min"]
    cba_max = value["max"]
    cba_counts = value["counts"]
    
    ax.errorbar(
        range(0, 11),
        cba_mean,
        yerr=[cba_mean-cba_min, cba_max-cba_mean],
        color=color,
        label=key,
        **plot_kwargs
    )

    if ii == 1:  # N
        ax.errorbar(
            range(0, 11),
            value["1,2-Aminoalcohol"],
            color="purple",
            alpha=1,
            label="AA",
            **plot_kwargs
        )

    ax.legend(fancybox=True, fontsize=10)

    if ii > 0:
        ax2.set_yticklabels([])
    
    ax2 = ax.twinx()
    ax2.bar(range(0, 11), np.log10(cba_counts), color=color, alpha=0.2)
    ax2.set_ylim(bottom=3.4, top=7.1)
    ax2.set_yticks([4, 5, 6, 7])
    ax.tick_params(which="both", direction="in")
    ax2.tick_params(which="both", direction="in")

    if ii == 3:
        ax2.set_ylabel(r"$\log_{10}$ Counts")
    if ii == 0:
        ax.set_ylabel(r"CBA")

ax.set_xlabel(r"$N_\mathrm{d}$")

plt.show()

# plt.savefig("figures/fig_nd/fig_nd_2.svg", dpi=300, bbox_inches="tight")
# plt.show()

# Multimodal advantage

In [None]:
# use this one for main results
# target = "data/23-04-26-ml-data"
# estimator_directory = Path("data/23-12-06_torch_models/solo_estimators")

# use this one for the cutoff-8 results
target = "data/23-05-11-ml-data-CUTOFF8"
estimator_directory = Path("data/23-12-06_torch_models/solo_estimators_cutoff8")

In [None]:
data_CNO = get_data(target=target, elements="CNO")
data_C = get_data(target=target, elements="C")
data_N = get_data(target=target, elements="N")
data_O = get_data(target=target, elements="O")

In [None]:
def index_map(input_data, element):
    if element == "C":
        return input_data[:, :200]
    elif element == "N":
        return input_data[:, 200:400]
    elif element == "O":
        return input_data[:, 400:]
    elif element == "CNO":
        return input_data
    raise ValueError(f"invalid element {element}")

In [None]:
estimator_CNO_C = Estimator.from_path(estimator_directory / "C-N-O_only_C")
estimator_CNO_N = Estimator.from_path(estimator_directory / "C-N-O_only_N")
estimator_CNO_O = Estimator.from_path(estimator_directory / "C-N-O_only_O")
estimator_CNO = Estimator.from_path(estimator_directory / "C-N-O")
estimator_C = Estimator.from_path(estimator_directory / "C")
estimator_N = Estimator.from_path(estimator_directory / "N")
estimator_O = Estimator.from_path(estimator_directory / "O")

In [None]:
pred_CNO_C = estimator_CNO_C.predict(index_map(data_CNO["X_test"], element="C"))
pred_CNO_N = estimator_CNO_N.predict(index_map(data_CNO["X_test"], element="N"))
pred_CNO_O = estimator_CNO_O.predict(index_map(data_CNO["X_test"], element="O"))
pred_CNO = estimator_CNO.predict(index_map(data_CNO["X_test"], element="CNO"))

pred_C = estimator_C.predict(data_C["X_test"])
pred_N = estimator_N.predict(data_N["X_test"])
pred_O = estimator_O.predict(data_O["X_test"])

First, it's useful to have the overall CBA score.

In [None]:
Y_test = data_CNO["Y_test"]
balanced_accuracy_score(Y_test.flatten(), pred_CNO.round().flatten())

Then we continue.

In [None]:
N_containing_functional_groups = [
    '1,2-Aminoalcohol',
    'Lactam',
    'Amide',
    'Imidolactone',
    'Heterocyclic',
    'Hetero_N_basic_H',
    'Amine',
    'Hetero_N_nonbasic',
    'Primary_arom_amine',
    'Tertiary_aliph_amine',
    'NH_aziridine',
    'Heteroaromatic',
    'Secondary_aliph_amine',
    'Nitrile',
]

O_containing_functional_groups = [
    '1,2-Aminoalcohol',
    'Secondary_alcohol',
    'Carbonic_acid_derivatives',
    'Tertiary_alcohol',
    'Lactam',
    'Primary_alcohol',
    'Aldehyde',
    'Ketone',
    'Carboxylic_acid_derivative',
    'Epoxide',
    'Imidolactone',
    'Heterocyclic',
    'Dialkylether',
    'Phenol',
    'Heteroaromatic',
    'Hetero_O',
    'Alcohol',
]

In [None]:
results = {
    "fg": data_CNO["functional_groups"], "C": [], "N": [], "O": [], "CNO": [] # "CNO_C": [], "CNO_N": [], "CNO_O": [],
}

In [None]:
for ii, fg in enumerate(data_CNO["functional_groups"]):
    Y_test = data_CNO["Y_test"]
    a_CNO = balanced_accuracy_score(Y_test[:, ii], pred_CNO.round()[:, ii])
    a_C = balanced_accuracy_score(Y_test[:, ii], pred_CNO_C.round()[:, ii])
    a_N = balanced_accuracy_score(Y_test[:, ii], pred_CNO_N.round()[:, ii])
    a_O = balanced_accuracy_score(Y_test[:, ii], pred_CNO_O.round()[:, ii])
    
    results["C"].append(a_C)
    results["N"].append(a_N)
    results["O"].append(a_O)
    results["CNO"].append(a_CNO)
    # print(f"{fg:<30} {a_C:.03f} {a_N:.03f} {a_O:.03f} {a_CNO:.03f}")

In [None]:
df = pd.DataFrame(results)
df["Best SM"] = df[["C", "N", "O"]].max(axis=1)
df["d"] = df["CNO"] - df["Best SM"]
df = df.sort_values("d", ascending=False)

In [None]:
def make_plot(df):
    
    # fig, axs = plt.subplots(2, 1, figsize=(8, 2), gridspec_kw={'height_ratios': [1, 3]}, sharex=True)
    fig, ax = plt.subplots(1, 1, figsize=(7, 2))

    x = [ii for ii in range(len(df.index))]

    ax.scatter(x, df["CNO"], color="grey", label=r"$\mathcal{D}_\mathrm{CNO}$")


    s = 10
    ax.scatter(x, df["C"], color="black", s=s, label=r"$\mathcal{D}_\mathrm{CNO}^\mathrm{[C]}$", linewidth=0.5)

    for ii, (fg, xx) in enumerate(zip(df["fg"], x)):
        ax.scatter(
            xx, df[df["fg"] == fg]["N"],
            s=s, label=r"$\mathcal{D}_\mathrm{CNO}^\mathrm{[N]}$" if ii == 0 else None, facecolors="white" if fg not in N_containing_functional_groups else "blue", edgecolors="blue",
            linewidth=.5
        )

    for ii, (fg, xx) in enumerate(zip(df["fg"], x)):
        ax.scatter(
            xx, df[df["fg"] == fg]["O"],
            s=s, label=r"$\mathcal{D}_\mathrm{CNO}^\mathrm{[O]}$" if ii == 0 else None, facecolors="white" if fg not in O_containing_functional_groups else "red", edgecolors="red",
            linewidth=.5
        )


    labels = [xx.replace("_", " ") for xx in df["fg"]]
    ax.set_xticks(x)
    ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=8)

    ax.tick_params(which='both', direction='in', bottom=True, left=True, top=False, right=False)
    for xx in x:
        ax.axvline(xx, zorder=-1, linewidth=0.5, color="black", alpha=0.2)

    ax.set_ylabel("CBA", fontsize=10)

    # ax.legend(frameon=True, ncol=4, fontsize=6, loc="lower center", bbox_to_anchor=(1, -1))
    
    

    
    ax2 = ax.twinx() 
    
    ax2.scatter(x, df["d"] * 100, zorder=-1, marker="^", color="green", s=5)
    # ax.axhline(df["d"].mean() * 100, color="grey", zorder=-1, linewidth=0.5)
    # ax.bar(x, df["d8"] * 100, color="purple", width=0.5, alpha=1)
    # ax.axhline(df["d8"].mean() * 100, color="purple", zorder=-1, linewidth=0.5)
    ax2.set_ylabel("Adv (\%)", fontsize=10, color="green")
    ax2.tick_params(which='both', direction='in', bottom=False, left=False, top=True, right=True)
    
    return ax, ax2

In [None]:
axs = make_plot(df)

axs[0].set_yticks([0.5, 0.6, 0.7, 0.8, 0.9, 1.0])
yticks = axs[0].get_yticks()
axs[0].set_ylim(yticks[0] - 0.02, yticks[-1] + 0.02)
axs[1].set_yticks([0, 3, 6])

# axs[0].set_yticks([0.7, 0.8, 0.9, 1.0])
# axs[0].set_ylim(0.68, 1.02)
# axs[1].set_yticks([0, 3, 6])

plt.savefig("figures/fig_multimodal_advantage/multimodal_advantage-cutoff8.svg", bbox_inches="tight", dpi=300)
# plt.show()

In [None]:
-2 % 10

In [None]:
ddd = [1, 2, 3]
ddd.insert(0, 0)

In [None]:
ddd

In [None]:
del ddd[1]

In [None]:
ddd