In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

Helper functions and imports

In [None]:
from itertools import product
import json
from pathlib import Path

import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib.backends.backend_pdf import PdfPages
import numpy as np
import pandas as pd
import pickle
from scipy.stats import sem
import seaborn as sns

In [None]:
pd.set_option('styler.format.precision', 3)

In [None]:
def read_json(path):
    with open(path, "r") as infile:
        dat = json.load(infile)
    return dat

In [None]:
from multimodal_molecules.plotting import set_defaults, set_grids

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 2))
plt.clf()
set_defaults()

In [None]:
from multimodal_molecules.models import Results, get_all_combinations, predict_rf
from multimodal_molecules.data import get_dataset

Set relevant paths.

In [None]:
json_file = Path("results/221226/C-XANES_N-XANES_O-XANES.json")
functional_group_json_file = "data/221205/functional_groups.json"
data_directory = Path("data/221205")

Load the relevant data.

In [None]:
single_modalities = [xx.split("-")[0] for xx in json_file.parts[-1].split("_")]
combos = get_all_combinations(len(single_modalities))
multi_modalities = [cc for cc in combos if len(cc) > 1]
multi_modalities = [",".join([single_modalities[ii] for ii in cc]) for cc in multi_modalities]
print(single_modalities)
print(multi_modalities)

In [None]:
functional_group_data = read_json(functional_group_json_file)
all_functional_groups_enumerated = [g for groups in functional_group_data.values() for g in groups]
all_unique_functional_groups = sorted(list(set(all_functional_groups_enumerated)))

In [None]:
results = Results.from_file(json_file)

In [None]:
data = get_dataset(data_directory / "221205_xanes.pkl", data_directory / "221205_index.csv", conditions=results._conditions)

In [None]:
conditions="C,N,O"

conditions_list = [f"{cc}-XANES" for cc in conditions.split(",")]
tt_data = results.get_train_test_split(data, xanes=conditions)
x_test = tt_data["x_test"]

for fg in data["FG"].keys():
    y_test = tt_data["y_test"][fg]
    conditions_list_joined = "_".join(conditions_list)
    model = results.models[f"{conditions_list_joined}-{fg}"]
    
    # Get the predictions
    y_test_preds = predict_rf(model, x_test)
    mu_pred = y_test_preds.mean(axis=1)
    mu_std = y_test_preds.std(axis=1)
    mu_std_rounded = np.round(mu_std, 5)
    unique_std = np.unique(mu_std_rounded)

- If error is 1, that means 1 was predicted but `y_test` is 0; a false positive
- If error is -1, that means 0 was predicted but `y_test` is 1; a false negative

In [None]:
errors = mu_pred.round(0) - y_test

Get an error plot for this functional group

In [None]:
mu_errors = []
sd_errors = []  # standard errors

for unique_sd in unique_std:
    where = np.where(mu_std_rounded == unique_sd)[0]
    e = np.abs(errors[where])
    mu_errors.append(e.mean())
    sd_errors.append(sem(e))

mu_errors = np.array(mu_errors)
sd_errors = np.array(sd_errors)

In [None]:
plot_kwargs = {
    'linewidth': 0.0,
    'marker': 's',
    'ms': 1.0,
    'capthick': 0.3,
    'capsize': 2.0,
    'elinewidth': 0.3
}

fig, ax = plt.subplots(1, 1, figsize=(3, 2))

ax.errorbar(unique_std**2, mu_errors, yerr=sd_errors, **plot_kwargs)
# ax.plot(unique_std**2, mu_errors, linewidth=0.2)
# ax.scatter(unique_std**2, mu_errors)

ax.set_yscale("log")
ax.set_ylim(10**-4.5, 10**0.5)

plt.show()

In [None]:
false_positive_rate = len(np.where(error == 1)[0]) / len(error) * 100.0
print(false_positive_rate)

false_negative_rate = len(np.where(error == -1)[0]) / len(error) * 100.0
print(false_negative_rate)

In [None]:
argsorted_mu_std = mu_std[indexes]
argsorted_error = error[indexes]
argsorted_error_false_positives = argsorted_error == 1
argsorted_error_false_negatives = argsorted_error == -1

In [None]:
bins = 80
duration = len(argsorted_error)
bin_width = int(round(duration / bins))
binned_argsorted_error_false_positives = np.array([argsorted_error_false_positives[i:i+bin_width].sum() for i in range(0, duration, bin_width)])
binned_argsorted_error_false_negatives = np.array([argsorted_error_false_negatives[i:i+bin_width].sum() for i in range(0, duration, bin_width)])
binned_grid = np.array([bin_width / 2 + ii * bin_width for ii in range(bins + 1)])

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(3, 1))

error_grid = np.array([ii for ii in range(len(mu_pred))])

