### Key Instance Detection for Conformers

KID for conformers

In [1]:
import pickle

import numpy as np
import pandas as pd

from sklearn.metrics import r2_score, balanced_accuracy_score
from sklearn.ensemble import RandomForestRegressor, RandomForestClassifier

In [2]:
from qsarmil.descriptor.rdkit import (RDKitGEOM, 
                                      RDKitAUTOCORR, 
                                      RDKitRDF, 
                                      RDKitMORSE, 
                                      RDKitWHIM, 
                                      RDKitGETAWAY)

from molfeat.calc import (Pharmacophore3D, 
                          USRDescriptors, 
                          ElectroShapeDescriptors)

from qsarmil.descriptor.wrapper import DescriptorWrapper

ModuleNotFoundError: No module named 'molfeat'

In [None]:
import logging
import warnings
warnings.filterwarnings("ignore")
logging.getLogger("pytorch_lightning").setLevel(logging.ERROR)
logging.getLogger("lightning").setLevel(logging.ERROR)

import time
import torch
import random

import numpy as np
import pandas as pd

import matplotlib.pyplot as plt

# MNIST dataset creation
from milearn.data.mnist import load_mnist, create_bags_or, create_bags_and, create_bags_xor, create_bags_reg

# Preprocessing
from milearn.preprocessing import BagMinMaxScaler

# Network hparams
from milearn.network.module.hopt import DEFAULT_PARAM_GRID

# MIL wrappers
from milearn.network.regressor import BagWrapperMLPNetworkRegressor, InstanceWrapperMLPNetworkRegressor
from milearn.network.classifier import BagWrapperMLPNetworkClassifier, InstanceWrapperMLPNetworkClassifier

# MIL networks
from milearn.network.regressor import (InstanceNetworkRegressor,
                                       BagNetworkRegressor,
                                       AdditiveAttentionNetworkRegressor,
                                       SelfAttentionNetworkRegressor,
                                       HopfieldAttentionNetworkRegressor,
                                       DynamicPoolingNetworkRegressor)

from milearn.network.classifier import (InstanceNetworkClassifier,
                                        BagNetworkClassifier,
                                        AdditiveAttentionNetworkClassifier,
                                        SelfAttentionNetworkClassifier,
                                        HopfieldAttentionNetworkClassifier,
                                        DynamicPoolingNetworkClassifier)

# Utils
from sklearn.metrics import r2_score, accuracy_score
from sklearn.model_selection import train_test_split

In [None]:
from typing import List, Tuple

def kid_accuracy(true_key_indices, predicted_weights, top_n=1):
    assert len(predicted_weights) == len(true_key_indices), "Mismatched input lengths."

    hits = 0
    total = len(predicted_weights)

    for bag_weights, key_indices in zip(predicted_weights, true_key_indices):
        # Get indices of top-N predicted instances
        top_n_indices = sorted(range(len(bag_weights)), key=lambda i: bag_weights[i], reverse=True)[:top_n]
        # Check for overlap with any true key instance
        if any(idx in top_n_indices for idx in key_indices):
            hits += 1

    return hits / total if total > 0 else 0.0

def normalized_entropy(weights, epsilon=1e-12):
    """
    Returns:
    - norm_entropy: float in [0, 1], where 0 = sharp, 1 = flat
    """
    weights = np.asarray(weights, dtype=np.float64)
    weights = weights / (weights.sum() + epsilon)  # normalize

    entropy = -np.sum(weights * np.log(weights + epsilon))
    max_entropy = np.log(len(weights) + epsilon)

    return entropy / max_entropy

In [None]:
from math import comb

def expected_kid_accuracy(true_key_indices, bag_sizes, top_n=1):

    assert len(true_key_indices) == len(bag_sizes), "Mismatched input lengths."
    expected_hits = 0

    for key_indices, B in zip(true_key_indices, bag_sizes):
        K = len(key_indices)
        N = min(top_n, B)  # top_n can't exceed bag size

        if K == 0 or B == 0:
            continue  # skip invalid bags

        if B - K < N:
            hit_prob = 1.0  # guaranteed to pick a key instance
        else:
            hit_prob = 1 - (comb(B - K, N) / comb(B, N))

        expected_hits += hit_prob

    return expected_hits / len(true_key_indices) if true_key_indices else 0.0

