### Key Instance Detection for Conformers

KID for conformers

In [1]:
import pickle

import numpy as np
import pandas as pd

from sklearn.metrics import r2_score, balanced_accuracy_score
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier

In [2]:
from qsarmil.descriptor.rdkit import (RDKitGEOM, 
                                      RDKitAUTOCORR, 
                                      RDKitRDF, 
                                      RDKitMORSE, 
                                      RDKitWHIM, 
                                      RDKitGETAWAY)

from molfeat.calc import (Pharmacophore3D, 
                          USRDescriptors, 
                          ElectroShapeDescriptors)

from qsarmil.descriptor.wrapper import DescriptorWrapper

In [3]:
from milearn.wrapper import InstanceWrapper, BagWrapper

from milearn.network.regressor import InstanceNetworkRegressor, BagNetworkRegressor
from milearn.network.classifier import InstanceNetworkClassifier, BagNetworkClassifier

# MIL regressors
from milearn.network.regressor import (AttentionNetworkRegressor,
                                       TempAttentionNetworkRegressor,
                                       GatedAttentionNetworkRegressor,
                                       MultiHeadAttentionNetworkRegressor,
                                       SelfAttentionNetworkRegressor,
                                       HopfieldAttentionNetworkRegressor,
                                       DynamicPoolingNetworkRegressor)

# MIL classifiers
from milearn.network.classifier import (AttentionNetworkClassifier,
                                        TempAttentionNetworkClassifier,
                                        GatedAttentionNetworkClassifier,
                                        MultiHeadAttentionNetworkClassifier,
                                        SelfAttentionNetworkClassifier,
                                        HopfieldAttentionNetworkClassifier,
                                        DynamicPoolingNetworkClassifier)

from milearn.preprocessing import BagMinMaxScaler

  return torch._C._cuda_getDeviceCount() > 0


In [4]:
from typing import List, Tuple

def kid_accuracy(true_key_indices, predicted_weights, top_n=1):
    assert len(predicted_weights) == len(true_key_indices), "Mismatched input lengths."

    hits = 0
    total = len(predicted_weights)

    for bag_weights, key_indices in zip(predicted_weights, true_key_indices):
        # Get indices of top-N predicted instances
        top_n_indices = sorted(range(len(bag_weights)), key=lambda i: bag_weights[i], reverse=True)[:top_n]
        # Check for overlap with any true key instance
        if any(idx in top_n_indices for idx in key_indices):
            hits += 1

    return hits / total if total > 0 else 0.0

def normalized_entropy(weights, epsilon=1e-12):
    """
    Returns:
    - norm_entropy: float in [0, 1], where 0 = sharp, 1 = flat
    """
    weights = np.asarray(weights, dtype=np.float64)
    weights = weights / (weights.sum() + epsilon)  # normalize

    entropy = -np.sum(weights * np.log(weights + epsilon))
    max_entropy = np.log(len(weights) + epsilon)

    return entropy / max_entropy

In [5]:
from math import comb

def expected_kid_accuracy(true_key_indices, bag_sizes, top_n=1):

    assert len(true_key_indices) == len(bag_sizes), "Mismatched input lengths."
    expected_hits = 0

    for key_indices, B in zip(true_key_indices, bag_sizes):
        K = len(key_indices)
        N = min(top_n, B)  # top_n can't exceed bag size

        if K == 0 or B == 0:
            continue  # skip invalid bags

        if B - K < N:
            hit_prob = 1.0  # guaranteed to pick a key instance
        else:
            hit_prob = 1 - (comb(B - K, N) / comb(B, N))

        expected_hits += hit_prob

    return expected_hits / len(true_key_indices) if true_key_indices else 0.0

### 1. Load data
### TODO: fix conformer indexing

In [6]:
with open("kid_data/actives_dm_train_allconf.pkl", "rb") as f:
    data_train = pickle.load(f)

with open("kid_data/actives_dm_test_allconf.pkl", "rb") as f:
    data_test = pickle.load(f)

In [7]:
# molecules
mol_train = [i[1] for i in data_train]
mol_test = [i[1] for i in data_test]

# property
prop_train = [i[3].item() for i in data_train]
prop_test = [i[3].item() for i in data_test]

# key instances
kid_test = []
for mol in data_test:
    key_list = [int(conf.split("_")[-1]) for conf in mol[2]] # wrong indexing
    
    kid_test.append(key_list)

### 2. Benchmark configuration

In [8]:
network_hparams = {'hidden_layer_sizes':(256, 128, 64),
                   'num_epoch':300,
                   'batch_size':128,
                   'learning_rate':0.001,
                   'weight_decay':0.001,
                   'instance_weight_dropout':0.01,
                   'init_cuda':False,
                   'verbose':False}

