In [None]:
%load_ext autoreload
%autoreload 2
%config Completer.use_jedi = False

Helper functions and imports

In [None]:
from pathlib import Path

import numpy as np
from sklearn.model_selection import train_test_split

In [None]:
from multimodal_molecules.data import get_dataset

In [None]:
def screening(fg_dict, low=0.05, high=0.95):
    """Screen out functional groups that don't appear in the data
    very often."""

    new_d = {}
    for key, value in fg_dict.items():
        avg = value.mean()
        if low < avg < high:
            new_d[key] = value
    return new_d

def concatenate(d):
    keys = [key for key in d.keys() if "XANES" in key]
    keys.sort()
    X = np.concatenate([d[key] for key in keys], axis=1)
    return X

In [None]:
conditions = [
    "C-XANES",
    "N-XANES",
    "O-XANES",
    "C-XANES,N-XANES",
    "C-XANES,O-XANES",
    "N-XANES,O-XANES",
    "C-XANES,N-XANES,O-XANES",
]

In [None]:
for condition in conditions:

    d = Path("data") / "23-05-03-ml-data" / condition.replace(",", "_")
    d.mkdir(exist_ok=True, parents=True)

    data = get_dataset(
        xanes_path="data/221205/xanes.pkl",
        index_path="data/221205/index.csv",
        conditions=condition
    )
    X = concatenate(data)
    assert X.shape[1] == len(condition.split(",")) * 200
    
    new_fg = screening(data["FG"])
    columns = list(new_fg.keys())
    Y = np.array([v for v in new_fg.values()]).T
    
    X_train_val, X_test, y_train_val, y_test, smiles_train_val, smiles_test = train_test_split(X, Y, data["index"]["SMILES"], test_size=0.15, random_state=42)
    X_train, X_val, y_train, y_val, smiles_train, smiles_val = train_test_split(X_train_val, y_train_val, smiles_train_val, test_size=0.15, random_state=42)
    
    assert X_train.shape[1] == X_val.shape[1] == X_test.shape[1] == X.shape[1]
    assert y_train.shape[1] == y_val.shape[1] == y_test.shape[1] == Y.shape[1] == len(columns)
    
    np.save(d / "X_train.npy", X_train)
    np.save(d / "X_val.npy", X_val)
    np.save(d / "X_test.npy", X_test)

    np.save(d / "Y_train.npy", y_train)
    np.save(d / "Y_val.npy", y_val)
    np.save(d / "Y_test.npy", y_test)
    
    with open(d / "smiles_train.txt", "w") as f:
        for line in smiles_train.to_list():
            f.write(f"{line}\n")
            
    with open(d / "smiles_val.txt", "w") as f:
        for line in smiles_val.to_list():
            f.write(f"{line}\n")
    
    with open(d / "smiles_test.txt", "w") as f:
        for line in smiles_test.to_list():
            f.write(f"{line}\n")
    
    with open(d / "functional_groups.txt", "w") as f:
        for line in columns:
            f.write(f"{line}\n")
    
    print(X_train.shape, X_val.shape, X_test.shape, y_train.shape, y_val.shape, y_test.shape)

In [None]:
d = Path("data") / "23-05-03-ml-data"
d2 = Path("data") / "23-04-26-ml-data"

for p in d.rglob("Y_train.npy"):
    parent = p.parent
    parent2 = d2 / parent.stem
    
    print(parent)
    print(parent2)
    
    X_train = np.load(parent / "X_train.npy")
    X_train2 = np.load(parent2 / "X_train.npy")
    assert np.all(X_train == X_train2)
    
    X_val = np.load(parent / "X_val.npy")
    X_val2 = np.load(parent2 / "X_val.npy")
    assert np.all(X_val == X_val2)
    
    X_test = np.load(parent / "X_test.npy")
    X_test2 = np.load(parent2 / "X_test.npy")
    assert np.all(X_test == X_test2)
    