In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

In [None]:
import json
from pathlib import Path
import sys

from IPython.display import clear_output
import matplotlib.pyplot as plt
import matplotlib as mpl
import torch

from crescendo.analysis import Ensemble as CrescendoEnsemble
from crescendo.analysis import HPTunedSet as CrescendoHPTunedSet
from multimodal_molecules.core import _torch_models_from_Crescendo, Ensemble, scaler_from_estimator, save_json

# Create torch model ensembles

Here we abstract away most external dependencies and create standalone models for people to use. We do this for every combination that we have.

In [None]:
for element in ["C", "N", "O", "CN", "CO", "NO", "CNO"]:
    xanes_dir = "_".join([f"{e}-XANES" for e in element])
    element_dir = "-".join([e for e in element])
    print(xanes_dir, element_dir)
    ensemble = CrescendoEnsemble.from_root(
        f"data/23-05-05-ensembles/{element_dir}",
        data_dir=f"data/23-04-26-ml-data/{xanes_dir}"
    )
    root = Path(f"data/23-12-06_torch_models/{element_dir}")
    _torch_models_from_Crescendo(root, ensemble)
    clear_output()

Do the same thing for the cutoff-8 data

In [None]:
for element in ["C", "N", "O", "CNO"]:
    xanes_dir = "_".join([f"{e}-XANES" for e in element])
    element_dir = "-".join([e for e in element])
    print(xanes_dir, element_dir)
    ensemble = CrescendoEnsemble.from_root(
        f"data/23-05-13-ensembles-CUTOFF8/{element_dir}",
        data_dir=f"data/23-05-11-ml-data-CUTOFF8/{xanes_dir}"
    )
    root = Path(f"data/23-12-06_torch_models/cutoff8/{element_dir}")
    _torch_models_from_Crescendo(root, ensemble)
    clear_output()

# Find the best model for some other cases

Note sometimes this causes the kernel to crash. Probably loading too much data at once. Hence why we're just saving everything as `torch` models and whatnot. Makes things way easier for everyone including me!

In [None]:
# Multi-modal models
# data_dir = "data/23-04-26-ml-data/C-XANES_N-XANES_O-XANES"
# models_path_root = "data/23-05-03-hp"
# model_signatures = ["C-N-O_only_C", "C-N-O_only_N", "C-N-O_only_O", "C-N-O"]

# X-only model
ELEMENT = "O"
data_dir = f"data/23-04-26-ml-data/{ELEMENT}-XANES"
models_path_root = "data/23-05-03-hp"
model_signatures = [f"{ELEMENT}"]

In [None]:
for sig in model_signatures:

    hp_model_path = Path(models_path_root) / sig
    hptuned_set = CrescendoHPTunedSet.from_root(hp_model_path, data_dir=data_dir)
    best_estimator = hptuned_set.get_best_estimator(hptuned_set.X_val, hptuned_set.Y_val)[0]

    target_path = Path("data/23-12-06_torch_models/solo_estimators") / sig
    target_path.mkdir(exist_ok=True, parents=True)
    d_scaler, _ = scaler_from_estimator(best_estimator)
    model = best_estimator.get_model()

    save_json(d_scaler, target_path / "scaler.json")
    torch.save(model, target_path / "model.pt")

    del hptuned_set
    del best_estimator
    del model
    clear_output()

Do the same thing for the cutoff-8 data

In [None]:
# Multi-modal models
# data_dir = "data/23-05-11-ml-data-CUTOFF8/C-XANES_N-XANES_O-XANES"
# models_path_root = "data/23-05-13-hp-CUTOFF8"
# model_signatures = ["C-N-O_only_C", "C-N-O_only_N", "C-N-O_only_O", "C-N-O"]

# X-only model
ELEMENT = "O"
data_dir = f"data/23-05-11-ml-data-CUTOFF8/{ELEMENT}-XANES"
models_path_root = "data/23-05-13-hp-CUTOFF8"
model_signatures = [f"{ELEMENT}"]

In [None]:
for sig in model_signatures:

    hp_model_path = Path(models_path_root) / sig
    hptuned_set = CrescendoHPTunedSet.from_root(hp_model_path, data_dir=data_dir)
    best_estimator = hptuned_set.get_best_estimator(hptuned_set.X_val, hptuned_set.Y_val)[0]

    target_path = Path("data/23-12-06_torch_models/solo_estimators_cutoff8") / sig
    target_path.mkdir(exist_ok=True, parents=True)
    d_scaler, _ = scaler_from_estimator(best_estimator)
    model = best_estimator.get_model()

    save_json(d_scaler, target_path / "scaler.json")
    torch.save(model, target_path / "model.pt")

    del hptuned_set
    del best_estimator
    del model
    clear_output()