In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [None]:
from functools import cache

from pathlib import Path
from pprint import pprint

import matplotlib.pyplot as plt
import matplotlib as mpl
import pandas as pd
import numpy as np
from scipy.stats import sem
from sklearn.metrics import balanced_accuracy_score, accuracy_score
from IPython.display import clear_output

mpl.rcParams['mathtext.fontset'] = 'stix'
mpl.rcParams['font.family'] = 'STIXGeneral'
mpl.rcParams['text.usetex'] = False
plt.rc('xtick', labelsize=12)
plt.rc('ytick', labelsize=12)
plt.rc('axes', labelsize=12)
mpl.rcParams['figure.dpi'] = 300

In [None]:
!which python

In [None]:
from crescendo.analysis import Ensemble

In [None]:
N_est = 20
jj_vals = list(range(N_est // 2 + 1))

def load(
    data,
    data_dir="/Users/mc/GitHub/AIMM/multimodal-molecules/data/23-04-26-ml-data",
    ensemble_dir="/Users/mc/GitHub/AIMM/multimodal-molecules/data/23-05-05-ensembles"
):
    data_dir = Path(data_dir) / data
    with open(Path(data_dir) / "functional_groups.txt", "r") as f:
        functional_groups = f.readlines()
    functional_groups = [xx.strip() for xx in functional_groups]
    d = data.replace("-XANES", "").replace("_", "-")
    ensemble_path = Path(ensemble_dir) / d
    ensemble = Ensemble.from_root(ensemble_path, data_dir=data_dir)
    with open(Path(data_dir) / "smiles_test.txt", "r") as f:
        smiles_test = f.readlines()
    return ensemble, functional_groups, smiles_test

def err_std2(ensemble):
    Y_test_pred = ensemble.predict(ensemble.X_test)
    
    # Round every result first, so treat each estimator as a classifier
    Y_test_pred = Y_test_pred.round()
    mu = Y_test_pred.mean(axis=0)
    mu = mu.round()
    err = np.abs(ensemble.Y_test - mu)
    std = np.std(Y_test_pred, axis=0)
    correct = (mu.round() - ensemble.Y_test) == 0
    return err, std, correct, Y_test_pred

def get_error_statistics(err, test_pred, n_functional_groups):
    
    
    N_est = test_pred.shape[0]
    pred_sum = test_pred.sum(axis=0)
    
    all_errors = []
    all_errors_std = []
    all_lengths = []

    for ii in range(n_functional_groups):
        errors = []
        errors_std = []
        lengths = []
        for jj in jj_vals:
            where = np.where((pred_sum[:, ii] == jj) | (pred_sum[:, ii] == N_est - jj))[0]
            lengths.append(len(where))

            # print(where_all_agree)
            errors.append(err[where, ii].mean())
            errors_std.append(err[where, ii].std())
            # print(f"{e*100:.02f}% err")

        all_errors.append(errors)
        all_errors_std.append(errors_std)
        all_lengths.append(lengths)
    
    all_errors = np.array(all_errors)
    all_errors_std = np.array(all_errors_std)
    all_lengths = np.array(all_lengths)
    
    # Average over the entire dataset
    errors = []
    errors_std = []
    lengths = []
    for jj in jj_vals:
        where = np.where((pred_sum == jj) | (pred_sum == N_est - jj))
        lengths.append(len(where[0]))
        # print(len(where[0]), len(where[1]))

        # print(where_all_agree)
        errors.append(np.nanmean(err[where]))
        errors_std.append(sem(err[where], nan_policy="omit"))
        # print(f"{e*100:.02f}% err")
        
    everything_errors = np.array(errors)
    everything_errors_sem = np.array(errors_std)
    everything_lengths = np.array(lengths)
    
    return all_errors, all_errors_std, all_lengths, everything_errors, everything_errors_sem, everything_lengths, jj_vals



def get_error_statistics_CBA(test_truth, test_pred, n_functional_groups):
    
    
    N_est = test_pred.shape[0]
    pred_sum = test_pred.sum(axis=0)
    pred_mu = test_pred.mean(axis=0).round()
    
    all_cba = []
    all_lengths = []

    for ii in range(n_functional_groups):
        cba = []
        lengths = []
        for jj in jj_vals:
            where = np.where((pred_sum[:, ii] == jj) | (pred_sum[:, ii] == N_est - jj))[0]
            lengths.append(len(where))
            # print(test_truth[where, ii].shape, pred_mu[where, ii].shape)
            tmp = balanced_accuracy_score(
                test_truth[where, ii],
                pred_mu[where, ii]
            )
            cba.append(tmp)

        all_cba.append(cba)
        all_lengths.append(lengths)
    
    all_cba = np.array(all_cba)
    all_lengths = np.array(all_lengths)
    
    
    everything_cba = []
    everything_lengths = []
    for jj in jj_vals:
        where = np.where((pred_sum == jj) | (pred_sum == N_est - jj))
        everything_lengths.append(len(where[0]))
        # print(test_truth[where, ii].shape, pred_mu[where, ii].shape)
        tmp = balanced_accuracy_score(
            test_truth[where].flatten(),
            pred_mu[where].flatten()
        )
        everything_cba.append(tmp)
    
    everything_cba = np.array(everything_cba)
    everything_lengths = np.array(everything_lengths)
    
    return all_cba, all_lengths, everything_cba, everything_lengths, jj_vals


def get_everything(data, data_dir, ensemble_dir):
    ensemble, functional_groups, smiles_test = load(data, data_dir, ensemble_dir)
    err, std, correct, test_pred = err_std2(ensemble)
    # all_errors, all_errors_std, all_lengths, everything_errors, everything_errors_sem, everything_lengths, jj_vals \
    #     = get_error_statistics(err, test_pred, len(functional_groups))
    all_cba, all_lengths, everything_cba, everything_lengths, jj_vals \
        = get_error_statistics_CBA(ensemble.Y_test, test_pred, len(functional_groups))
    where_little_data = np.where(all_lengths < 10)
    all_cba[where_little_data] = np.nan
    del where_little_data
    clear_output()
    return locals()  # Don't try this at home

# Plots

In [None]:
data_dir="/Users/mc/GitHub/AIMM/multimodal-molecules/data/23-04-26-ml-data"
ensemble_dir="/Users/mc/GitHub/AIMM/multimodal-molecules/data/23-05-05-ensembles"

In [None]:
results_C = get_everything("C-XANES", data_dir, ensemble_dir)
results_N = get_everything("N-XANES", data_dir, ensemble_dir)
results_O = get_everything("O-XANES", data_dir, ensemble_dir)
results_CNO = get_everything("C-XANES_N-XANES_O-XANES", data_dir, ensemble_dir)

In [None]:
results = [results_C, results_N, results_O, results_CNO]
colors = ["black", "blue", "red", "grey"]
letters1 = ["a", "b", "c", "d"]
letters2 = ["e", "f", "g", "h"]

In [None]:
fig, axs_arr = plt.subplots(2, 4, figsize=(4, 2), sharey=False, sharex=True, gridspec_kw={'height_ratios': [2, 1]})

y_range_top = [0.4, 1.02]
y_range_bottom = [1.7, 6.3]

axs = axs_arr[0]
for ii, (result, color, letter) in enumerate(zip(results, colors, letters1)):
    ax = axs[ii]
    for jj in range(len(result["functional_groups"])):
        ax.plot(result["jj_vals"], result["all_cba"][jj, :], color=color, linewidth=0.3, alpha=0.3)
    ax.plot(result["jj_vals"], result["everything_cba"], color=color, marker="o", markersize=2, linewidth=1)
    ax.tick_params(which="both", direction="in")
    ax.set_xticks([0, 5, 10])
    ax.set_yticks([0.5, 1.0])
    ax.set_ylim(*y_range_top)
    if ii != 0:
        ax.set_yticklabels([])
    ax.text(0.95, 0.95, f"({letter})", ha="right", va="top", transform=ax.transAxes)
axs[0].set_ylabel("CBA")
axs[0].text(0.1, 0.05, "C", ha="left", va="bottom", transform=axs[0].transAxes, color="black")
axs[1].text(0.1, 0.05, "N", ha="left", va="bottom", transform=axs[1].transAxes, color="blue")
axs[2].text(0.1, 0.05, "O", ha="left", va="bottom", transform=axs[2].transAxes, color="red")
axs[3].text(0.1, 0.05, "CNO", ha="left", va="bottom", transform=axs[3].transAxes, color="grey")

axs = axs_arr[1]
for ii, (result, color, letter) in enumerate(zip(results, colors, letters2)):
    ax = axs[ii]
    ax.tick_params(which="both", direction="in")
    ax.bar(jj_vals, np.log10(result["everything_lengths"]), color=color)
    ax.set_yticks([2, 4, 6])
    ax.set_ylim(*y_range_bottom)
    if ii != 0:
        ax.set_yticklabels([])
    ax.text(0.95, 0.95, f"({letter})", ha="right", va="top", transform=ax.transAxes)
axs[0].set_ylabel("$\log_{10} N$")



ax = fig.add_subplot(111, frameon=False)
plt.tick_params(labelcolor='none', top='off', bottom='off', left='off', right='off')
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel("Deviating Estimators", labelpad=20)

plt.subplots_adjust(hspace=0.05, wspace=0.1)

plt.show()
# plt.savefig("figures/uq.pdf", bbox_inches="tight", dpi=300)

In [None]:
for ii in np.where(results_N["all_cba"][:, 0] < 0.9)[0]:
    print(results_N["functional_groups"][ii].replace("_", "-"))

In [None]:
for ii in np.where(results_O["all_cba"][:, 0] < 0.9)[0]:
    print(results_O["functional_groups"][ii].replace("_", "-"))

# Plots CUTOFF8

In [None]:
data_dir="/Users/mc/GitHub/AIMM/multimodal-molecules/data/23-05-11-ml-data-CUTOFF8"
ensemble_dir="/Users/mc/GitHub/AIMM/multimodal-molecules/data/23-05-13-ensembles-CUTOFF8"

In [None]:
results_C = get_everything("C-XANES", data_dir, ensemble_dir)
results_N = get_everything("N-XANES", data_dir, ensemble_dir)
results_O = get_everything("O-XANES", data_dir, ensemble_dir)
results_CNO = get_everything("C-XANES_N-XANES_O-XANES", data_dir, ensemble_dir)

In [None]:
results = [results_C, results_N, results_O, results_CNO]
colors = ["black", "blue", "red", "grey"]
letters1 = ["a", "b", "c", "d"]
letters2 = ["e", "f", "g", "h"]

In [None]:
fig, axs_arr = plt.subplots(2, 4, figsize=(4, 2), sharey=False, sharex=True, gridspec_kw={'height_ratios': [2, 1]})

y_range_top = [0.4, 1.02]
y_range_bottom = [2.7, 7.3]

axs = axs_arr[0]
for ii, (result, color, letter) in enumerate(zip(results, colors, letters1)):
    ax = axs[ii]
    for jj in range(len(result["functional_groups"])):
        ax.plot(result["jj_vals"], result["all_cba"][jj, :], color=color, linewidth=0.3, alpha=0.3)
    ax.plot(result["jj_vals"], result["everything_cba"], color=color, marker="o", markersize=2, linewidth=1)
    ax.tick_params(which="both", direction="in")
    ax.set_xticks([0, 5, 10])
    ax.set_yticks([0.5, 1.0])
    ax.set_ylim(*y_range_top)
    if ii != 0:
        ax.set_yticklabels([])
    ax.text(0.95, 0.95, f"({letter})", ha="right", va="top", transform=ax.transAxes)
axs[0].set_ylabel("CBA")
axs[0].text(0.1, 0.05, "C", ha="left", va="bottom", transform=axs[0].transAxes, color="black")
axs[1].text(0.1, 0.05, "N", ha="left", va="bottom", transform=axs[1].transAxes, color="blue")
axs[2].text(0.1, 0.05, "O", ha="left", va="bottom", transform=axs[2].transAxes, color="red")
axs[3].text(0.1, 0.05, "CNO", ha="left", va="bottom", transform=axs[3].transAxes, color="grey")

axs = axs_arr[1]
for ii, (result, color, letter) in enumerate(zip(results, colors, letters2)):
    ax = axs[ii]
    ax.tick_params(which="both", direction="in")
    ax.bar(jj_vals, np.log10(result["everything_lengths"]), color=color)
    ax.set_yticks([3, 5, 7])
    ax.set_ylim(*y_range_bottom)
    if ii != 0:
        ax.set_yticklabels([])
    ax.text(0.95, 0.95, f"({letter})", ha="right", va="top", transform=ax.transAxes)
axs[0].set_ylabel("$\log_{10} N$")



ax = fig.add_subplot(111, frameon=False)
plt.tick_params(labelcolor='none', top='off', bottom='off', left='off', right='off')
ax.set_xticks([])
ax.set_yticks([])
ax.set_xlabel("Deviating Estimators", labelpad=20)

plt.subplots_adjust(hspace=0.05, wspace=0.1)

# plt.show()
plt.savefig("figures/uq_CUTOFF8.pdf", bbox_inches="tight", dpi=300)

In [None]:
for ii in np.where(results_N["all_cba"][:, 0] < 0.9)[0]:
    print(results_N["functional_groups"][ii].replace("_", "-"))

In [None]:
for ii in np.where(results_O["all_cba"][:, 0] < 0.9)[0]:
    print(results_O["functional_groups"][ii].replace("_", "-"))