# Import packages

In [1]:
import os
import csv
from pathlib import Path
from copy import deepcopy
import time
from tqdm.notebook import tqdm
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt

import joblib
import pickle

import torch
from lightning import pytorch as pl
from lightning.pytorch.callbacks.early_stopping import EarlyStopping
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import Callback

import chemprop
from chemprop import data, featurizers, models, nn
from chemprop.featurizers.molecule import RDKit2DFeaturizer
from chemprop.utils import make_mol

from hyperopt import STATUS_OK, Trials, fmin, hp, tpe
from hyperopt.pyll.base import scope

import warnings
warnings.filterwarnings("ignore", ".*Consider increasing the value of the `num_workers` argument*")

import logging

# logging.getLogger('lightning').setLevel(0)

# configure logging at the root level of Lightning
# logging.getLogger("lightning.pytorch").setLevel(logging.ERROR)

# configure logging on module level, redirect to file
# logger = logging.getLogger("lightning.pytorch.core")
# logger.addHandler(logging.FileHandler("core.log"))

# logging.getLogger("lightning.pytorch.utilities.rank_zero").setLevel(logging.ERROR)
# logging.getLogger("lightning.pytorch.accelerators.cuda").setLevel(logging.ERROR)

# logging.getLogger("lightning.fabric.plugins.environments.slurm").setLevel(logging.ERROR)
# logging.getLogger("lightning.pytorch.callbacks.model_checkpoint").setLevel(logging.ERROR)

In [2]:
from sklearn.preprocessing import StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import cross_val_score
from sklearn.model_selection import cross_validate
from sklearn.ensemble import RandomForestRegressor
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import root_mean_squared_error
from sklearn.metrics import r2_score

from sklearn.decomposition import TruncatedSVD
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from sklearn.feature_selection import SelectKBest
from sklearn.feature_selection import RFE
from sklearn.feature_selection import SelectFromModel

# from xgboost import XGBRegressor

# from skopt import BayesSearchCV
# from skopt.space import Real, Categorical, Integer
# from skopt.plots import plot_objective
# from skopt.plots import plot_convergence

## Load data

In [3]:
data_path = '../train_split_fluor.csv'

smiles_columns = ['Chromophore', 'Solvent']
target_columns = ['log_q_yield']

In [4]:
num_workers = 0 # number of workers for dataloader. 0 means using main process for data loading
batch_size = 512

In [5]:
data_df = pd.read_csv(data_path)

## Preprocess data

In [10]:
def drop_extra(df, columns):
    return df[columns]

def dropna(df, columns):
    return df.dropna(subset=columns, how='all')

def replace_gas(df):
    df.loc[df['Solvent'] == 'gas', 'Solvent'] = df['Chromophore']
    return df

def remove_neg_shift(df):
    return df[(df['Stokes shift'] >= 0.0) | (df['Stokes shift'].isna())]

def make_log_q_yield(df, eps=1e-5):
    df_tmp = df.copy()
    df_tmp.loc[df_tmp['Quantum yield'] == 0.0, 'Quantum yield'] = eps
    df_tmp.loc[:, 'log_q_yield'] = np.log(df_tmp['Quantum yield'])
    return df_tmp

def delete_outliers(df, columns):
    for column in columns:
        print(column)
        q1 = df[column].quantile(0.25)
        q3 = df[column].quantile(0.75)
        iqr = q3 - q1
        df = df[
            ((df[column] > q1 - 1.5 * iqr) & (df[column] < q3 + 1.5 * iqr))
            | (df[column].isna())
        ]

        print("left", q1 - 1.5 * iqr)
        print("right", q3 + 1.5 * iqr)
        print("=" * 100)
    return df

def preprocess_train(df):
    df = drop_extra(df, smiles_columns + ['Absorption max (nm)', 'Emission max (nm)', 'Stokes shift', 'Quantum yield'])
    df = dropna(df, ['Quantum yield'])
    df = replace_gas(df)
    df = remove_neg_shift(df)
    df = make_log_q_yield(df)
    df = drop_extra(df, smiles_columns + target_columns)
    return df

def preprocess_test(df):
    df = drop_extra(df, smiles_columns + ['Absorption max (nm)', 'Emission max (nm)', 'Stokes shift', 'Quantum yield'])
    df = dropna(df, ['Quantum yield'])
    df = replace_gas(df)
    df = remove_neg_shift(df)
    df = make_log_q_yield(df)
    df = drop_extra(df, smiles_columns + target_columns)
    return df

def rmsd(pred, target):
    mask = ~np.isnan(pred) & ~np.isnan(target)
    return root_mean_squared_error(pred[mask], target[mask])

def r2(pred, target):
    mask = ~np.isnan(pred) & ~np.isnan(target)
    return np.corrcoef(pred[mask], target[mask])[0, 1]**2

