### Key Instance Detection for Conformers

KID for conformers

In [1]:
import logging
import warnings
warnings.filterwarnings("ignore")
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
logging.getLogger("lightning").setLevel(logging.ERROR)

import time
import torch
import pickle
import random

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

# Preprocessing
from milearn.preprocessing import BagMinMaxScaler

# Network hparams
from milearn.network.module.hopt import DEFAULT_PARAM_GRID

# MIL wrappers
from milearn.network.regressor import BagWrapperMLPNetworkRegressor, InstanceWrapperMLPNetworkRegressor

# MIL networks
from milearn.network.regressor import (InstanceNetworkRegressor,
                                       BagNetworkRegressor,
                                       AdditiveAttentionNetworkRegressor,
                                       SelfAttentionNetworkRegressor,
                                       HopfieldAttentionNetworkRegressor,
                                       DynamicPoolingNetworkRegressor)

# Utils
from sklearn.metrics import r2_score, accuracy_score
from sklearn.model_selection import train_test_split

# 3D descriptors
from qsarmil.descriptor.rdkit import (RDKitGEOM, 
                                      RDKitAUTOCORR, 
                                      RDKitRDF, 
                                      RDKitMORSE, 
                                      RDKitWHIM, 
                                      RDKitGETAWAY)
from molfeat.calc import Pharmacophore3D, USRDescriptors, ElectroShapeDescriptors
from qsarmil.descriptor.wrapper import DescriptorWrapper

In [2]:
from rdkit import Chem
from rdkit.Chem import AllChem
import py3Dmol
import numpy as np
from IPython.display import display, HTML

def visualize_conformers_grid(mol, weights, key_conformers, top_n=5, 
                              style="stick", n_cols=3, width=250, height=250,
                              show_all=False, sort_by_weight=True):

    num_confs = mol.GetNumConformers()
    if num_confs != len(weights):
        raise ValueError("Number of weights must equal number of conformers")

    # top-N predicted indices
    top_indices = set(np.argsort(weights)[-top_n:][::-1])
    key_conformers = set(key_conformers)

    if show_all:
        conf_indices = list(range(num_confs))
    else:
        conf_indices = sorted(key_conformers.union(top_indices))

    # sort conformers by weight if requested
    if sort_by_weight:
        conf_indices = sorted(conf_indices, key=lambda i: weights[i], reverse=True)

    viewers_html = []
    for i in conf_indices:
        conf = mol.GetConformer(int(i) + 1)
        block = Chem.MolToMolBlock(mol, confId=conf.GetId())

        color = "0xAAAAAA"  # default grey
        label = f"Conf {i} (w={weights[i]:.2f})"
        if i in key_conformers:
            color = "0xFF0000"  # red
            label += " [TRUE]"
        elif i in top_indices:
            color = "0x0000FF"  # blue
            label += " [PRED]"

        viewer = py3Dmol.view(width=width, height=height)
        viewer.addModel(block, "sdf")
        viewer.setStyle({style: {"color": color}})
        viewer.zoomTo()
        
        html = viewer._make_html()
        viewers_html.append(f"<div style='display:inline-block; text-align:center;'>{html}<br>{label}</div>")

    # arrange into grid
    rows = []
    for i in range(0, len(viewers_html), n_cols):
        row_html = "".join(viewers_html[i:i+n_cols])
        rows.append(f"<div style='margin-bottom:20px'>{row_html}</div>")

    # add legend
    legend_html = """
    <div style='margin:10px 0;'>
      <b>Legend:</b> 
      <span style='color:red;'>[TRUE]=Ground truth</span> | 
      <span style='color:blue;'>[PRED]=Top predicted</span> | 
      <span style='color:gray;'>Others</span>
    </div>
    """

    display(HTML(legend_html + "".join(rows)))

In [3]:
from math import comb

def kid_accuracy(true_key_indices, predicted_weights, top_n=1):
    assert len(predicted_weights) == len(true_key_indices), "Mismatched input lengths."

    hits = 0
    total = len(predicted_weights)

    for bag_weights, key_indices in zip(predicted_weights, true_key_indices):

        top_n_indices = sorted(range(len(bag_weights)), key=lambda i: bag_weights[i], reverse=True)[:top_n]

        if any(idx in top_n_indices for idx in key_indices):
            hits += 1

    return hits / total if total > 0 else 0.0

