Now we need to specificy how these Data objects within the created Dataset are split into training, validation and test sets. This is where PyTorch Lightning's DataModule comes in. The DataModule is a class that encapsulates the logic for loading, batching and splitting the data. It's a way of keeping the data loading and batching logic separate from the model and training logic, which makes the code more modular and easier to maintain. It also makes it easier to switch between different datasets and data loading strategies.

In [None]:
import lightning as L
from torch.utils.data import random_split
from torch_geometric.loader import DataLoader
from src import dataloader

class ProteinGraphDataModule(L.LightningDataModule):
    def __init__(self, root, dataset_file, pre_transform, transform):
        super().__init__()
        self.root = root
        self.dataset_file = dataset_file
        with open(dataset_file) as f:
            self.protein_names = [line.strip() for line in f]
        self.protein_names = self.protein_names[:10] # SMALL DATASET FOR TESTING
        self.pre_transform = pre_transform
        self.transform = transform

    def prepare_data(self):
        # download, IO, etc. Useful with shared filesystems
        # only called on 1 GPU/TPU in distributed settings
        # does the downloading and saving of graphein graphs part, just once
        dataloader.ProteinDataset(root=self.root,
                                  protein_names=self.protein_names,
                                  pre_transform=self.pre_transform, 
                                  transform=self.transform)
    

    def setup(self, stage):
        # make assignments here (val/train/test split)
        # called on every process in DDP
        # now it's just loaded and not downloaded processed etc.
        dataset = dataloader.ProteinDataset(root=self.root,
                                  protein_names=self.protein_names,
                                  pre_transform=self.pre_transform, 
                                  transform=self.transform)
        train_idx, val_idx, test_idx = random_split(range(len(dataset)), [0.8, 0.1, 0.1])
        self.train, self.val, self.test = dataset[list(train_idx)], dataset[list(val_idx)], dataset[list(test_idx)]

    def train_dataloader(self):
        return DataLoader(self.train)

    def val_dataloader(self):
        return DataLoader(self.val)

    def test_dataloader(self):
        return DataLoader(self.test)

TODO: explain dataloaders, batching etc.
TODO: add batch_size num_workers
TODO: how would it change if you had a predefined train/val/test split
TODO: make a datamodule and loop through the dataloaders to show how it works