### 1. Load data
### TODO: fix conformer indexing

In [None]:
with open("kid_data/actives_dm_train_allconf.pkl", "rb") as f:
    data_train = pickle.load(f)

with open("kid_data/actives_dm_test_allconf.pkl", "rb") as f:
    data_test = pickle.load(f)

In [None]:
# molecules
mol_train = [i[1] for i in data_train]
mol_test = [i[1] for i in data_test]

# property
prop_train = [i[3].item() for i in data_train]
prop_test = [i[3].item() for i in data_test]

# key instances
kid_test = []
for mol in data_test:
    key_list = [int(conf.split("_")[-1]) for conf in mol[2]] # wrong indexing
    
    kid_test.append(key_list)

### 2. Benchmark configuration

In [None]:
desc_list = [
             ("RDKitGEOM", DescriptorWrapper(RDKitGEOM())),
             ("RDKitAUTOCORR", DescriptorWrapper(RDKitAUTOCORR())),
             ("RDKitRDF", DescriptorWrapper(RDKitRDF())),
             ("RDKitMORSE", DescriptorWrapper(RDKitMORSE())),
             ("RDKitWHIM", DescriptorWrapper(RDKitWHIM())),
             ("RDKitGETAWAY", DescriptorWrapper(RDKitGETAWAY())), # can be long
             ("MolFeatUSRD", DescriptorWrapper(USRDescriptors())),
             ("MolFeatElectroShape", DescriptorWrapper(ElectroShapeDescriptors())),
             ("MolFeatPmapper", DescriptorWrapper(Pharmacophore3D(factory='pmapper'))), # can be long
            ]

In [None]:
regressor_list = [

        # attention mil networks
        ("AdditiveAttentionNetworkRegressor", AdditiveAttentionNetworkRegressor()),
        ("SelfAttentionNetworkRegressor", SelfAttentionNetworkRegressor()),
        ("HopfieldAttentionNetworkRegressor", HopfieldAttentionNetworkRegressor()),

        # other mil networks
        ("DynamicPoolingNetworkRegressor", DynamicPoolingNetworkRegressor()),
    ]

classifier_list = [

        # attention mil networks
        ("AdditiveAttentionNetworkClassifier", AdditiveAttentionNetworkClassifier()),
        ("SelfAttentionNetworkClassifier", SelfAttentionNetworkClassifier()),
        ("HopfieldAttentionNetworkClassifier", HopfieldAttentionNetworkClassifier()),

        # other mil networks
        ("DynamicPoolingNetworkClassifier", DynamicPoolingNetworkClassifier()),
    ]

### 3. Benchmark regression models

In [None]:
task = "reg"

if task == "reg":
    estimator_list = regressor_list
    accuracy_metric = r2_score
    y_train = prop_train
    y_test = prop_test
elif task == "clf":
    threshold = 7
    estimator_list = clissifier_list
    accuracy_metric = balanced_accuracy_score
    y_train = [1 if i >= threshold else 0 for i in prop_train]
    y_test = [1 if i >= threshold else 0 for i in prop_test]

In [None]:
n = 0
total_n = len(desc_list) * len(regressor_list)

