In [3]:
from dataclasses import dataclass, field
import math
from pathlib import Path
import time
from typing import Callable, Union, Iterable

from lightning import pytorch as pl
import numpy as np
import pandas as pd
from rdkit import Chem
import torch

from chemprop import data, featurizers, models
from chemprop.models import MPNN

In [4]:
def make_prediction(
    models: list[MPNN],
    trainer: pl.Trainer,
    smiles: list[str],
) -> np.ndarray:
    """Makes predictions on a list of SMILES.

    Parameters
    ----------
    models : list
        A list of models to make predictions with.
    smiles : list
        A list of SMILES to make predictions on.

    Returns
    -------
    list[list[float]]
        A list of lists containing the predicted values.
    """

    test_data = [data.MoleculeDatapoint.from_smi(smi) for smi in smiles]
    test_dset = data.MoleculeDataset(test_data)
    test_loader = data.build_dataloader(
        test_dset, batch_size=1, num_workers=0, shuffle=False
    )

    with torch.inference_mode():
        sum_preds = []
        for model in models:
            predss = trainer.predict(model, test_loader)
            preds = torch.cat(predss, 0)
            preds = preds.cpu().numpy()
            sum_preds.append(preds)

        # Ensemble predictions
        sum_preds = sum(sum_preds)
        avg_preds = sum_preds / len(models)

    return avg_preds

In [5]:
@dataclass
class MCTSNode:
    """Represents a node in a Monte Carlo Tree Search.

    Parameters
    ----------
    smiles : str
        The SMILES for the substructure at this node.
    atoms : list
        A list of atom indices in the substructure at this node.
    W : float
        The total action value, which indicates how likely the deletion will lead to a good rationale.
    N : int
        The visit count, which indicates how many times this node has been visited. It is used to balance exploration and exploitation.
    P : float
        The predicted property score of the new subgraphs' after the deletion, shown as R in the original paper.
    """

    smiles: str
    atoms: Iterable[int]
    W: float = 0
    N: int = 0
    P: float = 0
    children: list[...] = field(default_factory=list)

    def __post_init__(self):
        self.atoms = set(self.atoms)

    def Q(self) -> float:
        """
        Returns
        -------
        float
            The mean action value of the node.
        """
        return self.W / self.N if self.N > 0 else 0

    def U(self, n: int, c_puct: float = 10.0) -> float:
        """
        Parameters
        ----------
        n : int
            The sum of the visit count of this node's siblings.
        c_puct : float
            A constant that controls the level of exploration.
        
        Returns
        -------
        float
            The exploration value of the node.
        """
        return c_puct * self.P * math.sqrt(n) / (1 + self.N)

In [6]:
def find_clusters(mol: Chem.Mol) -> tuple[list[tuple[int, ...]], list[list[int]]]:
    """Finds clusters within the molecule. Jin et al. from [1]_ only allows deletion of one peripheral non-aromatic bond or one peripheral ring from each state,
    so the clusters here are defined as non-ring bonds and the smallest set of smallest rings.

    Parameters
    ----------
    mol : RDKit molecule
        The molecule to find clusters in.

    Returns
    -------
    tuple
        A tuple containing:
        - list of tuples: Each tuple contains atoms in a cluster.
        - list of int: Each atom's cluster index.
    
    References
    ----------
    .. [1] Jin, Wengong, Regina Barzilay, and Tommi Jaakkola. "Multi-objective molecule generation using interpretable substructures." International conference on machine learning. PMLR, 2020. https://arxiv.org/abs/2002.03244
    """

    n_atoms = mol.GetNumAtoms()
    if n_atoms == 1:  # special case
        return [(0,)], [[0]]

    clusters = []
    for bond in mol.GetBonds():
        a1 = bond.GetBeginAtom().GetIdx()
        a2 = bond.GetEndAtom().GetIdx()
        if not bond.IsInRing():
            clusters.append((a1, a2))

    ssr = [tuple(x) for x in Chem.GetSymmSSSR(mol)]
    clusters.extend(ssr)

    atom_cls = [[] for _ in range(n_atoms)]
    for i in range(len(clusters)):
        for atom in clusters[i]:
            atom_cls[atom].append(i)

    return clusters, atom_cls