ax.bar(binned_grid, np.log10(binned_argsorted_error_false_negatives**2), width=bin_width, color="red", label="False Positives")
ax.bar(binned_grid, np.log10(binned_argsorted_error_false_positives**2), width=bin_width, color="blue", label="False Negatives")
ax.legend(frameon=False, fontsize=8)
set_grids(ax)

ax2 = ax.twinx()
ax2.plot(error_grid, argsorted_mu_std**2, color="black", linewidth=1, label="$\sigma$")
ax2.legend(frameon=False, loc="lower left", fontsize=8)
set_grids(ax2, left=False)



ax2.set_ylabel("Uncertainty")
# ax.set_ylabel("Counts $/200$")
# ax.set_xlabel("Index $/10^5$")
# ax.set_xticks([0, 50000, 100000])
# ax.set_xticklabels([0, 0.5, 1])
# ax.set_yticks([0, 200])
# ax.set_yticklabels([0, 1])
ax.set_ylabel("Counts")
ax.set_xlabel("Index")


plt.show()

In [None]:
p = 0.5
1.0 - p * np.log2(p)

In [None]:
np.random.seed(123)
test_indexes = sorted(np.random.choice(100, size=10, replace=False).tolist())

In [None]:
test_indexes

In [None]:
results = load_results(xanes_path="data/221205/xanes.pkl", index_path="data/221205/index.csv", conditions="O-XANES", root="results")

In [None]:
report = results["results"]["O-XANES_1,2-Aminoalcohol"]

In [None]:
data = get_dataset(
    xanes_path="data/221205/xanes.pkl",
    index_path="data/221205/index.csv",
    conditions="O-XANES,N-XANES"
)

In [None]:
data.keys()

In [None]:
data["O-XANES"].shape

In [None]:
data["N-XANES"].shape

In [None]:
new_arr = np.concatenate([data["O-XANES"], data["N-XANES"]], axis=1)
new_arr.shape

In [None]:
data["index"]

In [None]:
from sklearn.metrics import accuracy_score, balanced_accuracy_score
from sklearn.model_selection import train_test_split
from sklearn.ensemble import RandomForestClassifier
from sklearn.inspection import permutation_importance

In [None]:
fg = "Amide"
offset = None
xanes = data["O-XANES"][:, offset:]
binary_targets = data["FG"][fg]
grid = data["grid"]["O-XANES"][offset:]
pc = binary_targets.sum() / len(binary_targets) * 100
print(f"Occurrence of {fg}: {pc:.01f}%")

In [None]:
x_train, x_test, y_train, y_test = train_test_split(xanes, binary_targets, test_size=0.9, random_state=42)

Train a basic random forest classifier.

In [None]:
rf = RandomForestClassifier(n_jobs=16, random_state=42)
rf.fit(x_train, y_train)

In [None]:
def rf_classifier_predict(rf, x):
    preds = np.array([tree.predict(x) for tree in rf.estimators_])
    feature_importances = np.array([tree.feature_importances_ for tree in rf.estimators_])
    return preds, feature_importances

In [None]:
y_pred, x_feature_importances = rf_classifier_predict(rf, x_test)

In [None]:
x_permutation_importances = permutation_importance(rf, x_test, y_test, n_jobs=16)

In [None]:
print("Accuracy          ", round(accuracy_score(y_test, y_pred.mean(axis=0).round(0)), 5))
print("Balanced accuracy ", round(balanced_accuracy_score(y_test, y_pred.mean(axis=0).round(0)), 5))

In [None]:
where = np.where(data["FG"][fg] == 1)[0]
not_where = np.where(data["FG"][fg] == 0)[0]

In [None]:
fig, axs = plt.subplots(3, 1, figsize=(3, 2), sharex=True)

ax = axs[0]
mu = x_permutation_importances.importances_mean #x_feature_importances.mean(axis=0)
sd = x_permutation_importances.importances_std #x_feature_importances.std(axis=0)
ax.plot(grid, mu, color="blue")
ax.fill_between(grid, mu - sd * 3, mu + sd * 3, color="blue", alpha=0.5, linewidth=0) 
ax.text(0.95, 0.95, "Permutation feature importance", ha="right", va="top", transform=ax.transAxes, color="blue")
ax.set_yticks([])

ax = axs[1]
mu = xanes[where, :].mean(axis=0)
sd = xanes[where, :].std(axis=0)
ax.plot(grid, mu, color="red")
ax.fill_between(grid, mu - sd, mu + sd, color="red", alpha=0.5, linewidth=0)
ax.text(0.95, 0.95, "With", ha="right", va="top", transform=ax.transAxes, color="red")

ax = axs[2]
mu = xanes[not_where, :].mean(axis=0)
sd = xanes[not_where, :].std(axis=0)
ax.plot(grid, mu, color="black")
ax.fill_between(grid, mu - sd, mu + sd, color="black", alpha=0.5, linewidth=0)
ax.text(0.95, 0.95, "Without", ha="right", va="top", transform=ax.transAxes, color="black")

plt.show()

Error correlation