In [None]:
import os
import csv
import math
from tqdm import tqdm
from collections import defaultdict

import torch
from torchdrug import datasets, utils, models, tasks, core
from torchdrug.data import Molecule

from rdkit import Chem
from retflow import config

In [None]:
class USPTO(datasets.USPTO50k):
    def __init__(self, path, as_synthon=False, verbose=1, **kwargs):
        path = os.path.expanduser(path)
        if not os.path.exists(path):
            os.makedirs(path)
        self.path = path
        self.as_synthon = as_synthon

        self.load_csv(
            self.path,
            smiles_field="reactants>reagents>production",
            target_fields=self.target_fields,
            verbose=verbose,
            **kwargs
        )

        if as_synthon:
            prefix = "Computing synthons"
            process_fn = self._get_synthon
        else:
            prefix = "Computing reaction centers"
            process_fn = self._get_reaction_center

        data = self.data
        targets = self.targets
        ids = self.ids
        self.data = []
        self.targets = defaultdict(list)
        self.ids = []

        indexes = range(len(data))
        if verbose:
            indexes = tqdm(indexes, prefix)
        invalid = 0
        for i in indexes:
            reactant, product = data[i]
            reactant.bond_stereo[:] = 0
            product.bond_stereo[:] = 0
            reactant = self._convert_reactant_molecule(reactant, product, kwargs["atom_feature"])
            reactants, products = process_fn(reactant, product)
            
            if not reactants:
                invalid += 1
                continue

            self.data += zip(reactants, products)
            for k in targets:
                new_k = self.target_alias.get(k, k)
                self.targets[new_k] += [targets[k][i] - 1] * len(reactants)
            self.targets["sample id"] += [i] * len(reactants)
            self.ids += [ids[i]] * len(reactants)

        self.valid_rate = 1 - invalid / len(data)

    def _convert_reactant_molecule(self, reactant, product, atom_feature):
        reactant = reactant.to_molecule()
        s = max(product.atom_map)
        for atom in reactant.GetAtoms():
            if atom.GetAtomMapNum() > s:
                atom.SetAtomMapNum(0)
        rmol2_s = Chem.MolToSmiles(reactant, canonical=True)
        return Molecule.from_molecule(Chem.MolFromSmiles(rmol2_s), atom_feature=atom_feature)

    def _get_difference(self, reactant, product):
        product2id = product.atom_map
        id2reactant = torch.zeros(product2id.max() + 1, dtype=torch.long)
        id2reactant[reactant.atom_map] = torch.arange(reactant.num_node)
        prod2react = id2reactant[product2id]

        # check edges in the product
        product = product.directed()
        # O(n^2) brute-force match is faster than O(nlogn) data.Graph.match for small molecules
        mapped_edge = product.edge_list.clone()
        mapped_edge[:, :2] = prod2react[mapped_edge[:, :2]]
        is_same_index = mapped_edge.unsqueeze(0) == reactant.edge_list.unsqueeze(1)
        has_typed_edge = is_same_index.all(dim=-1).any(dim=0)
        has_edge = is_same_index[:, :, :2].all(dim=-1).any(dim=0)
        is_added = ~has_edge
        is_modified = has_edge & ~has_typed_edge
        edge_added = product.edge_list[is_added, :2]
        edge_modified = product.edge_list[is_modified, :2]

        return edge_added, edge_modified, prod2react

    def load_csv(
        self, csv_file, smiles_field="smiles", target_fields=None, verbose=0, **kwargs
    ):
        """
        Load the dataset from a csv file.

        Parameters:
            csv_file (str): file name
            smiles_field (str, optional): name of the SMILES column in the table.
                Use ``None`` if there is no SMILES column.
            target_fields (list of str, optional): name of target columns in the table.
                Default is all columns other than the SMILES column.
            verbose (int, optional): output verbose level
            **kwargs
        """
        if target_fields is not None:
            target_fields = set(target_fields)
        self.ids = []
        with open(csv_file, "r") as fin:
            reader = csv.reader(fin)
            if verbose:
                reader = iter(
                    tqdm(
                        reader, "Loading %s" % csv_file, utils.get_line_count(csv_file)
                    )
                )
            fields = next(reader)
            smiles = []
            targets = defaultdict(list)
            for values in reader:
                if not any(values):
                    continue
                if smiles_field is None:
                    smiles.append("")
                for field, value in zip(fields, values):
                    if field == smiles_field:
                        smiles.append(value)
                    elif target_fields is None or field in target_fields:
                        value = utils.literal_eval(value)
                        if value == "":
                            value = math.nan
                        targets[field].append(value)
                    if field == "id":
                        self.ids.append(value)
        self.load_smiles(smiles, targets, verbose=verbose, **kwargs)

In [12]:
train_dataset = USPTO(config.get_dataset_directory() / "USPTO" / "raw" / "uspto50k_train.csv", kekulize=False, atom_feature="center_identification")
val_dataset = USPTO(config.get_dataset_directory() / "USPTO" / "raw" / "uspto50k_val.csv", kekulize=False, atom_feature="center_identification")
test_dataset = USPTO(config.get_dataset_directory() / "USPTO" / "raw" / "uspto50k_test.csv", kekulize=False, atom_feature="center_identification")

Loading /bigdata/robiny/retro_workspace/datasets/USPTO/raw/uspto50k_train.csv:  51%|█████▏    | 20568/40009 [00:00<00:00, 205672.35it/s]

Loading /bigdata/robiny/retro_workspace/datasets/USPTO/raw/uspto50k_train.csv: 100%|██████████| 40009/40009 [00:00<00:00, 207974.07it/s]
Constructing molecules from SMILES: 100%|██████████| 40008/40008 [01:10<00:00, 569.04it/s]
Computing reaction centers: 100%|██████████| 40008/40008 [01:44<00:00, 381.24it/s]
Loading /bigdata/robiny/retro_workspace/datasets/USPTO/raw/uspto50k_val.csv: 100%|██████████| 5002/5002 [00:00<00:00, 30937.29it/s]
Constructing molecules from SMILES: 100%|██████████| 5001/5001 [00:08<00:00, 574.14it/s]
Computing reaction centers: 100%|██████████| 5001/5001 [00:13<00:00, 378.75it/s]
Loading /bigdata/robiny/retro_workspace/datasets/USPTO/raw/uspto50k_test.csv: 100%|██████████| 5008/5008 [00:00<00:00, 72886.71it/s]
Constructing molecules from SMILES: 100%|██████████| 5007/5007 [00:10<00:00, 483.13it/s]
Computing reaction centers: 100%|██████████| 5007/5007 [00:13<00:00, 380.16it/s]


In [13]:
reaction_model = models.RGCN(input_dim=43, hidden_dims=[256, 256, 256, 256], num_relation=4, concat_hidden=True, short_cut=True)
reaction_task = tasks.CenterIdentification(reaction_model, feature=("graph", "atom", "bond"))
reaction_task.preprocess(train_dataset, val_dataset, test_dataset)
reaction_task = reaction_task.to("cuda")

In [None]:
reaction_optimizer = torch.optim.AdamW(reaction_task.parameters(), lr=1e-4)
reaction_solver = core.Engine(reaction_task, train_dataset, val_dataset,
                              test_dataset, reaction_optimizer,
                              gpus=[0], batch_size=128)
reaction_solver.train(num_epoch=100)
reaction_solver.evaluate("test")

In [None]:
reaction_solver.save(config.get_models_directory() / "g2g_center.pth")