In [7]:
def extract_subgraph_from_mol(mol: Chem.Mol, selected_atoms: set[int]) -> tuple[Chem.Mol, list[int]]:
    """Extracts a subgraph from an RDKit molecule given a set of atom indices.

    Parameters
    ----------
    mol : RDKit molecule
        The molecule from which to extract a subgraph.
    selected_atoms : list of int
        The indices of atoms which form the subgraph to be extracted.

    Returns
    -------
    tuple
        A tuple containing:
        - RDKit molecule: The subgraph.
        - list of int: Root atom indices from the selected indices.
    """

    selected_atoms = set(selected_atoms)
    roots = []
    for idx in selected_atoms:
        atom = mol.GetAtomWithIdx(idx)
        bad_neis = [y for y in atom.GetNeighbors() if y.GetIdx() not in selected_atoms]
        if len(bad_neis) > 0:
            roots.append(idx)

    new_mol = Chem.RWMol(mol)

    for atom_idx in roots:
        atom = new_mol.GetAtomWithIdx(atom_idx)
        atom.SetAtomMapNum(1)
        aroma_bonds = [
            bond for bond in atom.GetBonds() if bond.GetBondType() == Chem.rdchem.BondType.AROMATIC
        ]
        aroma_bonds = [
            bond
            for bond in aroma_bonds
            if bond.GetBeginAtom().GetIdx() in selected_atoms
            and bond.GetEndAtom().GetIdx() in selected_atoms
        ]
        if len(aroma_bonds) == 0:
            atom.SetIsAromatic(False)

    remove_atoms = [
        atom.GetIdx() for atom in new_mol.GetAtoms() if atom.GetIdx() not in selected_atoms
    ]
    remove_atoms = sorted(remove_atoms, reverse=True)
    for atom in remove_atoms:
        new_mol.RemoveAtom(atom)

    return new_mol.GetMol(), roots

In [13]:
def extract_subgraph(smiles: str, selected_atoms: set[int]) -> tuple[str, list[int]]:
    """Extracts a subgraph from a SMILES given a set of atom indices.

    Parameters
    ----------
    smiles : str
        The SMILES string from which to extract a subgraph.
    selected_atoms : list of int
        The indices of atoms which form the subgraph to be extracted.

    Returns
    -------
    tuple
        A tuple containing:
        - str: SMILES representing the subgraph.
        - list of int: Root atom indices from the selected indices.
    """
    # try with kekulization
    mol = Chem.MolFromSmiles(smiles)
    Chem.Kekulize(mol)
    subgraph, roots = extract_subgraph_from_mol(mol, selected_atoms)
    try:
        subgraph = Chem.MolToSmiles(subgraph, kekuleSmiles=True)
        subgraph = Chem.MolFromSmiles(subgraph)
    except Exception:
        subgraph = None

    mol = Chem.MolFromSmiles(smiles)  # de-kekulize
    if subgraph is not None and mol.HasSubstructMatch(subgraph):
        return Chem.MolToSmiles(subgraph), roots

    # If fails, try without kekulization
    subgraph, roots = extract_subgraph_from_mol(mol, selected_atoms)
    subgraph = Chem.MolToSmiles(subgraph)
    subgraph = Chem.MolFromSmiles(subgraph)

    if subgraph is not None:
        return Chem.MolToSmiles(subgraph), roots
    else:
        return None, None

In [15]:
def mcts_rollout(
    node: MCTSNode,
    state_map: dict[str, MCTSNode],
    orig_smiles: str,
    clusters: list[set[int]],
    atom_cls: list[set[int]],
    nei_cls: list[set[int]],
    scoring_function: Callable[[list[str]], list[float]],
    min_atoms: int = 15,
    c_puct: float = 10.0,
) -> float:
    """A Monte Carlo Tree Search rollout from a given MCTSNode.

    Parameters
    ----------
    node : MCTSNode
        The MCTSNode from which to begin the rollout.
    state_map : dict
        A mapping from SMILES to MCTSNode.
    orig_smiles : str
        The original SMILES of the molecule.
    clusters : list
        Clusters of atoms.
    atom_cls : list
        Atom indices in the clusters.
    nei_cls : list
        Neighboring cluster indices.
    scoring_function : function
        A function for scoring subgraph SMILES using a Chemprop model.
    min_atoms : int
        The minimum number of atoms in a subgraph.
    c_puct : float
        The constant controlling the level of exploration.

    Returns
    -------
    float
        The score of this MCTS rollout.
    """
    # Return if the number of atoms is less than the minimum
    cur_atoms = node.atoms
    if len(cur_atoms) <= min_atoms:
        return node.P

    # Expand if this node has never been visited
    if len(node.children) == 0:
        # Cluster indices whose all atoms are present in current subgraph
        cur_cls = set([i for i, x in enumerate(clusters) if x <= cur_atoms])

        for i in cur_cls:
            # Leaf atoms are atoms that are only involved in one cluster.
            leaf_atoms = [a for a in clusters[i] if len(atom_cls[a] & cur_cls) == 1]

            # This checks
            # 1. If there is only one neighbor cluster in the current subgraph (so that we don't produce unconnected graphs), or
            # 2. If the cluster has only two atoms and the current subgraph has only one leaf atom.
            # If either of the conditions is met, remove the leaf atoms in the current cluster.
            if len(nei_cls[i] & cur_cls) == 1 or len(clusters[i]) == 2 and len(leaf_atoms) == 1:
                new_atoms = cur_atoms - set(leaf_atoms)
                new_smiles, _ = extract_subgraph(orig_smiles, new_atoms)
                if new_smiles in state_map:
                    new_node = state_map[new_smiles]  # merge identical states
                else:
                    new_node = MCTSNode(new_smiles, new_atoms)
                if new_smiles:
                    node.children.append(new_node)

        state_map[node.smiles] = node
        if len(node.children) == 0:
            return node.P  # cannot find leaves

        scores = scoring_function([x.smiles for x in node.children])
        for child, score in zip(node.children, scores):
            child.P = score

    sum_count = sum(c.N for c in node.children)
    selected_node = max(node.children, key=lambda x: x.Q() + x.U(sum_count, c_puct=c_puct))
    v = mcts_rollout(
        selected_node,
        state_map,
        orig_smiles,
        clusters,
        atom_cls,
        nei_cls,
        scoring_function,
        min_atoms=min_atoms,
        c_puct=c_puct,
    )
    selected_node.W += v
    selected_node.N += 1

    return v

