In [1]:
from abc import abstractmethod
from dataclasses import InitVar, dataclass, field
from pathlib import Path
from typing import Iterable, NamedTuple, Sequence, TypeAlias

from lightning import pytorch as pl
import numpy as np
import pandas as pd
from rdkit.Chem import Mol
import torch
from torch import Tensor, nn
from torch.utils.data import DataLoader, Dataset

from chemprop.data.collate import BatchMolGraph, collate_batch
from chemprop.data.datapoints import MoleculeDatapoint
from chemprop.data.datasets import Datum, MoleculeDataset, MulticomponentDataset, ReactionDataset
from chemprop.data.molgraph import MolGraph
from chemprop.data.splitting import make_split_indices, split_data_by_indices
from chemprop.featurizers import Featurizer, SimpleMoleculeMolGraphFeaturizer
from chemprop.models import MulticomponentMPNN, multi
from chemprop.nn.agg import Aggregation, MeanAggregation
from chemprop.nn.hparams import HasHParams
from chemprop.nn.message_passing import BondMessagePassing, MulticomponentMessagePassing
from chemprop.nn.metrics import ChempropMetric
from chemprop.nn.predictors import Predictor, RegressionFFN
from chemprop.nn.transforms import ScaleTransform, UnscaleTransform

We first extend the `MolGraph` class to include a global attribute `w_fp` which is the weight of the learned fingerprint of the molecule when averaging the fingerprints of components in the mixture.

In [2]:
# See also chemprop.data.molgraph.MolGraph
class ComponentMolGraph(NamedTuple):
    V: np.ndarray
    E: np.ndarray
    edge_index: np.ndarray
    rev_edge_index: np.ndarray
    w_fp: float = 1.0
    """the weight of the component's fingerprint when combining components in the mixture"""


# See also chemprop.data.datasets.Datum
class ComponentDatum(NamedTuple):
    mg: ComponentMolGraph
    V_d: np.ndarray | None
    x_d: np.ndarray | None
    y: np.ndarray | None
    weight: float
    lt_mask: np.ndarray | None
    gt_mask: np.ndarray | None

The batched versions of `MolGraph` and `Datum` are created during collating datapoints. These are also extended, as well as the entire batch representing all components in the mixture.

In [3]:
# See also chemprop.data.collate.BatchMolGraph
@dataclass(repr=False, eq=False, slots=True)
class BatchComponentMolGraph(BatchMolGraph):
    mgs: InitVar[Sequence[ComponentMolGraph]]
    w_fps: Tensor = field(init=False)

    def __post_init__(self, mgs):
        super(BatchComponentMolGraph, self).__post_init__(mgs)
        self.w_fps = torch.from_numpy(np.array([mg.w_fp for mg in mgs])).float()

    def to(self, device: str | torch.device):
        super(BatchComponentMolGraph, self).to(device)
        self.w_fps = self.w_fps.to(device)


# See also chemprop.data.collate.TrainingBatch
class BatchComponentDatum(NamedTuple):
    bmg: BatchComponentMolGraph
    V_d: Tensor | None
    X_d: Tensor | None
    Y: Tensor | None
    w: Tensor
    lt_mask: Tensor | None
    gt_mask: Tensor | None


# See also chemprop.data.collate.MulticomponentTrainingBatch
class MixtureBatch(NamedTuple):
    bmgs: list[BatchMolGraph | BatchComponentMolGraph]
    V_ds: list[Tensor | None]
    X_d: Tensor | None
    Y: Tensor | None
    w: Tensor
    lt_mask: Tensor | None
    gt_mask: Tensor | None

In [4]:
# See also chemprop.data.collate.collate_batch
def collate_component(batch: Iterable[Datum]) -> BatchComponentDatum:
    mgs, V_ds, x_ds, ys, weights, lt_masks, gt_masks = zip(*batch)

    return BatchComponentDatum(
        BatchComponentMolGraph(mgs),
        None if V_ds[0] is None else torch.from_numpy(np.concatenate(V_ds)).float(),
        None if x_ds[0] is None else torch.from_numpy(np.array(x_ds)).float(),
        None if ys[0] is None else torch.from_numpy(np.array(ys)).float(),
        torch.tensor(weights, dtype=torch.float).unsqueeze(1),
        None if lt_masks[0] is None else torch.from_numpy(np.array(lt_masks)),
        None if gt_masks[0] is None else torch.from_numpy(np.array(gt_masks)),
    )


