In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [None]:
from functools import cache

from pathlib import Path
from pprint import pprint

import matplotlib.pyplot as plt
import matplotlib as mpl
import pandas as pd
import numpy as np
from scipy.stats import sem
from sklearn.metrics import balanced_accuracy_score, accuracy_score
import torch
from IPython.display import clear_output
from sklearn.decomposition import PCA
from sklearn.preprocessing import StandardScaler

from rdkit import Chem

mpl.rcParams['mathtext.fontset'] = 'stix'
mpl.rcParams['font.family'] = 'STIXGeneral'
mpl.rcParams['text.usetex'] = False
plt.rc('xtick', labelsize=12)
plt.rc('ytick', labelsize=12)
plt.rc('axes', labelsize=12)
mpl.rcParams['figure.dpi'] = 300

In [None]:
from crescendo.analysis import Ensemble
from multimodal_molecules.data import get_dataset

In [None]:
tmp_dataset = get_dataset("data/22-12-05-data/xanes.pkl", "data/22-12-05-data/index.csv", conditions="C-XANES,N-XANES,O-XANES")
grids = tmp_dataset["grid"]
del tmp_dataset

In [None]:
def load(
    data,
    data_dir="/Users/mc/GitHub/AIMM/multimodal-molecules/data/23-04-26-ml-data",
):
    data_dir = Path(data_dir) / data
    with open(Path(data_dir) / "functional_groups.txt", "r") as f:
        functional_groups = f.readlines()
    functional_groups = [xx.strip() for xx in functional_groups]

    with open(Path(data_dir) / "smiles_train.txt", "r") as f:
        smiles_train = [xx.strip() for xx in f.readlines()]
    with open(Path(data_dir) / "smiles_val.txt", "r") as f:
        smiles_val = [xx.strip() for xx in f.readlines()]
    with open(Path(data_dir) / "smiles_test.txt", "r") as f:
        smiles_test = [xx.strip() for xx in f.readlines()]
    smiles = smiles_train + smiles_val
    
    X_train = np.load(Path(data_dir) / "X_train.npy")
    X_val = np.load(Path(data_dir) / "X_val.npy")
    X_test = np.load(Path(data_dir) / "X_test.npy")
    X = np.concatenate([X_train, X_val], axis=0)
    
    Y_train = np.load(Path(data_dir) / "Y_train.npy")
    Y_val = np.load(Path(data_dir) / "Y_val.npy")
    Y_test = np.load(Path(data_dir) / "Y_test.npy")
    Y = np.concatenate([Y_train, Y_val], axis=0)
    
    return smiles, smiles_test, X, X_test, Y, Y_test, functional_groups

In [None]:
smiles, smiles_test, X, X_test, Y, Y_test, functional_groups = load("C-XANES_N-XANES_O-XANES")

In [None]:
pca_CNO = PCA(2)
scaler = StandardScaler()
pca_CNO.fit(scaler.fit_transform(X))

In [None]:
w_CNO = pca_CNO.transform(scaler.transform(X_test))

In [None]:
w_CNO.shape

In [None]:
def get_example(p, w=w_CNO):
    p = np.array(p)
    delta = np.abs(w - p).sum(axis=1)
    return np.argmin(delta)

In [None]:
p1 = get_example([0, 40])
p2 = get_example([25, 0])
p3 = get_example([0, -20])
p4 = get_example([-25, 0])
p5 = get_example([0, 0])
selected_points = [p1, p2, p3, p4, p5]
colors = ["red", "green", "blue", "orange", "purple"]

In [None]:
fig, ax = plt.subplots(1, 1, figsize=(2, 2), sharey=True)

ax.scatter(w_CNO[:, 0], w_CNO[:, 1], s=0.5, color="black", alpha=0.5, rasterized=True)

for point, color in zip(selected_points, colors):
    ax.scatter(w_CNO[point, 0], w_CNO[point, 1], color=color, marker="x")

ax.set_xlabel(r"$z_1$ [a.u.]")
ax.set_ylabel(r"$z_2$ [a.u.]")
ax.set_xticks([])
ax.set_yticks([])
ax.text(0.9, 0.9, "(d)", ha="center", va="center", transform=ax.transAxes)
    
plt.savefig("figures/pca.svg", dpi=300, bbox_inches="tight")

In [None]:
fig, axs = plt.subplots(3, 1, figsize=(2, 2), sharey=False)

ax = axs[0]
for p, c in zip(selected_points, colors):
    ax.plot(grids["C-XANES"], X_test[p, :200], color=c)
ax.text(0.1, 0.9, "(a)", ha="center", va="center", transform=ax.transAxes)
    
ax = axs[1]
for p, c in zip(selected_points, colors):
    ax.plot(grids["N-XANES"], X_test[p, 200:400], color=c)
ax.text(0.1, 0.9, "(b)", ha="center", va="center", transform=ax.transAxes)

ax = axs[2]
for p, c in zip(selected_points, colors):
    ax.plot(grids["O-XANES"], X_test[p, 400:], color=c, label=p)
ax.legend(frameon=False, loc=(1, 1))
ax.text(0.1, 0.9, "(c)", ha="center", va="center", transform=ax.transAxes)
    
    
for ax in axs:
    ax.set_yticks([])
    ax.spines[["right", "top"]].set_visible(False)
    ax.tick_params(
        which='both', direction='in', bottom=True, left=True
    )

axs[1].set_ylabel(r"$\mu(E)$ [a.u.]")
axs[2].set_xlabel(r"$E$ [e.V.]")
    
plt.subplots_adjust(hspace=1)
    
plt.savefig("figures/example_spectra.svg", dpi=300, bbox_inches="tight")

In [None]:
selected_points

In [None]:
svgs = Chem.Draw.MolsToGridImage([Chem.MolFromSmiles(smiles_test[ii]) for ii in selected_points], useSVG=True)

In [None]:
svgs

In [None]:
with open("figures/fig2_molecules.svg", "w") as f:
    f.write(svgs.data)

Paper workflow figure

In [None]:
fig, axs = plt.subplots(3, 1, figsize=(2, 3), sharey=True)

ii = 3

ax = axs[0]
ax.plot(grids["C-XANES"], X_test[selected_points[ii], :200], color="black")

ax = axs[1]
ax.plot(grids["N-XANES"], X_test[selected_points[ii], 200:400], color="blue")

ax = axs[2]
ax.plot(grids["O-XANES"], X_test[selected_points[ii], 400:], color="red")


plt.subplots_adjust(hspace=0.5)

plt.savefig("figures/workflow_spectra.svg", bbox_inches="tight", dpi=300)

In [None]:
svgs = Chem.Draw.MolsToGridImage([Chem.MolFromSmiles(smiles_test[selected_points[3]])], useSVG=True)

In [None]:
smiles_test[selected_points[3]]

In [None]:
svgs

In [None]:
with open("figures/workflow_molecule.svg", "w") as f:
    f.write(svgs.data)