In [17]:
def mcts(
    smiles: str,
    scoring_function: Callable[[list[str]], list[float]],
    n_rollout: int,
    max_atoms: int,
    prop_delta: float,
    min_atoms: int = 15,
    c_puct: int = 10,
) -> list[MCTSNode]:
    """Runs the Monte Carlo Tree Search algorithm.

    Parameters
    ----------
    smiles : str
        The SMILES of the molecule to perform the search on.
    scoring_function : function
        A function for scoring subgraph SMILES using a Chemprop model.
    n_rollout : int
        The number of MCTS rollouts to perform.
    max_atoms : int
        The maximum number of atoms allowed in an extracted rationale.
    prop_delta : float
        The minimum required property value for a satisfactory rationale.
    min_atoms : int
        The minimum number of atoms in a subgraph.
    c_puct : float
        The constant controlling the level of exploration.

    Returns
    -------
    list
        A list of rationales each represented by a MCTSNode.
    """

    mol = Chem.MolFromSmiles(smiles)

    clusters, atom_cls = find_clusters(mol)
    nei_cls = [0] * len(clusters)
    for i, cls in enumerate(clusters):
        nei_cls[i] = [nei for atom in cls for nei in atom_cls[atom]]
        nei_cls[i] = set(nei_cls[i]) - {i}
        clusters[i] = set(list(cls))
    for a in range(len(atom_cls)):
        atom_cls[a] = set(atom_cls[a])

    root = MCTSNode(smiles, set(range(mol.GetNumAtoms())))
    state_map = {smiles: root}
    for _ in range(n_rollout):
        mcts_rollout(
            root,
            state_map,
            smiles,
            clusters,
            atom_cls,
            nei_cls,
            scoring_function,
            min_atoms=min_atoms,
            c_puct=c_puct,
        )

    rationales = [
        node
        for _, node in state_map.items()
        if len(node.atoms) <= max_atoms and node.P >= prop_delta
    ]

    return rationales

In [19]:
# 模型文件路径（相对于当前工作目录）
model_path = Path("results") / "RANDOM_run1" / "replicate_3" / "model_2" / "best.pt"

# 加载模型
mpnn = models.MPNN.load_from_file(model_path)

mpnn  # 输出模型信息（可选）

