In [1]:
import sys
sys.path.append("..")

from rdkit.Chem.Draw import IPythonConsole, MolsToGridImage

In [2]:
import lightning as L

import os

from typing import List
from src.utils.dtypes import BuildingBlock as BB
from src.utils.dtypes import CustomReaction as CR
from src.data.dataset import CRBBDataset, CRBBOutput


from torch.utils.data import DataLoader

In [3]:
bbs = BB.read_from_sdf("../data/raw/EnamineHighFidFrags.sdf")
rxns = CR.parse_txt("../data/raw/rxn.txt")

100%|██████████| 1920/1920 [00:03<00:00, 566.45it/s]


In [None]:
def make_dir(path: str):
    if not os.path.exists(path):
        os.mkdir(path)
    
class CRBBDataModule(L.LightningDataModule):
    def __init__(self, 
                 root: str = 'crbb', 
                 batch_size=32,
                 rxns: List[CR] = list(),
                 rxn_files: List[str] = list(),
                 reactant_bbs: List[BB] = list(),
                 reactant_sdfs: List[str] = list(),
                 train_ligs: List[BB] = list(),
                 val_ligs: List[BB] = list(),
                 test_ligs: List[BB] = list(),
                 train_ligsdfs: List[str] = list(),
                 val_ligsdfs: List[str] = list(),
                 test_ligsdfs: List[str] = list(),
                 force_reload: bool = False):
        
        super().__init__()
        self.root = root
        self.train_root = os.path.join(root, 'train')
        self.val_root = os.path.join(root, 'val')
        self.test_root = os.path.join(root, 'test')
        [make_dir(r) for r in [self.root, self.train_root, 
                               self.val_root, self.test_root]]
        
        self.batch_size = batch_size
        self.force_reload = force_reload

        self.rxns = rxns if rxns is not None else []
        self.rxn_files = rxn_files

        self.reactant_bbs = reactant_bbs
        self.reactant_sdfs = reactant_sdfs

        self.train_ligs = train_ligs
        self.train_ligsdfs = train_ligsdfs

        self.val_ligs = val_ligs
        self.val_ligsdfs = val_ligsdfs

        self.test_ligs = test_ligs
        self.test_ligsdfs = test_ligsdfs

    def setup(self, stage=None):
        if stage == 'fit' or stage is None:
            self.train_dataset = CRBBDataset(root=self.train_root,
                                             rxns=self.rxns,
                                             rxn_files=self.rxn_files,
                                             bbs=self.reactant_bbs,
                                             bb_files=self.reactant_sdfs,
                                             ligs=self.train_ligs,
                                             lig_files=self.train_ligsdfs,
                                             force_reload=self.force_reload)
            self.val_dataset = CRBBDataset(root=self.val_root,
                                            rxns=self.rxns,
                                            rxn_files=self.rxn_files,
                                            bbs=self.reactant_bbs,
                                            bb_files=self.reactant_sdfs,
                                            ligs=self.val_ligs,
                                            lig_files=self.val_ligsdfs,
                                            force_reload=self.force_reload)
        
        if stage == 'test' or stage is None:
            self.test_dataset = CRBBDataset(root=self.test_root,
                                            rxns=self.rxns,
                                            rxn_files=self.rxn_files,
                                            bbs=self.reactant_bbs,
                                            bb_files=self.reactant_sdfs,
                                            ligs=self.test_ligs,
                                            lig_files=self.test_ligsdfs,
                                            force_reload=self.force_reload)
    
    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, 
                          shuffle=True, collate_fn=CRBBOutput.collate_fn)
    
    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size,
                          shuffle=False, collate_fn=CRBBOutput.collate_fn)


datamodule = CRBBDataModule(rxns=rxns,
                            reactant_bbs=bbs[:100],
                            train_ligs=bbs[101:200],
                            val_ligs=bbs[201:300],
                            force_reload=True)

In [6]:
datamodule.setup()

Creating building block dataset...


Processing...


Loading building blocks...
Creating conformers for building blocks...


100%|██████████| 100/100 [00:00<00:00, 473.19it/s]


Featurizing building blocks...