res_df = pd.DataFrame()
for desc_name, desc_calc in desc_list:

    # 1. Calc descriptors
    x_train = desc_calc.transform(mol_train)
    x_test = desc_calc.transform(mol_test)

    # 2. Scale descriptors
    scaler = BagMinMaxScaler()
    scaler.fit(x_train)
    x_train_scaled = scaler.transform(x_train)
    x_test_scaled = scaler.transform(x_test)
    
    for method_name, model in estimator_list:
        
        # 3. Train model
        model.fit(x_train_scaled, y_train)
        
        # 4. Get predictions
        y_pred = model.predict(x_test_scaled)
        if task == "clf":
            y_prob = model.predict(x_test_scaled)
            y_pred = np.where(y_prob > 0.5, 1, 0)
        
        res_df.loc[f"{desc_name}|{method_name}", "ACC"] = accuracy_metric(y_test, y_pred)

        # 5. Key instance detection
        w_pred = model.get_instance_weights(x_test_scaled)
        w_pred = [w.flatten() for w in w_pred]

        # calc kid accuracy
        for top_n in [1, 2, 3]:
            res_df.loc[f"{desc_name}|{method_name}", f"TOP-{top_n}"] = kid_accuracy(kid_test, w_pred, top_n=top_n)

        # cacl weights entropy
        ent_mean = np.mean([normalized_entropy(w) for w in w_pred]).item()
        res_df.loc[f"{desc_name}|{method_name}", "ENT"] = ent_mean
        
        # logging
        n += 1
        print(f"{n}/{total_n} {desc_name}|{method_name}", end="\r")

# save results
res_df.to_csv("kid_results.csv")

In [None]:
bag_sizes = [len(i) for i in x_test]
for top_n in [1, 2, 3]:
    baseline = expected_kid_accuracy(kid_test, bag_sizes, top_n=top_n)
    print(f"Expected KID accuracy Top-{top_n} = {baseline:.2f}")

In [None]:
res_df.sort_values(by="TOP-1", ascending=False)

### 4. Best model analysis

In [None]:
desc_calc = DescriptorWrapper(Pharmacophore3D(factory='pmapper'))
# desc_calc = DescriptorWrapper(RDKitGEOM())

# 1. Calc descriptors
x_train = desc_calc.transform(mol_train)
x_test = desc_calc.transform(mol_test)

# 2. Scale descriptors
scaler = BagMinMaxScaler()
scaler.fit(x_train)
x_train_scaled = scaler.transform(x_train)
x_test_scaled = scaler.transform(x_test)

In [None]:
task = "clf"

if task == "reg":
    accuracy_metric = r2_score
    y_train = prop_train
    y_test = prop_test
elif task == "clf":
    threshold = 7
    accuracy_metric = balanced_accuracy_score
    y_train = [1 if i >= threshold else 0 for i in prop_train]
    y_test = [1 if i >= threshold else 0 for i in prop_test]

In [None]:
network_hparams = {'hidden_layer_sizes':(256, 128, 64),
                   'num_epoch':300,
                   'batch_size':128,
                   'learning_rate':0.001,
                   'weight_decay':0.001,
                   'instance_weight_dropout':0.01,
                   'init_cuda':False,
                   'verbose':False}

In [None]:
# model = AttentionNetworkRegressor(**network_hparams)
# model = DynamicPoolingNetworkRegressor(**network_hparams)
model = DynamicPoolingNetworkClassifier(**network_hparams)

model.fit(x_train_scaled, y_train)

In [None]:
y_pred = model.predict(x_test_scaled)
if task == "clf":
    y_prob = model.predict(x_test_scaled)
    y_pred = np.where(y_prob > 0.5, 1, 0)
w_pred = model.get_instance_weights(x_test_scaled)

In [None]:
top_n = 3

print(accuracy_metric(y_test, y_pred))
print(kid_accuracy(kid_test, w_pred, top_n=top_n))
print(expected_kid_accuracy(kid_test, [len(i) for i in w_pred], top_n=top_n))

kid_test_7, w_pred_7 = [], []
for y, k, w in zip(y_test, kid_test, w_pred):
    if y == 1:
        kid_test_7.append(k)
        w_pred_7.append(w)
print()
print(kid_accuracy(kid_test_7, w_pred_7, top_n=top_n))
print(expected_kid_accuracy(kid_test_7, [len(i) for i in w_pred_7], top_n=top_n))

In [None]:
tmp = []
for n, y in enumerate(y_test):
    if y == 7:
        tmp.append(n)
tmp[:10]

In [None]:
N = 15

In [None]:
y_test[N]

In [None]:
kid_test[N]

In [None]:
w_pred[N].round(2)