MPNN(
  (message_passing): BondMessagePassing(
    (W_i): Linear(in_features=86, out_features=300, bias=False)
    (W_h): Linear(in_features=300, out_features=300, bias=False)
    (W_o): Linear(in_features=372, out_features=300, bias=True)
    (dropout): Dropout(p=0.3, inplace=False)
    (tau): ReLU()
    (V_d_transform): Identity()
    (graph_transform): GraphTransform(
      (V_transform): Identity()
      (E_transform): Identity()
    )
  )
  (agg): NormAggregation()
  (bn): Identity()
  (predictor): BinaryClassificationFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=300, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.3, inplace=False)
        (2): Linear(in_features=300, out_features=300, bias=True)
      )
      (2): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.3, inplace=False)
        (2): Linear(in_features=300, out_features=1, bias=True)
      )
    )
    (criterion): BCELoss(task

In [43]:
# 设置 CSV 文件路径
test_path = Path(r"D:\working_example\oseltamivir.csv")  # 使用原始字符串避免路径问题
smiles_column = "SMILES"  # 假设 CSV 中有一列名为 "smiles"

# 读取 CSV 文件
df_test = pd.read_csv(test_path)

# 获取 SMILES 列表
smiles_list = df_test[smiles_column].tolist()

# 查看数据
df_test



Unnamed: 0,SMILES,ACTIVITY
0,CC(N[C@@H]1[C@@H](N)C=C(O[C@H]1[C@H](OC(C)=O)[...,0.6


In [45]:
trainer = pl.Trainer(logger=None, enable_progress_bar=False, accelerator="cpu", devices=1)

You are using the plain ModelCheckpoint callback. Consider using LitModelCheckpoint which with seamless uploading to Model registry.
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [47]:
# MCTS options
rollout = 10  # number of MCTS rollouts to perform. If mol.GetNumAtoms() > 50, consider setting n_rollout = 1 to avoid long computation time

c_puct = 10.0  # constant that controls the level of exploration

max_atoms = 20  # maximum number of atoms allowed in an extracted rationale

min_atoms = 8  # minimum number of atoms in an extracted rationale

prop_delta = 0.5  # Minimum score to count as positive.
# In this algorithm, if the predicted property from the substructure if larger than prop_delta, the substructure is considered satisfactory.
# This value depends on the property you want to interpret. 0.5 is a dummy value for demonstration purposes

num_rationales_to_keep = 5  # number of rationales to keep for each molecule

In [49]:
# Define the scoring function. "Score" for a substructure is the predicted property value of the substructure.

models = [mpnn]

property_for_interpretation = "ACTIVITY"

property_id = (
    df_test.columns.get_loc(property_for_interpretation) - 1
)  # property index in the dataset; -1 for the SMILES column


def scoring_function(smiles: list[str]) -> list[float]:
    return make_prediction(
        models=models,
        trainer=trainer,
        smiles=smiles,
    )[:, property_id]

In [51]:
# only use the first 5 SMILES for demonstration purposes
all_smiles = df_test[smiles_column].tolist()
all_smiles

['CC(N[C@@H]1[C@@H](N)C=C(O[C@H]1[C@H](OC(C)=O)[C@H](OC(C)=O)COC(C)=O)C(OC)=O)=O']

In [53]:
%%time

results_df = {"smiles": [], property_for_interpretation: []}

for i in range(num_rationales_to_keep):
    results_df[f"rationale_{i}"] = []
    results_df[f"rationale_{i}_score"] = []

for smiles in all_smiles:
    score = scoring_function([smiles])[0]
    if score > prop_delta:
        rationales = mcts(
            smiles=smiles,
            scoring_function=scoring_function,
            n_rollout=rollout,
            max_atoms=max_atoms,
            prop_delta=prop_delta,
            min_atoms=min_atoms,
            c_puct=c_puct,
        )
    else:
        rationales = []

    results_df["smiles"].append(smiles)
    results_df[property_for_interpretation].append(score)

    if len(rationales) == 0:
        for i in range(num_rationales_to_keep):
            results_df[f"rationale_{i}"].append(None)
            results_df[f"rationale_{i}_score"].append(None)
    else:
        min_size = min(len(x.atoms) for x in rationales)
        min_rationales = [x for x in rationales if len(x.atoms) == min_size]
        rats = sorted(min_rationales, key=lambda x: x.P, reverse=True)

        for i in range(num_rationales_to_keep):
            if i < len(rats):
                results_df[f"rationale_{i}"].append(rats[i].smiles)
                results_df[f"rationale_{i}_score"].append(rats[i].P)
            else:
                results_df[f"rationale_{i}"].append(None)
                results_df[f"rationale_{i}_score"].append(None)

CPU times: total: 93.8 ms
Wall time: 19.1 ms


D:\Anaconda\envs\chemprop\Lib\site-packages\lightning\pytorch\core\saving.py:363: Skipping 'metrics' parameter because it is not possible to safely dump to YAML.
D:\Anaconda\envs\chemprop\Lib\site-packages\lightning\pytorch\trainer\connectors\data_connector.py:425: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=7` in the `DataLoader` to improve performance.


In [55]:
results_df = pd.DataFrame(results_df)
results_df

Unnamed: 0,smiles,ACTIVITY,rationale_0,rationale_0_score,rationale_1,rationale_1_score,rationale_2,rationale_2_score,rationale_3,rationale_3_score,rationale_4,rationale_4_score
0,CC(N[C@@H]1[C@@H](N)C=C(O[C@H]1[C@H](OC(C)=O)[...,0.215576,,,,,,,,,,