100%|██████████| 100/100 [00:00<00:00, 586.24it/s]

Featurizing building block conformers...



100%|██████████| 100/100 [00:00<00:00, 151.37it/s]
Done!
Processing...


Creating reaction dataset...
Loading reactions...
Featurizing reactions...


100%|██████████| 2/2 [00:00<00:00, 4488.29it/s]

Creating ligand dataset...



Done!
Processing...


Loading building blocks...
Creating conformers for building blocks...


100%|██████████| 99/99 [00:00<00:00, 542.64it/s]


Featurizing building blocks...


100%|██████████| 99/99 [00:00<00:00, 715.57it/s]

Featurizing building block conformers...



100%|██████████| 99/99 [00:00<00:00, 189.81it/s]
Done!
Processing...


Matching building blocks (and ligs) and reactions...


100%|██████████| 2/2 [00:00<00:00, 26.10it/s]
100%|██████████| 100/100 [00:00<00:00, 5294.03it/s]
100%|██████████| 99/99 [00:00<00:00, 4841.95it/s]


Constructing searcher for every rxn reactant...


100%|██████████| 1/1 [00:00<00:00, 215.17it/s]


Constructing conformer searcher for every building block...


100%|██████████| 100/100 [00:00<00:00, 3248.48it/s]
Done!


Creating building block dataset...


Processing...


Loading building blocks...
Creating conformers for building blocks...


100%|██████████| 100/100 [00:00<00:00, 509.74it/s]


Featurizing building blocks...


100%|██████████| 100/100 [00:00<00:00, 597.26it/s]


Featurizing building block conformers...


100%|██████████| 100/100 [00:00<00:00, 209.89it/s]
Done!
Processing...


Creating reaction dataset...
Loading reactions...
Featurizing reactions...


100%|██████████| 2/2 [00:00<00:00, 3855.06it/s]

Creating ligand dataset...



Done!
Processing...


Loading building blocks...
Creating conformers for building blocks...


100%|██████████| 99/99 [00:00<00:00, 345.88it/s]


Featurizing building blocks...


100%|██████████| 99/99 [00:00<00:00, 688.71it/s]


Featurizing building block conformers...


100%|██████████| 99/99 [00:00<00:00, 164.17it/s]
Done!
Processing...


Matching building blocks (and ligs) and reactions...


100%|██████████| 2/2 [00:00<00:00, 17.16it/s]
100%|██████████| 100/100 [00:00<00:00, 3646.14it/s]
100%|██████████| 99/99 [00:00<00:00, 3337.78it/s]


Constructing searcher for every rxn reactant...


100%|██████████| 1/1 [00:00<00:00, 302.64it/s]


Constructing conformer searcher for every building block...


100%|██████████| 100/100 [00:00<00:00, 2980.69it/s]
Done!


In [7]:
loader = datamodule.train_dataloader()

In [8]:
next(iter(loader))

CRBBOutput(lig0_feats=(DataBatch(x=[703, 6], edge_index=[2, 1428], y=[32, 12], batch=[703], ptr=[33]), DataBatch(y=[32, 11], pos=[703, 3], batch=[703], ptr=[33])), rsfeats_2d=[DataBatch(x=[758, 6], edge_index=[2, 1534], y=[32, 12], batch=[758], ptr=[33]), DataBatch(x=[696, 6], edge_index=[2, 1416], y=[32, 12], batch=[696], ptr=[33]), DataBatch(x=[758, 6], edge_index=[2, 1534], y=[32, 12], batch=[758], ptr=[33])], rsfeats_3d=[DataBatch(y=[32, 11], pos=[758, 3], batch=[758], ptr=[33]), DataBatch(y=[32, 11], pos=[696, 3], batch=[696], ptr=[33]), DataBatch(y=[32, 11], pos=[758, 3], batch=[758], ptr=[33])], rslengths=tensor([2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
        2, 2, 2, 2, 2, 2, 2, 2]), terminate=tensor([[False,  True, False],
        [ True, False, False],
        [False,  True, False],
        [ True, False, False],
        [False,  True, False],
        [False,  True, False],
        [False,  True, False],
        [ True, False, False],
        