def expected_kid_accuracy(true_key_indices, bag_sizes, top_n=1):

    assert len(true_key_indices) == len(bag_sizes), "Mismatched input lengths."
    expected_hits = 0

    for key_indices, B in zip(true_key_indices, bag_sizes):
        K = len(key_indices)
        N = min(top_n, B)  # top_n can't exceed bag size

        if K == 0 or B == 0:
            continue  # skip invalid bags

        if B - K < N:
            hit_prob = 1.0  # guaranteed to pick a key instance
        else:
            hit_prob = 1 - (comb(B - K, N) / comb(B, N))

        expected_hits += hit_prob

    return expected_hits / len(true_key_indices) if true_key_indices else 0.0

### 1. Load data

In [4]:
with open("new_data/actives_dm_train_allconf.pkl", "rb") as f:
    data_train = pickle.load(f)

with open("new_data/actives_dm_test_allconf.pkl", "rb") as f:
    data_test = pickle.load(f)

In [5]:
# molecules
mol_train = [i[1] for i in data_train]
mol_test = [i[1] for i in data_test]

# property
y_train = [i[3] for i in data_train]
y_test = [i[3] for i in data_test]

# key instances
key_train = [i[2] for i in data_train]
key_test = [i[2] for i in data_test]

### 2. Build model

In [6]:
desc_calc = DescriptorWrapper(Pharmacophore3D(factory='pmapper'))

# 1. Calc descriptors
x_train = desc_calc.transform(mol_train)
x_test = desc_calc.transform(mol_test)

# 2. Scale descriptors
scaler = BagMinMaxScaler()
scaler.fit(x_train)
x_train_scaled = scaler.transform(x_train)
x_test_scaled = scaler.transform(x_test)

In [7]:
model = DynamicPoolingNetworkRegressor()

model.hopt(x_train_scaled, y_train, param_grid=DEFAULT_PARAM_GRID, verbose=True)
model.fit(x_train_scaled, y_train)

Optimizing hyperparameter: hidden_layer_sizes (3 options)
[1/28 |  3.6% |  1.0 min] Value: (2048, 1024, 512, 256, 128, 64), Epochs: 35, Loss: 0.0111
[2/28 |  7.1% |  0.2 min] Value: (256, 128, 64), Epochs: 21, Loss: 0.0115
[3/28 | 10.7% |  0.3 min] Value: (128,), Epochs: 42, Loss: 0.0352
Best hidden_layer_sizes = (2048, 1024, 512, 256, 128, 64), val_loss = 0.0111
Optimizing hyperparameter: activation (5 options)
[4/28 | 14.3% |  1.4 min] Value: relu, Epochs: 19, Loss: 0.0107
[5/28 | 17.9% |  2.4 min] Value: leakyrelu, Epochs: 31, Loss: 0.0088
[6/28 | 21.4% |  2.0 min] Value: gelu, Epochs: 25, Loss: 0.0141
[7/28 | 25.0% |  2.9 min] Value: elu, Epochs: 42, Loss: 0.0120
[8/28 | 28.6% |  2.8 min] Value: silu, Epochs: 41, Loss: 0.0110
Best activation = leakyrelu, val_loss = 0.0088
Optimizing hyperparameter: learning_rate (2 options)
[9/28 | 32.1% |  1.1 min] Value: 0.0001, Epochs: 35, Loss: 0.0136
[10/28 | 35.7% |  1.0 min] Value: 0.001, Epochs: 29, Loss: 0.0112
Best learning_rate = 0.001, 

DynamicPoolingNetworkRegressor(
  (instance_transformer): Sequential(
    (0): Linear(in_features=2048, out_features=2048, bias=True)
    (1): LeakyReLU(negative_slope=0.01)
    (2): Linear(in_features=2048, out_features=1024, bias=True)
    (3): LeakyReLU(negative_slope=0.01)
    (4): Linear(in_features=1024, out_features=512, bias=True)
    (5): LeakyReLU(negative_slope=0.01)
    (6): Linear(in_features=512, out_features=256, bias=True)
    (7): LeakyReLU(negative_slope=0.01)
    (8): Linear(in_features=256, out_features=128, bias=True)
    (9): LeakyReLU(negative_slope=0.01)
    (10): Linear(in_features=128, out_features=64, bias=True)
    (11): LeakyReLU(negative_slope=0.01)
  )
  (bag_estimator): Norm()
  (dynamic_pooling): DynamicPooling()
)

