## 1. Load dataset

The example datasets contain molecule structure (SMILES) and measured bioactivity (pKi or IC50) – the higher the better. Each SMILES is converted to a Mol object in RDKit.

In [1]:
import numpy as np
import pandas as pd
from rdkit import Chem

from sklearn.metrics import r2_score, accuracy_score

In [2]:
def reg_to_clf(y):
    return np.where(np.array(y) > 6, 1, 0)

def accuracy_metric(y_true, y_pred, task=None):
    if task == "classification":
        return accuracy_score(y_true, y_pred)
    elif task == "regression":
        return r2_score(y_true, y_pred)

In [3]:
# TASK = "regression"
TASK = "classification"

In [4]:
data_train = pd.read_csv('data/CHEMBL1824/train.csv', header=None)
data_test = pd.read_csv('data/CHEMBL1824/test.csv', header=None)

In [5]:
smi_train, prop_train = data_train[0].to_list(), data_train[1].to_list()
smi_test, prop_test = data_test[0].to_list(), data_test[1].to_list()

if TASK == "classification":
    prop_train, prop_test = reg_to_clf(prop_train), reg_to_clf(prop_test)

In [6]:
mols_train, y_train = [], []
for smi, prop in zip(smi_train, prop_train):
    mol = Chem.MolFromSmiles(smi)
    if mol:
        mols_train.append(mol)
        y_train.append(prop)

In [7]:
mols_test, y_test = [], []
for smi, prop in zip(smi_test, prop_test):
    mol = Chem.MolFromSmiles(smi)
    if mol:
        mols_test.append(mol)
        y_test.append(prop)

## 1.5 Reduce the dataset size for faster pipeline (for playing around)

In [8]:
# mols_train, y_train = mols_train[:80], y_train[:80]
# mols_test, y_test = mols_test[:20], y_test[:20]

## 2. Conformer generation

For each molecule, an ensemble of conformers is generated. Then, molecules for which conformer generation failed are filtered out from both, the training and test set. Generated conformers can be accessed by mol.GetConformers(confID=0).

In [9]:
from qsarmil.conformer import RDKitConformerGenerator

from qsarmil.utils.logging import FailedConformer, FailedDescriptor

In [10]:
conf_gen = RDKitConformerGenerator(num_conf=10, num_cpu=40)

In [11]:
confs_train = conf_gen.run(mols_train)

tmp = [(c, y) for c, y in zip(confs_train, y_train) if not isinstance(c, FailedConformer)]
confs_train, y_train = zip(*tmp) 
confs_train, y_train = list(confs_train), list(y_train)

Generating conformers: 100%|████████████████████████████████████████████████████████| 1667/1667 [01:10<00:00, 23.77it/s]


In [12]:
confs_test = conf_gen.run(mols_test)

tmp = [(c, y) for c, y in zip(confs_test, y_test) if not isinstance(c, FailedConformer)]
confs_test, y_test = zip(*tmp) 
confs_test, y_test = list(confs_test), list(y_test)

Generating conformers: 100%|██████████████████████████████████████████████████████████| 556/556 [00:28<00:00, 19.59it/s]


## 3. Descriptor calculation

Then, for each molecule with associated conformers 3D descriptors are calculated. Here, a descriptor wrapper is used, which is designed to apply descriptor calculators from external packages. The resulting descriptors are a list of 2D arrays (bags). Also, the resulting descriptors are scaled.

In [13]:
from qsarmil.descriptor.rdkit import (RDKitGEOM, 
                                      RDKitAUTOCORR, 
                                      RDKitRDF, 
                                      RDKitMORSE, 
                                      RDKitWHIM, 
                                      RDKitGETAWAY)

from molfeat.calc import Pharmacophore3D, USRDescriptors, ElectroShapeDescriptors

from qsarmil.descriptor.wrapper import DescriptorWrapper

from milearn.preprocessing import BagMinMaxScaler

In [14]:
desc_calc = DescriptorWrapper(RDKitRDF())

In [15]:
x_train = desc_calc.transform(confs_train)
x_test = desc_calc.transform(confs_test)

In [16]:
scaler = BagMinMaxScaler()

scaler.fit(x_train)

x_train_scaled = scaler.transform(x_train)
x_test_scaled = scaler.transform(x_test)

## 4. Model training

In [17]:
from sklearn.metrics import r2_score, accuracy_score

from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier

from milearn.wrapper import InstanceWrapper, BagWrapper

from milearn.network.regressor import InstanceNetworkRegressor, BagNetworkRegressor
from milearn.network.classifier import InstanceNetworkClassifier, BagNetworkClassifier

# MIL regressors
from milearn.network.regressor import (AttentionNetworkRegressor,
                                       TempAttentionNetworkRegressor,
                                       GatedAttentionNetworkRegressor,
                                       MultiHeadAttentionNetworkRegressor,
                                       SelfAttentionNetworkRegressor,
                                       HopfieldAttentionNetworkRegressor,
                                       DynamicPoolingNetworkRegressor)

# MIL classifiers
from milearn.network.classifier import (AttentionNetworkClassifier,
                                        TempAttentionNetworkClassifier,
                                        GatedAttentionNetworkClassifier,
                                        MultiHeadAttentionNetworkClassifier,
                                        SelfAttentionNetworkClassifier,
                                        HopfieldAttentionNetworkClassifier,
                                        DynamicPoolingNetworkClassifier)

  return torch._C._cuda_getDeviceCount() > 0