# See also chemprop.data.collate.collate_multicomponent
def collate_mixture(batches: Iterable[Iterable[ComponentDatum | Datum]]) -> MixtureBatch:
    tbs = [
        collate_batch(batch) if isinstance(batch[0], Datum) else collate_component(batch)
        for batch in zip(*batches)
    ]

    return MixtureBatch(
        [tb.bmg for tb in tbs],
        [tb.V_d for tb in tbs],
        tbs[0].X_d,
        tbs[0].Y,
        tbs[0].w,
        tbs[0].lt_mask,
        tbs[0].gt_mask,
    )

In [5]:
@dataclass
class ComponentDatapoint(MoleculeDatapoint):
    w_fp: np.ndarray | None = None
    """the weight of the molecule's learned fingerprint when averaging in the mixture"""


@dataclass
class ComponentDataset(MoleculeDataset, Dataset[ComponentMolGraph]):
    data: list[ComponentDatapoint]

    @property
    def w_fps(self) -> np.ndarray:
        return np.array([d.w_fp for d in self.data])

    def __getitem__(self, idx: int) -> ComponentDatum:
        d = self.data[idx]
        mg = self.mg_cache[idx]
        mg = ComponentMolGraph(w_fp=d.w_fp, *mg)

        return ComponentDatum(
            mg, self.V_ds[idx], self.X_d[idx], self.Y[idx], d.weight, d.lt_mask, d.gt_mask
        )


@dataclass(repr=False, eq=False)
class MixtureDataset(MulticomponentDataset):
    datasets: list[MoleculeDataset | ReactionDataset | ComponentDataset]

    def __getitem__(self, idx: int) -> list[ComponentDatum | Datum]:
        return [dset[idx] for dset in self.datasets]

In [6]:
class MixtureAggregation(nn.Module, HasHParams):
    output_dim: int

    def __init__(
        self, graph_agg: Aggregation, groups: Sequence[Sequence[int]], fp_dims: Sequence[int]
    ):
        super().__init__()
        self.hparams = {
            "cls": self.__class__,
            "graph_agg": graph_agg,
            "groups": groups,
            "fp_dims": fp_dims,
        }
        self.graph_agg = graph_agg
        self.groups = groups
        self.fp_dims = fp_dims

    @abstractmethod
    def forward(
        self, Hs: list[Tensor], bmgs: list[BatchComponentMolGraph | BatchMolGraph]
    ) -> Tensor:
        """Aggregate component representations into a mixture representation"""


class WeightedSumAggregation(MixtureAggregation):
    @property
    def output_dim(self) -> int:
        return sum(self.fp_dims[group[0]] for group in self.groups)

    def forward(
        self, H_vs: list[Tensor], bmgs: list[BatchComponentMolGraph | BatchMolGraph]
    ) -> Tensor:
        Hs = [self.graph_agg(H_v, bmg.batch) for H_v, bmg in zip(H_vs, bmgs)]
        combined_Hs = []
        for group in self.groups:
            if len(group) == 1:
                combined_Hs.append(Hs[group[0]])
                continue

            group_Hs = torch.stack([Hs[idx] for idx in group])  # n x b x d
            group_w_fps = torch.stack([bmgs[idx].w_fps for idx in group])  # n x b
            # n: num. components in group, b: num. comp. in batch, d: output dim of message passing
            combined_H = torch.einsum("nb,nbd->bd", group_w_fps, group_Hs)
            combined_Hs.append(combined_H)
        return torch.cat(combined_Hs, 1)


class ConcatAggregation(MixtureAggregation):
    @property
    def components_in_mixture(self) -> set[int]:
        return {idx for group in self.groups if len(group) > 1 for idx in group}

    @property
    def output_dim(self) -> int:
        return sum(self.fp_dims) + len(self.components_in_mixture)

    def forward(
        self, H_vs: list[Tensor], bmgs: list[BatchComponentMolGraph | BatchMolGraph]
    ) -> Tensor:
        Hs = [self.graph_agg(H_v, bmg.batch) for H_v, bmg in zip(H_vs, bmgs)]
        w_fps = torch.stack([bmgs[idx].w_fps for idx in self.components_in_mixture], dim=1)
        return torch.cat(Hs + [w_fps], 1)

