In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [None]:
from pathlib import Path
from pprint import pprint

import matplotlib.pyplot as plt
import matplotlib as mpl
import pandas as pd
import numpy as np
from sklearn.metrics import balanced_accuracy_score, accuracy_score
from IPython.display import clear_output

mpl.rcParams['mathtext.fontset'] = 'stix'
mpl.rcParams['font.family'] = 'STIXGeneral'
mpl.rcParams['text.usetex'] = False
plt.rc('xtick', labelsize=12)
plt.rc('ytick', labelsize=12)
plt.rc('axes', labelsize=12)
mpl.rcParams['figure.dpi'] = 300

In [None]:
!which python

In [None]:
single_modalities = ["C-N-O_only_C", "C-N-O_only_N", "C-N-O_only_O"]
multi_modalities = ["C-N-O"]

# Core data
df = pd.read_csv("data/23-05-13-multimodal-single-estimator/C-XANES_N-XANES_O-XANES.csv", index_col=0)
df["Best SM"] = df[[*single_modalities]].max(axis=1)
df["Best MM"] = df[[*multi_modalities]].max(axis=1)
df["d"] = df["Best MM"] - df["Best SM"]
df = df.drop(columns=["Best SM", "Best MM"])

# CUTOFF8 data
df_CUTOFF8 = pd.read_csv("data/23-05-13-multimodal-single-estimator/C-XANES_N-XANES_O-XANES-CUTOFF8.csv", index_col=0)
df_CUTOFF8["Best SM"] = df_CUTOFF8[[*single_modalities]].max(axis=1)
df_CUTOFF8["Best MM"] = df_CUTOFF8[[*multi_modalities]].max(axis=1)
df_CUTOFF8["d"] = df_CUTOFF8["Best MM"] - df_CUTOFF8["Best SM"]
df_CUTOFF8 = df_CUTOFF8.drop(columns=["Best SM", "Best MM"])

In [None]:
df = df.sort_values(by=["d"], ascending=False)
df_CUTOFF8 = df_CUTOFF8.sort_values(by=["d"], ascending=False)

In [None]:
df_plot = df_CUTOFF8.style.background_gradient(cmap='viridis')
df_plot

In [None]:
N_containing_functional_groups = [
    '1,2-Aminoalcohol',
    'Lactam',
    'Amide',
    'Imidolactone',
    'Heterocyclic',
    'Hetero_N_basic_H',
    'Amine',
    'Hetero_N_nonbasic',
    'Primary_arom_amine',
    'Tertiary_aliph_amine',
    'NH_aziridine',
    'Heteroaromatic',
    'Secondary_aliph_amine',
    'Nitrile',
]

In [None]:
O_containing_functional_groups = [
    '1,2-Aminoalcohol',
    'Secondary_alcohol',
    'Carbonic_acid_derivatives',
    'Tertiary_alcohol',
    'Lactam',
    'Primary_alcohol',
    'Aldehyde',
    'Ketone',
    'Carboxylic_acid_derivative',
    'Epoxide',
    'Imidolactone',
    'Heterocyclic',
    'Dialkylether',
    'Phenol',
    'Heteroaromatic',
    'Hetero_O',
    'Alcohol',
]

In [None]:
def make_plot(df):
    
    fig, axs = plt.subplots(2, 1, figsize=(8, 2), gridspec_kw={'height_ratios': [1, 3]}, sharex=True)

    x = [ii for ii in range(len(df.index))]

    ax = axs[0]

    ax.bar(x, df["d"] * 100, color="grey")
    # ax.axhline(df["d"].mean() * 100, color="grey", zorder=-1, linewidth=0.5)
    # ax.bar(x, df["d8"] * 100, color="purple", width=0.5, alpha=1)
    # ax.axhline(df["d8"].mean() * 100, color="purple", zorder=-1, linewidth=0.5)
    ax.set_ylabel("Adv (%)", fontsize=10)
    ax.tick_params(which='both', direction='in', bottom=False, left=True, top=True, right=True)
    for xx in x:
        ax.axvline(xx, zorder=-1, linewidth=0.5, color="black", alpha=0.2)

    ax = axs[1]

    ax.scatter(x, df["C-N-O"], color="grey", label="C$+$N$+$O")


    s = 10
    ax.scatter(x, df["C-N-O_only_C"], color="black", s=s, label="C")

    for ii, (fg, xx) in enumerate(zip(df["functional_group"], x)):
        ax.scatter(
            xx, df[df["functional_group"] == fg]["C-N-O_only_N"],
            s=s, label="N" if ii == 0 else None, facecolors="white" if fg not in N_containing_functional_groups else "blue", edgecolors="blue", linewidth=.5
        )

    for ii, (fg, xx) in enumerate(zip(df["functional_group"], x)):
        ax.scatter(
            xx, df[df["functional_group"] == fg]["C-N-O_only_O"],
            s=s, label="O" if ii == 0 else None, facecolors="white" if fg not in O_containing_functional_groups else "red", edgecolors="red", linewidth=.5
        )


    labels = [xx.replace("_", " ") for xx in df["functional_group"]]
    ax.set_xticks(x)
    ax.set_xticklabels(labels, rotation=45, ha='right', fontsize=8)

    ax.tick_params(which='both', direction='in', bottom=True, left=True, top=False, right=True)
    for xx in x:
        ax.axvline(xx, zorder=-1, linewidth=0.5, color="black", alpha=0.2)

    ax.set_ylabel("CBA", fontsize=10)

    ax.legend(frameon=True, ncol=4, fontsize=6, loc="lower center")
    
    return axs

In [None]:
axs = make_plot(df)

axs[1].set_yticks([0.7, 0.8, 0.9, 1.0])
axs[1].set_ylim(0.68, 1.02)

# plt.show()
plt.savefig("figures/multimodal_advantage.pdf", bbox_inches="tight", dpi=300)

In [None]:
axs = make_plot(df_CUTOFF8)

axs[1].set_yticks([0.6, 0.8, 1.0])
axs[1].set_ylim(0.53, 1.02)

# plt.show()
plt.savefig("figures/multimodal_advantage_CUTOFF8.pdf", bbox_inches="tight", dpi=300)

In [None]:
df["d"].mean() * 100

In [None]:
df["d8"].mean() * 100

In [None]:
mu = df["C-N-O"].mean()
sd = df["C-N-O"].std()
print(f"{mu:.02f} +/- {sd:.02f}")

In [None]:
mu = df_CUTOFF8["C-N-O"].mean()
sd = df_CUTOFF8["C-N-O"].std()
print(f"{mu:.02f} +/- {sd:.02f}")