In [11]:
data_clean = preprocess_train(data_df)
data_clean.shape

(12515, 3)

In [12]:
test_data_df = pd.read_csv('../test_split_fluor.csv')
test_data_df.shape

(1850, 16)

In [13]:
test_data_clean = preprocess_test(test_data_df)
test_data_clean.shape

(1296, 3)

### Prepare data for training

In [14]:
smiss = data_clean.loc[:, smiles_columns].values
ys = data_clean.loc[:, target_columns].values

In [15]:
test_smiss = test_data_clean.loc[:, smiles_columns].values
test_ys = test_data_clean.loc[:, target_columns].values

In [16]:
smiss.shape

(12515, 2)

In [17]:
molecule_featurizer = RDKit2DFeaturizer()

def generate_mol_features(smiss, save_file, molecule_featurizer=molecule_featurizer):
    mols = [make_mol(smis, keep_h=False, add_h=False) for smis in smiss]
    extra_datapoint_descriptors = [molecule_featurizer(mol) for mol in tqdm(mols)]
    np.savez(save_file, extra_datapoint_descriptors)
    return extra_datapoint_descriptors

def load_mol_features(save_file):
    extra_mol_features = np.load(save_file)
    return [extra_mol_features[f"arr_{i}"] for i in range(len(extra_mol_features))][0]



In [18]:
# save_file = 'rdkit_feats_1.npz'
# extra_datapoint_descriptors_1 = generate_mol_features(smiss[:, 0], save_file)

In [19]:
save_file = '../gnn_3/rdkit_feats_1.npz'
extra_datapoint_descriptors_1 = load_mol_features(save_file)

In [20]:
# save_file = 'rdkit_feats_2.npz'
# extra_datapoint_descriptors_2 = generate_mol_features(smiss[:, 1], save_file)

In [21]:
save_file = '../gnn_3/rdkit_feats_2.npz'
extra_datapoint_descriptors_2 = load_mol_features(save_file)

In [22]:
# save_file = 'rdkit_feats_1_test.npz'
# extra_datapoint_descriptors_1_test = generate_mol_features(test_smiss[:, 0], save_file)

In [23]:
save_file = '../gnn_3/rdkit_feats_1_test.npz'
extra_datapoint_descriptors_1_test = load_mol_features(save_file)

In [24]:
# save_file = 'rdkit_feats_2_test.npz'
# extra_datapoint_descriptors_2_test = generate_mol_features(test_smiss[:, 1], save_file)

In [25]:
save_file = '../gnn_3/rdkit_feats_2_test.npz'
extra_datapoint_descriptors_2_test = load_mol_features(save_file)

In [29]:
good_indeces = np.loadtxt('good_indeces.txt').astype(np.int64)
ind_important_feats = np.loadtxt('ind_important_features_lqy.txt').astype(np.int64)

In [30]:
# ind_important_feats = list(range(420))

In [31]:
def get_mol_descriptors(d_1, d_2, good_indeces, inds):
    extra_datapoint_descriptors = np.concatenate((
        d_1,
        d_2),
        axis=1
    )

    extra_datapoint_descriptors = extra_datapoint_descriptors[:, good_indeces]
    extra_datapoint_descriptors = extra_datapoint_descriptors[:, inds]

    return extra_datapoint_descriptors

In [32]:
extra_datapoint_descriptors = get_mol_descriptors(
    extra_datapoint_descriptors_1,
    extra_datapoint_descriptors_2,
    good_indeces,
    ind_important_feats
)

mol_feats_scaler = StandardScaler()
extra_datapoint_descriptors = mol_feats_scaler.fit_transform(extra_datapoint_descriptors)

In [33]:
extra_datapoint_descriptors_test = get_mol_descriptors(
    extra_datapoint_descriptors_1_test,
    extra_datapoint_descriptors_2_test,
    good_indeces,
    ind_important_feats
)

extra_datapoint_descriptors_test = mol_feats_scaler.transform(extra_datapoint_descriptors_test)

In [34]:
all_data = [[data.MoleculeDatapoint.from_smi(smis[0], y, x_d=X_d) \
             for smis, y, X_d in zip(smiss, ys, extra_datapoint_descriptors)]]
all_data += [[data.MoleculeDatapoint.from_smi(smis[1]) \
              for smis in smiss]]

In [35]:
test_data = [[data.MoleculeDatapoint.from_smi(smis[0], y, x_d=X_d) \
              for smis, y, X_d in zip(test_smiss, test_ys, extra_datapoint_descriptors_test)]]
test_data += [[data.MoleculeDatapoint.from_smi(smis[1]) \
               for smis in test_smiss]]

In [36]:
component_to_split_by = 1
mols = [d.mol for d in all_data[component_to_split_by]]
train_indices, val_indices, test_indices = data.make_split_indices(mols, "random", (0.9, 0.05, 0.05))
val_indices += test_indices
train_data, val_data, _ = data.split_data_by_indices(
    all_data, train_indices, val_indices, test_indices
)