In [7]:
class MixtureMPNN(MulticomponentMPNN):
    def __init__(
        self,
        message_passing: MulticomponentMessagePassing,
        agg: MixtureAggregation,
        predictor: Predictor,
        batch_norm: bool = False,
        metrics: Iterable[ChempropMetric] | None = None,
        warmup_epochs: int = 2,
        init_lr: float = 1e-4,
        max_lr: float = 1e-3,
        final_lr: float = 1e-4,
        X_d_transform: ScaleTransform | None = None,
    ):
        super().__init__(
            message_passing,
            agg,
            predictor,
            batch_norm,
            metrics,
            warmup_epochs,
            init_lr,
            max_lr,
            final_lr,
            X_d_transform,
        )
        self.agg: MixtureAggregation

    def fingerprint(
        self,
        bmgs: Iterable[BatchComponentMolGraph | BatchMolGraph],
        V_ds: Iterable[Tensor],
        X_d: Tensor | None = None,
    ) -> Tensor:
        H_vs: list[Tensor] = self.message_passing(bmgs, V_ds)
        H = self.agg(H_vs, bmgs)
        H = self.bn(H)
        return H if X_d is None else torch.cat((H, self.X_d_transform(X_d)), 1)

    @classmethod
    def _load(cls, path, map_location, **submodules):
        d = torch.load(path, map_location, weights_only=False)

        try:
            hparams = d["hyper_parameters"]
            state_dict = d["state_dict"]
        except KeyError:
            raise KeyError(f"Could not find hyper parameters and/or state dict in {path}.")

        hparams["message_passing"]["blocks"] = [
            block_hparams.pop("cls")(**block_hparams)
            for block_hparams in hparams["message_passing"]["blocks"]
        ]
        graph_agg_hparams = hparams["agg"]["graph_agg"]
        hparams["agg"]["graph_agg"] = graph_agg_hparams.pop("cls")(**graph_agg_hparams)
        submodules |= {
            key: hparams[key].pop("cls")(**hparams[key])
            for key in ("message_passing", "agg", "predictor")
            if key not in submodules
        }

        if not hasattr(submodules["predictor"].criterion, "_defaults"):
            submodules["predictor"].criterion = submodules["predictor"].criterion.__class__(
                task_weights=submodules["predictor"].criterion.task_weights
            )

        return submodules, state_dict, hparams

In [8]:
chemprop_dir = Path.cwd().parent
input_path = (
    chemprop_dir / "tests" / "data" / "regression" / "mol+mol" / "mol+mol.csv"
)  # path to your data .csv file containing SMILES strings and target values
smiles_columns = ["smiles", "solvent"]  # name of the column containing SMILES strings
target_columns = ["peakwavs_max"]  # list of names of the columns containing targets
df_input = pd.read_csv(input_path)
smiss = df_input.loc[:, smiles_columns].values
ys = df_input.loc[:, target_columns].values

In [9]:
all_data = [[MoleculeDatapoint.from_smi(smis[0], y) for smis, y in zip(smiss, ys)]]
all_data += [[ComponentDatapoint.from_smi(smis[0], w_fp=0.1) for smis in smiss]]
all_data += [[ComponentDatapoint.from_smi(smis[1], w_fp=0.9) for smis in smiss]]

In [10]:
component_to_split_by = 0  # index of the component to use for structure based splits
mols = [d.mol for d in all_data[component_to_split_by]]
train_indices, val_indices, test_indices = make_split_indices(mols, "random", (0.8, 0.1, 0.1))
train_data, val_data, test_data = split_data_by_indices(
    all_data, train_indices, val_indices, test_indices
)
train_data = train_data[0]
val_data = val_data[0]
test_data = test_data[0]

The return type of make_split_indices has changed in v2.1 - see help(make_split_indices)


In [11]:
train_datasets = [
    MoleculeDataset(train_data[0]),
    ComponentDataset(train_data[1]),
    ComponentDataset(train_data[2]),
]
val_datasets = [
    MoleculeDataset(val_data[0]),
    ComponentDataset(val_data[1]),
    ComponentDataset(val_data[2]),
]
test_datasets = [
    MoleculeDataset(test_data[0]),
    ComponentDataset(test_data[1]),
    ComponentDataset(test_data[2]),
]

In [12]:
train_mcdset = MixtureDataset(train_datasets)
scaler = train_mcdset.normalize_targets()
val_mcdset = MixtureDataset(val_datasets)
val_mcdset.normalize_targets(scaler)
test_mcdset = MixtureDataset(test_datasets)