In [9]:
desc_list = [
             ("RDKitGEOM", DescriptorWrapper(RDKitGEOM())),
             ("RDKitAUTOCORR", DescriptorWrapper(RDKitAUTOCORR())),
             ("RDKitRDF", DescriptorWrapper(RDKitRDF())),
             ("RDKitMORSE", DescriptorWrapper(RDKitMORSE())),
             ("RDKitWHIM", DescriptorWrapper(RDKitWHIM())),
             ("RDKitGETAWAY", DescriptorWrapper(RDKitGETAWAY())), # can be long
             ("MolFeatPmapper", DescriptorWrapper(Pharmacophore3D(factory='pmapper'))), # can be long
             ("MolFeatUSRD", DescriptorWrapper(USRDescriptors())),
             ("MolFeatElectroShape", DescriptorWrapper(ElectroShapeDescriptors())),
            ]

In [10]:
regressor_list = [
                   ("AttentionNetworkRegressor", AttentionNetworkRegressor(**network_hparams)),
                   ("TempAttentionNetworkRegressor", TempAttentionNetworkRegressor(**network_hparams, tau=0.1)),
                   ("GatedAttentionNetworkRegressor", GatedAttentionNetworkRegressor(**network_hparams)),
                   ("MultiHeadAttentionNetworkRegressor", MultiHeadAttentionNetworkRegressor(**network_hparams, num_heads=6)),
                   ("SelfAttentionNetworkRegressor", SelfAttentionNetworkRegressor(**network_hparams)),
                   ("HopfieldAttentionNetworkRegressor", HopfieldAttentionNetworkRegressor(**network_hparams)),
                   ("DynamicPoolingNetworkRegressor", DynamicPoolingNetworkRegressor(**network_hparams))
              ]

clissifier_list = [
                   ("AttentionNetworkClassifier", AttentionNetworkClassifier(**network_hparams)),
                   ("TempAttentionNetworkClassifier", TempAttentionNetworkClassifier(**network_hparams, tau=0.1)),
                   ("GatedAttentionNetworkClassifier", GatedAttentionNetworkClassifier(**network_hparams)),
                   ("MultiHeadAttentionNetworkClassifier", MultiHeadAttentionNetworkClassifier(**network_hparams, num_heads=6)),
                   ("SelfAttentionNetworkClassifier", SelfAttentionNetworkClassifier(**network_hparams)),
                   ("HopfieldAttentionNetworkClassifier", HopfieldAttentionNetworkClassifier(**network_hparams)),
                   ("DynamicPoolingNetworkClassifier", DynamicPoolingNetworkClassifier(**network_hparams))
              ]

### 3. Benchmark regression models

In [11]:
task = "reg"

if task == "reg":
    estimator_list = regressor_list
    accuracy_metric = r2_score
    y_train = prop_train
    y_test = prop_test
elif task == "clf":
    threshold = 7
    estimator_list = clissifier_list
    accuracy_metric = balanced_accuracy_score
    y_train = [1 if i >= threshold else 0 for i in prop_train]
    y_test = [1 if i >= threshold else 0 for i in prop_test]

In [12]:
n = 0
total_n = len(desc_list) * len(regressor_list)

res_df = pd.DataFrame()
for desc_name, desc_calc in desc_list:

    # 1. Calc descriptors
    x_train = desc_calc.transform(mol_train)
    x_test = desc_calc.transform(mol_test)

    # 2. Scale descriptors
    scaler = BagMinMaxScaler()
    scaler.fit(x_train)
    x_train_scaled = scaler.transform(x_train)
    x_test_scaled = scaler.transform(x_test)
    
    for method_name, model in estimator_list:
        
        # 3. Train model
        model.fit(x_train_scaled, y_train)
        
        # 4. Get predictions
        y_pred = model.predict(x_test_scaled)
        if task == "clf":
            y_prob = model.predict(x_test_scaled)
            y_pred = np.where(y_prob > 0.5, 1, 0)
        
        res_df.loc[f"{desc_name}|{method_name}", "ACC"] = accuracy_metric(y_test, y_pred)

        # 5. Key instance detection
        w_pred = model.get_instance_weights(x_test_scaled)

        # calc kid accuracy
        for top_n in [1, 2, 3]:
            res_df.loc[f"{desc_name}|{method_name}", f"TOP-{top_n}"] = kid_accuracy(kid_test, w_pred, top_n=top_n)

        # cacl weights entropy
        ent_mean = np.mean([normalized_entropy(w) for w in w_pred]).item()
        res_df.loc[f"{desc_name}|{method_name}", "ENT"] = ent_mean
        
        # logging
        n += 1
        print(f"{n}/{total_n} {desc_name}|{method_name}", end="\r")

