In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

Helper functions and imports

In [None]:
import json

import numpy as np
import pandas as pd
import pickle

In [None]:
import matplotlib.pyplot as plt
import matplotlib as mpl

In [None]:
mpl.rcParams['mathtext.fontset'] = 'stix'
mpl.rcParams['font.family'] = 'STIXGeneral'
# mpl.rcParams['text.usetex'] = True
mpl.rcParams['text.usetex'] = False
plt.rc('xtick', labelsize=12)
plt.rc('ytick', labelsize=12)
plt.rc('axes', labelsize=12)
mpl.rcParams['figure.dpi'] = 300

Create an index which references which smiles have which atom types and which spectra available.

In [None]:
def read_json(path):
    with open(path, "r") as infile:
        dat = json.load(infile)
    return dat

In [None]:
data = pickle.load(open("data/221205/221205_xanes.pkl", "rb"))

In [None]:
functional_group_data = read_json("data/221205/functional_groups.json")
all_functional_groups_enumerated = [g for groups in functional_group_data.values() for g in groups]
all_unique_functional_groups = sorted(list(set(all_functional_groups_enumerated)))

In [None]:
index = {
    "SMILES": [],
    "C": [],
    "N": [],
    "O": [],
    "C-XANES": [],
    "N-XANES": [],
    "O-XANES": []
}
index = {**index, **{fg: [] for fg in all_unique_functional_groups}}

In [None]:
for smile, dat in data["data"].items():
    
    lower_smile = smile.lower()
    
    index["SMILES"].append(smile)
    
    for key in ["C", "N", "O"]:
        index[key].append(int(key.lower() in lower_smile))
    
    for key in ["C-XANES", "N-XANES", "O-XANES"]:
        index[key].append(int(dat[key] is not None))
        
    for fg in all_unique_functional_groups:
        index[fg].append(int(fg in functional_group_data[smile]))

In [None]:
df = pd.DataFrame(index)

In [None]:
# df.to_csv("data/221205/221205_index.csv")

In [None]:
index = pd.read_csv("data/221205/221205_index.csv", index_col=0)  # Reload

In [None]:
assert (df == index).all().all()

Test the usage of the index.

In [None]:
from multimodal_molecules.data import get_dataset

In [None]:
data = get_dataset(
    xanes_path="data/221205/221205_xanes.pkl",
    index_path="data/221205/221205_index.csv",
    conditions="C-XANES,O-XANES"
)

# Example Random Forest Classifier training

In [None]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import train_test_split
from sklearn.metrics import balanced_accuracy_score, accuracy_score
from sklearn.inspection import permutation_importance

In [None]:
X = data["C-XANES"]
Y = data["FG"]["Amide"]

In [None]:
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.10, random_state=43)

In [None]:
rf = RandomForestClassifier(n_estimators=200, n_jobs=4, random_state=43, verbose=True)
rf.fit(X_train, Y_train)

In [None]:
Y_test_pred = rf.predict(X_test)

In [None]:
accuracy = accuracy_score(Y_test, Y_test_pred) * 100
balanced_accuracy = balanced_accuracy_score(Y_test, Y_test_pred) * 100

In [None]:
print(f"Accuracy is {accuracy:.01f}%")
print(f"Balanced accuracy is {balanced_accuracy:.01f}%")

In [None]:
p_importance = permutation_importance(rf, X_test, Y_test, n_jobs=4)

In [None]:
grid = data["grid"]["C-XANES"]
feature_importances = p_importance["importances_mean"]

Let's find out the average spectrum of each class.

In [None]:
index_where_has_fg = np.where(Y_test == 1)[0]
index_where_not_has_fg = np.where(Y_test == 0)[0]

In [None]:
spectra_where_has_fg = X_test[index_where_has_fg, :]
spectra_where_not_has_fg = X_test[index_where_not_has_fg, :]

In [None]:
mu_spectra_where_has_fg = spectra_where_has_fg.mean(axis=0)
sd_spectra_where_has_fg = spectra_where_has_fg.std(axis=0)

mu_spectra_where_not_has_fg = spectra_where_not_has_fg.mean(axis=0)
sd_spectra_where_not_has_fg = spectra_where_not_has_fg.std(axis=0)

In [None]:
fig, axs = plt.subplots(2, 1, figsize=(3, 2), sharex=True)

ax = axs[0]
# ax.plot(grid, feature_importances)

ax = axs[1]
ax.plot(grid, mu_spectra_where_has_fg, "r-")
ax.fill_between(grid, mu_spectra_where_has_fg - sd_spectra_where_has_fg, mu_spectra_where_has_fg + sd_spectra_where_has_fg, color="red", alpha=0.2, linewidth=0)
ax.plot(grid, mu_spectra_where_not_has_fg, "b-")
ax.fill_between(grid, mu_spectra_where_not_has_fg - sd_spectra_where_not_has_fg, mu_spectra_where_not_has_fg + sd_spectra_where_not_has_fg, color="blue", alpha=0.2, linewidth=0)

plt.show()