In [13]:
train_loader = DataLoader(train_mcdset, batch_size=10, shuffle=True, collate_fn=collate_mixture)
val_loader = DataLoader(val_mcdset, batch_size=10, shuffle=False, collate_fn=collate_mixture)
test_loader = DataLoader(test_mcdset, batch_size=10, shuffle=False, collate_fn=collate_mixture)

In [14]:
mcmp = MulticomponentMessagePassing(blocks=[BondMessagePassing()], n_components=3, shared=True)

graph_agg = MeanAggregation()
mixagg = WeightedSumAggregation(
    graph_agg=graph_agg, groups=[[0], [1, 2]], fp_dims=[mcmp.blocks[0].output_dim] * 3
)
mixagg = ConcatAggregation(
    graph_agg=graph_agg, groups=[[0], [1, 2]], fp_dims=[mcmp.blocks[0].output_dim] * 3
)

output_transform = UnscaleTransform.from_standard_scaler(scaler)
ffn = RegressionFFN(input_dim=mixagg.output_dim, output_transform=output_transform)

In [15]:
mcmpnn = MixtureMPNN(mcmp, mixagg, ffn)
mcmpnn

MixtureMPNN(
  (message_passing): MulticomponentMessagePassing(
    (blocks): ModuleList(
      (0-2): 3 x BondMessagePassing(
        (W_i): Linear(in_features=86, out_features=300, bias=False)
        (W_h): Linear(in_features=300, out_features=300, bias=False)
        (W_o): Linear(in_features=372, out_features=300, bias=True)
        (dropout): Dropout(p=0.0, inplace=False)
        (tau): ReLU()
        (V_d_transform): Identity()
        (graph_transform): Identity()
      )
    )
  )
  (agg): ConcatAggregation(
    (graph_agg): MeanAggregation()
  )
  (bn): Identity()
  (predictor): RegressionFFN(
    (ffn): MLP(
      (0): Sequential(
        (0): Linear(in_features=902, out_features=300, bias=True)
      )
      (1): Sequential(
        (0): ReLU()
        (1): Dropout(p=0.0, inplace=False)
        (2): Linear(in_features=300, out_features=1, bias=True)
      )
    )
    (criterion): MSE(task_weights=[[1.0]])
    (output_transform): UnscaleTransform()
  )
  (X_d_transform): Ide

In [None]:
trainer = pl.Trainer(
    logger=False,
    enable_checkpointing=True,
    enable_progress_bar=True,
    accelerator="auto",
    devices=1,
    max_epochs=50,
)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs


In [17]:
trainer.fit(mcmpnn, train_loader, val_loader)

/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/knathan/chemprop/examples/checkpoints exists and is not empty.
Loading `train_dataloader` to estimate number of stepping batches.
/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.

  | Name            | Type                         | Params | Mode 
-------------------------------------------------------------------------
0 | message_passing | MulticomponentMessagePassing | 227 K  | train
1 | agg             | ConcatAggregation            | 0      | train
2 | bn              | Identity                     | 0      | train
3 | predictor       | RegressionFFN            

Sanity Checking DataLoader 0:   0%|          | 0/1 [00:00<?, ?it/s]

/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Epoch 99: 100%|██████████| 8/8 [00:01<00:00,  7.02it/s, train_loss_step=0.0073, val_loss=0.588, train_loss_epoch=0.00598]  

`Trainer.fit` stopped: `max_epochs=100` reached.


Epoch 99: 100%|██████████| 8/8 [00:01<00:00,  6.83it/s, train_loss_step=0.0073, val_loss=0.588, train_loss_epoch=0.00598]


In [18]:
results = trainer.test(mcmpnn, test_loader)

/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Testing DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 29.05it/s]


In [19]:
results = trainer.predict(mcmpnn, test_loader)

/home/knathan/anaconda3/envs/chemprop/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:424: The 'predict_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=11` in the `DataLoader` to improve performance.


Predicting DataLoader 0: 100%|██████████| 1/1 [00:00<00:00, 31.32it/s]


In [20]:
results

[tensor([[398.7043],
         [650.7003],
         [377.3484],
         [395.7841],
         [389.0737],
         [438.4164],
         [362.5379],
         [378.0801],
         [441.8935],
         [462.4413]])]

In [21]:
test_loader.dataset.datasets[0].Y

array([[384. ],
       [553. ],
       [394. ],
       [428.2],
       [386. ],
       [369. ],
       [520. ],
       [515. ],
       [313. ],
       [480. ]])