# save results
# res_df.to_csv("kid_results.csv")

63/63 MolFeatElectroShape|DynamicPoolingNetworkClassifierierr

In [13]:
bag_sizes = [len(i) for i in x_test]
for top_n in [1, 2, 3]:
    baseline = expected_kid_accuracy(kid_test, bag_sizes, top_n=top_n)
    print(f"Expected KID accuracy Top-{top_n} = {baseline:.2f}")

Expected KID accuracy Top-1 = 0.10
Expected KID accuracy Top-2 = 0.19
Expected KID accuracy Top-3 = 0.27


In [14]:
res_df.sort_values(by="TOP-1", ascending=False)

Unnamed: 0,ACC,TOP-1,TOP-2,TOP-3,ENT
MolFeatPmapper|DynamicPoolingNetworkClassifier,0.884274,0.182584,0.353933,0.435393,0.940229
MolFeatPmapper|AttentionNetworkClassifier,0.891170,0.165730,0.283708,0.410112,0.975658
RDKitGEOM|MultiHeadAttentionNetworkClassifier,0.710397,0.165730,0.252809,0.353933,0.992118
RDKitGEOM|DynamicPoolingNetworkClassifier,0.626228,0.162921,0.255618,0.376404,0.937492
RDKitWHIM|SelfAttentionNetworkClassifier,0.874295,0.157303,0.238764,0.362360,0.754035
...,...,...,...,...,...
RDKitWHIM|TempAttentionNetworkClassifier,0.881870,0.070225,0.140449,0.238764,0.477779
MolFeatPmapper|SelfAttentionNetworkClassifier,0.933177,0.067416,0.120787,0.219101,0.997874
RDKitGEOM|SelfAttentionNetworkClassifier,0.500000,0.064607,0.112360,0.171348,0.997816
MolFeatUSRD|MultiHeadAttentionNetworkClassifier,0.806426,0.064607,0.123596,0.193820,0.998472


### 4. Best model analysis

In [15]:
desc_calc = DescriptorWrapper(Pharmacophore3D(factory='pmapper'))
# desc_calc = DescriptorWrapper(RDKitGEOM())

# 1. Calc descriptors
x_train = desc_calc.transform(mol_train)
x_test = desc_calc.transform(mol_test)

# 2. Scale descriptors
scaler = BagMinMaxScaler()
scaler.fit(x_train)
x_train_scaled = scaler.transform(x_train)
x_test_scaled = scaler.transform(x_test)

In [16]:
network_hparams = {'hidden_layer_sizes':(256, 128, 64),
                   'num_epoch':300,
                   'batch_size':128,
                   'learning_rate':0.001,
                   'weight_decay':0.001,
                   'instance_weight_dropout':0.01,
                   'init_cuda':False,
                   'verbose':False}

In [17]:
# model = AttentionNetworkRegressor(**network_hparams)
model = DynamicPoolingNetworkRegressor(**network_hparams)

model.fit(x_train_scaled, y_train)

DynamicPoolingNetworkRegressor(
  (extractor): Sequential(
    (0): Linear(in_features=11, out_features=256, bias=True)
    (1): ReLU()
    (2): Linear(in_features=256, out_features=128, bias=True)
    (3): ReLU()
    (4): Linear(in_features=128, out_features=64, bias=True)
    (5): ReLU()
  )
  (pooling): DynamicPooling()
  (estimator): Norm()
)

In [18]:
y_pred = model.predict(x_test_scaled)
w_pred = model.get_instance_weights(x_test_scaled)

In [19]:
top_n = 1

print(r2_score(y_test, y_pred))
print(kid_accuracy(kid_test, w_pred, top_n=top_n))
print(expected_kid_accuracy(kid_test, [len(i) for i in w_pred], top_n=top_n))

kid_test_7, w_pred_7 = [], []
for y, k, w in zip(y_test, kid_test, w_pred):
    if y == 7:
        kid_test_7.append(k)
        w_pred_7.append(w)
print()
print(kid_accuracy(kid_test_7, w_pred_7, top_n=top_n))
print(expected_kid_accuracy(kid_test_7, [len(i) for i in w_pred_7], top_n=top_n))

0.07598596811294556
0.14887640449438203
0.09692859451562535

0.0
0.0


In [20]:
tmp = []
for n, y in enumerate(y_test):
    if y == 7:
        tmp.append(n)
tmp[:10]

[]

In [21]:
N = 15

In [22]:
y_test[N]

1

In [23]:
kid_test[N]

[2]

In [24]:
w_pred[N].round(2)

array([0.05, 0.05, 0.05, 0.05, 0.05], dtype=float32)