In [8]:
y_pred = model.predict(x_test_scaled)
w_pred = model.get_instance_weights(x_test_scaled)
w_pred = [w.flatten() for w in w_pred]

In [9]:
top_n = 1

print(f"All molecules: {len(y_test)}")
print(f"Prediction accuracy: {r2_score(y_test, y_pred):.2f}")
print(f"KID prediction accuracy: {kid_accuracy(key_test, w_pred, top_n=top_n):.2f}")
print(f"KID baseline accuracy: {expected_kid_accuracy(key_test, [len(i) for i in w_pred], top_n=top_n):.2f}")

idx_7 = []
for n, y in enumerate(y_test):
    if y == 7:
        idx_7.append(n)
key_test_7 = [key_test[i] for i in idx_7]
w_pred_7 = [w_pred[i] for i in idx_7]

print(f"\nActive molecules: {len(idx_7)}")
print(f"KID prediction accuracy: {kid_accuracy(key_test_7, w_pred_7, top_n=top_n):.2f}")
print(f"KID baseline accuracy: {expected_kid_accuracy(key_test_7, [len(i) for i in w_pred_7], top_n=top_n):.2f}")

All molecules: 356
Prediction accuracy: 0.88
KID prediction accuracy: 0.23
KID baseline accuracy: 0.11

Active molecules: 66
KID prediction accuracy: 0.42
KID baseline accuracy: 0.16


In [10]:
idx_7[:15]

[3, 9, 15, 16, 29, 39, 45, 47, 48, 52, 55, 56, 63, 77, 83]

In [15]:
N = 45

visualize_conformers_grid(mol_test[N], w_pred[N], key_test[N], top_n=5, 
                          style="stick", n_cols=4, width=250, height=250, show_all=False, sort_by_weight=True)

### 3. Mini-benchmark

In [12]:
desc_list = [
             ("MolFeatPmapper", DescriptorWrapper(Pharmacophore3D(factory='pmapper'))), # can be long
            ]

In [13]:
regressor_list = [

        # attention mil networks
        ("AdditiveAttentionNetworkRegressor", AdditiveAttentionNetworkRegressor()),
        ("SelfAttentionNetworkRegressor", SelfAttentionNetworkRegressor()),
        ("HopfieldAttentionNetworkRegressor", HopfieldAttentionNetworkRegressor()),

        # other mil networks
        ("DynamicPoolingNetworkRegressor", DynamicPoolingNetworkRegressor()),
    ]

In [16]:
n = 0
total_n = len(desc_list) * len(regressor_list)

res_df = pd.DataFrame()
for desc_name, desc_calc in desc_list:

    # 1. Calc descriptors
    x_train = desc_calc.transform(mol_train)
    x_test = desc_calc.transform(mol_test)

    # 2. Scale descriptors
    scaler = BagMinMaxScaler()
    scaler.fit(x_train)
    x_train_scaled = scaler.transform(x_train)
    x_test_scaled = scaler.transform(x_test)
    
    for method_name, model in regressor_list:
        
        # 3. Train model
        model.hopt(x_train_scaled, y_train, param_grid=DEFAULT_PARAM_GRID, verbose=False)
        model.fit(x_train_scaled, y_train)
        
        # 4. Get predictions
        y_pred = model.predict(x_test_scaled)
        res_df.loc[f"{desc_name}|{method_name}", "ACC"] = r2_score(y_test, y_pred)

        # 5. Key instance detection
        w_pred = model.get_instance_weights(x_test_scaled)
        w_pred = [w.flatten() for w in w_pred]

        # calc kid accuracy
        for top_n in [1, 2, 3]:
            res_df.loc[f"{desc_name}|{method_name}", f"TOP-{top_n}"] = kid_accuracy(key_test, w_pred, top_n=top_n)
        
        # logging
        n += 1
        print(f"{n}/{total_n} {desc_name}|{method_name}", end="\r")

4/4 MolFeatPmapper|DynamicPoolingNetworkRegressorsor

In [17]:
res_df.round(2)

Unnamed: 0,ACC,TOP-1,TOP-2,TOP-3
MolFeatPmapper|AdditiveAttentionNetworkRegressor,0.91,0.15,0.24,0.31
MolFeatPmapper|SelfAttentionNetworkRegressor,0.91,0.17,0.29,0.38
MolFeatPmapper|HopfieldAttentionNetworkRegressor,0.9,0.17,0.26,0.32
MolFeatPmapper|DynamicPoolingNetworkRegressor,0.9,0.21,0.33,0.42
