In [2]:
import sys
sys.path.append("..")

from rdkit.Chem.Draw import IPythonConsole, MolsToGridImage

In [5]:
import lightning as L

from typing import List
from src.utils.dtypes import BuildingBlock as BB
from src.utils.dtypes import CustomReaction as CR
from src.data.dataset import CRBBDataset

In [None]:
class CRBBDataModule(L.LightningDataModule):
    def __init__(self, 
                 root: str = 'crbb', 
                 batch_size=32,
                 rxns: List[CR] = list(),
                 rxn_files: List[str] = list(),
                 reactant_bbs: List[BB] = list(),
                 reactant_sdfs: List[str] = list(),
                 train_ligs: List[BB] = list(),
                 val_ligs: List[BB] = list(),
                 test_ligs: List[BB] = list(),
                 train_ligsdfs: List[str] = list(),
                 val_ligsdfs: List[str] = list(),
                 test_ligsdfs: List[str] = list(),):
        
        super().__init__()
        self.root = root
        self.batch_size = batch_size

        self.rxns = rxns if rxns is not None else []
        self.rxn_files = rxn_files

        self.reactant_bbs = reactant_bbs
        self.reactant_sdfs = reactant_sdfs

        self.train_ligs = train_ligs
        self.train_ligsdfs = train_ligsdfs

        self.val_ligs = val_ligs
        self.val_ligsdfs = val_ligsdfs

        self.test_ligs = test_ligs
        self.test_ligsdfs = test_ligsdfs

    def setup(self, stage=None):
        if stage == 'fit':
            self.train_dataset = CRBBDataset(root=self.root,
                                             rxns=self.rxns,
                                             rxn_files=self.rxn_files,
                                             bbs=self.reactant_bbs,
                                             bb_files=self.reactant_sdfs,
                                             ligs=self.train_ligs,
                                             lig_files=self.train_ligsdfs)
            self.val_dataset = CRBBDataset(root=self.root,
                                            rxns=self.rxns,
                                            rxn_files=self.rxn_files,
                                            bbs=self.reactant_bbs,
                                            bb_files=self.reactant_sdfs,
                                            ligs=self.val_ligs,
                                            lig_files=self.val_ligsdfs)

    def train_dataloader(self):
        return L.DataLoader(self.train_dataset, batch_size=self.batch_size)