In [18]:
network_hparams = {'hidden_layer_sizes':(256, 128, 64),
                   'num_epoch':300,
                   'batch_size':128,
                   'learning_rate':0.001,
                   'weight_decay':0.001,
                   'instance_weight_dropout':0.01,
                   'init_cuda':False,
                   'verbose':False}

In [19]:
regressor_list = [
                   ("MeanInstanceWrapperRegressor", InstanceWrapper(estimator=RandomForestRegressor(), pool="mean")), 
                   ("MaxInstanceWrapperRegressor", InstanceWrapper(RandomForestRegressor(), pool="max")), 
                   ("MeanBagWrapperRegressor", BagWrapper(RandomForestRegressor(), pool="mean")), 
                   ("MaxBagWrapperRegressor", BagWrapper(RandomForestRegressor(), pool="max")), 
                   ("MinBagWrapperRegressor", BagWrapper(RandomForestRegressor(), pool="min")), 
                   ("ExtremeBagWrapperRegressor", BagWrapper(RandomForestRegressor(), pool="extreme")),
                   ("MeanInstanceNetworkRegressor", InstanceNetworkRegressor(**network_hparams, pool="mean")),
                   ("MaxInstanceNetworkRegressor", InstanceNetworkRegressor(**network_hparams, pool="max")),
                   ("MeanBagNetworkRegressor", BagNetworkRegressor(**network_hparams, pool="mean")),
                   ("MaxBagNetworkRegressor", BagNetworkRegressor(**network_hparams, pool="max")),
                   ("AttentionNetworkRegressor", AttentionNetworkRegressor(**network_hparams)),
                   ("TempAttentionNetworkRegressor", AttentionNetworkRegressor(**network_hparams)),
                   ("GatedAttentionNetworkRegressor", GatedAttentionNetworkRegressor(**network_hparams)),
                   ("MultiHeadAttentionNetworkRegressor", SelfAttentionNetworkRegressor(**network_hparams)),
                   ("SelfAttentionNetworkRegressor", SelfAttentionNetworkRegressor(**network_hparams)),
                   ("HopfieldAttentionNetworkRegressor", HopfieldAttentionNetworkRegressor(**network_hparams)),
                   ("DynamicPoolingNetworkRegressor", DynamicPoolingNetworkRegressor(**network_hparams))
              ]

classifier_list = [
                   ("MeanInstanceWrapperClassifier", InstanceWrapper(estimator=RandomForestClassifier(), pool="mean")), 
                   ("MaxInstanceWrapperClassifier", InstanceWrapper(RandomForestClassifier(), pool="max")), 
                   ("MeanBagWrapperClassifier", BagWrapper(RandomForestClassifier(), pool="mean")), 
                   ("MaxBagWrapperClassifier", BagWrapper(RandomForestClassifier(), pool="max")), 
                   ("MinBagWrapperClassifier", BagWrapper(RandomForestClassifier(), pool="min")), 
                   ("ExtremeBagWrapperClassifier", BagWrapper(RandomForestClassifier(), pool="extreme")),
                   ("MeanInstanceNetworkClassifier", InstanceNetworkClassifier(**network_hparams, pool="mean")),
                   ("MaxInstanceNetworkClassifier", InstanceNetworkClassifier(**network_hparams, pool="max")),
                   ("MeanBagNetworkClassifier", BagNetworkClassifier(**network_hparams, pool="mean")),
                   ("MaxBagNetworkClassifier", BagNetworkClassifier(**network_hparams, pool="max")),
                   ("AttentionNetworkClassifier", AttentionNetworkClassifier(**network_hparams)),
                   ("TempAttentionNetworkClassifier", AttentionNetworkClassifier(**network_hparams)),
                   ("GatedAttentionNetworkClassifier", GatedAttentionNetworkClassifier(**network_hparams)),
                   ("MultiHeadAttentionNetworkClassifier", SelfAttentionNetworkClassifier(**network_hparams)),
                   ("SelfAttentionNetworkClassifier", SelfAttentionNetworkClassifier(**network_hparams)),
                   ("HopfieldAttentionNetworkClassifier", HopfieldAttentionNetworkClassifier(**network_hparams)),
                   ("DynamicPoolingNetworkClassifier", DynamicPoolingNetworkClassifier(**network_hparams))
              ]

In [20]:
if TASK == "regression":
    method_list = regressor_list
elif TASK == "classification":
    method_list = classifier_list

res_df = pd.DataFrame()
for method_name, model in method_list:
    model.fit(x_train_scaled, y_train)

    if TASK == "regression":
        y_pred = model.predict(x_test_scaled)
    elif TASK == "classification":
        y_prob = model.predict(x_test_scaled)
        y_pred = np.where(y_prob > 0.5, 1, 0)
    
    res_df.loc[method_name, "ACC"] = accuracy_metric(y_test, y_pred, task=TASK)

In [21]:
res_df.sort_values(by="ACC", ascending=False)

Unnamed: 0,ACC
MeanInstanceNetworkClassifier,0.897482
TempAttentionNetworkClassifier,0.881295
MultiHeadAttentionNetworkClassifier,0.875899
MeanBagNetworkClassifier,0.875899
GatedAttentionNetworkClassifier,0.874101
SelfAttentionNetworkClassifier,0.870504
AttentionNetworkClassifier,0.870504
HopfieldAttentionNetworkClassifier,0.870504
MeanInstanceWrapperClassifier,0.854317
DynamicPoolingNetworkClassifier,0.854317