In [37]:
len(train_data[0]) + len(val_data[0])

12515

In [38]:
featurizer = featurizers.SimpleMoleculeMolGraphFeaturizer()

train_datasets = [data.MoleculeDataset(train_data[i], featurizer) for i in range(len(smiles_columns))]
val_datasets = [data.MoleculeDataset(val_data[i], featurizer) for i in range(len(smiles_columns))]
test_datasets = [data.MoleculeDataset(test_data[i], featurizer) for i in range(len(smiles_columns))]

In [39]:
train_mcdset = data.MulticomponentDataset(train_datasets)

scaler = train_mcdset.normalize_targets()
# extra_datapoint_descriptors_scaler = train_mcdset.normalize_inputs("X_d")

val_mcdset = data.MulticomponentDataset(val_datasets)
val_mcdset.normalize_targets(scaler)
# val_mcdset.normalize_inputs("X_d", extra_datapoint_descriptors_scaler)

test_mcdset = data.MulticomponentDataset(test_datasets)
# tmp
# test_mcdset.normalize_inputs("X_d", extra_datapoint_descriptors_scaler)

In [40]:
train_loader = data.build_dataloader(train_mcdset, batch_size=batch_size)
val_loader = data.build_dataloader(val_mcdset, shuffle=False, batch_size=batch_size)
test_loader = data.build_dataloader(test_mcdset, shuffle=False, batch_size=batch_size)

## Training

In [61]:
def get_module_dict(model_pt, name):
    return {k.replace(f'{name}.', ''): v \
                 for k, v in model_pt.items() \
                 if k.startswith(f'{name}.')}

def load_model(ckpt_path=None):
    mcmp = nn.MulticomponentMessagePassing(
        blocks=[
            nn.BondMessagePassing(
                    d_h=512,
                    dropout=0.1,
                    depth=3,
                    bias=True
                )
            for _ in range(len(smiles_columns))],
        n_components=len(smiles_columns),
    )
    
    output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)
    
    ffn_input_dim = mcmp.output_dim + extra_datapoint_descriptors.shape[1]
    
    
    ffn = nn.RegressionFFN(
        n_tasks=len(target_columns),
        output_transform=output_transform,
        input_dim=ffn_input_dim,
        hidden_dim=512,
        n_layers=4,
        dropout=0.5,
        activation="relu"
    )
    
    # X_d_transform = nn.ScaleTransform.from_standard_scaler(extra_datapoint_descriptors_scaler[0])
    
    mpnn = models.MulticomponentMPNN(
        mcmp,
        nn.MeanAggregation(),
        ffn,
        batch_norm=True,
        warmup_epochs=5,
        # init_lr=1e-5,
        max_lr= 2 * 1e-4,
        final_lr= 5 * 1e-5,
        metrics=[nn.metrics.RMSEMetric()],
        # X_d_transform=X_d_transform
    )

    if ckpt_path is not None:
        mpnn = mpnn.load_from_checkpoint(ckpt_path)

    return mpnn

def load_pretrained_backbone(ckpt_path=None):
    mcmp = nn.MulticomponentMessagePassing(
        blocks=[
            nn.BondMessagePassing(
                    d_h=512,
                    dropout=0.1,
                    depth=3,
                    bias=True
                )
            for _ in range(len(smiles_columns))],
        n_components=len(smiles_columns),
    )
    
    output_transform = nn.UnscaleTransform.from_standard_scaler(scaler)
    
    ffn_input_dim = mcmp.output_dim + extra_datapoint_descriptors.shape[1]
    
    
    ffn = nn.RegressionFFN(
        n_tasks=len(target_columns),
        output_transform=output_transform,
        input_dim=ffn_input_dim,
        hidden_dim=512,
        n_layers=4,
        dropout=0.5,
        activation="relu"
    )
    
    # X_d_transform = nn.ScaleTransform.from_standard_scaler(extra_datapoint_descriptors_scaler[0])
    
    mpnn = models.MulticomponentMPNN(
        mcmp,
        nn.MeanAggregation(),
        ffn,
        batch_norm=True,
        warmup_epochs=5,
        # init_lr=1e-5,
        max_lr= 1e-4,
        final_lr= 5 * 1e-5,
        metrics=[nn.metrics.RMSEMetric()],
        # X_d_transform=X_d_transform
    )

    if ckpt_path is not None:
        checkpoint = torch.load(ckpt_path)['state_dict']
        mp_statedict = get_module_dict(checkpoint, 'message_passing')
        mpnn.message_passing.load_state_dict(mp_statedict)

    return mpnn

def freeze_backbone(model):
    for param in model.message_passing.parameters():
        param.requires_grad = False

In [72]:
ckpt_path = '../gnn_4/model_2/epoch=060-val_loss=0.345.ckpt'

mpnn = load_pretrained_backbone(ckpt_path)
freeze_backbone(mpnn)

  checkpoint = torch.load(ckpt_path)['state_dict']


In [83]:
ckpt_path = [None, 'model_lqy_1/epoch=147-val_loss=0.420.ckpt'][1]

mpnn = load_model(ckpt_path)

  hparams = torch.load(checkpoint_path)["hyper_parameters"]


In [90]:
mpnn.max_lr = 1e-4
mpnn.final_lr = 1e-5

In [91]:
pytorch_total_params = sum(p.numel() for p in mpnn.parameters() if p.requires_grad)
print(f'Trainable params {pytorch_total_params}')

Trainable params 2539009


# Set up trainer

In [92]:
device = 0
model_dir = 'model_lqy_2'

In [93]:
checkpoint_cb = ModelCheckpoint(
    save_top_k=3,
    monitor="val_loss",
    mode="min",
    dirpath=model_dir,
    filename="{epoch:03d}-{val_loss:.3f}"
)

earlystopping_cb = EarlyStopping(
    monitor="val_loss",
    mode="min",
    patience=20,
    min_delta=0.0
)

trainer = pl.Trainer(
    logger=True,
    enable_progress_bar=True,
    accelerator="cuda",
    devices=[device],
    min_epochs=5,
    max_epochs=200,
    callbacks=[checkpoint_cb, earlystopping_cb],
)

/home/kashurin/soft/miniconda3/envs/chemprop/lib/python3.11/site-packages/lightning/fabric/plugins/environments/slurm.py:204: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /home/kashurin/soft/miniconda3/envs/chemprop/lib/pyt ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


# Start training

In [94]:
trainer.fit(mpnn, train_loader, val_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
Loading `train_dataloader` to estimate number of stepping batches.
/home/kashurin/soft/miniconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:298: The number of training batches (22) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.

  | Name            | Type                         | Params | Mode 
-------------------------------------------------------------------------
0 | message_passing | MulticomponentMessagePassing | 1.2 M  | train
1 | agg             | MeanAggregation              | 0      | train
2 | bn              | BatchNorm1d                  | 2.0 K  | train
3 | predictor       | RegressionFFN                | 1.3 M  | train
4 | X_d_transform   | Identity                     | 0      | train
-------------------------------------------------------------------------
2.5 M     Trainable p

Sanity Checking: |                                                                                           |…

Training: |                                                                                                  |…

Validation: |                                                                                                |…

Validation: |                                                                                                |…

Validation: |                                                                                                |…

Validation: |                                                                                                |…

Validation: |                                                                                                |…

Validation: |                                                                                                |…

Validation: |                                                                                                |…

Validation: |                                                                                                |…

Validation: |                                                                                                |…

Validation: |                                                                                                |…

Validation: |                                                                                                |…

Validation: |                                                                                                |…

Validation: |                                                                                                |…

Validation: |                                                                                                |…

Validation: |                                                                                                |…

Validation: |                                                                                                |…

Validation: |                                                                                                |…

Validation: |                                                                                                |…

Validation: |                                                                                                |…

Validation: |                                                                                                |…

Validation: |                                                                                                |…

Validation: |                                                                                                |…

Validation: |                                                                                                |…

In [95]:
earlystopping_cb.best_score

tensor(0.4263, device='cuda:0')

### Test results

In [78]:
checkpoint_cb.best_model_path

'/home/kashurin/gnn_5/model_lqy_1/epoch=147-val_loss=0.420.ckpt'

In [79]:
ckpt_path = checkpoint_cb.best_model_path
# ckpt_path = '/home/kashurin/gnn_4/model_1/epoch=191-val_loss=0.353.ckpt'

mpnn_predict = mpnn

with torch.inference_mode():
    trainer = pl.Trainer(
        logger=None,
        enable_progress_bar=True,
        devices=[device]
    )
    testing_preds = trainer.predict(mpnn_predict, test_loader, ckpt_path=ckpt_path)

testing_preds = np.concatenate(testing_preds, axis=0)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
Restoring states from the checkpoint path at /home/kashurin/gnn_5/model_lqy_1/epoch=147-val_loss=0.420.ckpt
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]
Loaded model weights from the checkpoint at /home/kashurin/gnn_5/model_lqy_1/epoch=147-val_loss=0.420.ckpt


Predicting: |                                                                                                |…

In [81]:
print(f"RMSD Log quantum yield: {rmsd(testing_preds, test_ys)}")

RMSD Log quantum yield: 1.0311842150866124


In [82]:
print(f"R2 Log quantum yield: {r2(testing_preds, test_ys)}")

R2 Log quantum yield: 0.7201068450450531
