In [128]:
import pandas as pd
import itertools
from pathlib import Path
from monai.data import Dataset
from pytorch_lightning import LightningDataModule
from sklearn.model_selection import train_test_split
from torch.utils.data import DataLoader
from monai.transforms import (
    Compose,
    LoadImaged,
    EnsureChannelFirstd,
    CenterSpatialCropd,
    ToTensord
)


class DataModule:
    def __init__(self, path_to_center_folders, prediction_target_file_path, target, test_center, batch_size=32):
        self.root = Path(path_to_center_folders)
        self.target = pd.read_csv(prediction_target_file_path).set_index('lesion')[target]
        self.test_center = test_center
        self.batch_size=batch_size
        self.train_transform = Compose([
            LoadImaged(keys=['img']),
            EnsureChannelFirstd(keys=['img']),
            CenterSpatialCropd(keys=['img'], roi_size=(96,96,96)),
            ToTensord(keys=['img','label'])
        ])
        self.val_transform = self.train_transform
        self.test_transform = self.train_transform
        self.centers = [c.name for c in self.root.iterdir()]
    
    def setup(self):
        dev_centers = [c for c in self.centers if not c == self.test_center]

        # development data
        dev_data = list(itertools.chain(*[
            self.data_dir_to_dict(self.root / c) for c in dev_centers
        ]))
        train_data, val_data = train_test_split(dev_data, test_size=0.75)
        self.train_dataset = Dataset(train_data, self.train_transform)
        self.val_dataset = Dataset(val_data, self.val_transform)

        # test data
        test_data = self.data_dir_to_dict(self.root / self.test_center)
        self.test_dataset = Dataset(test_data, self.test_transform)

    def data_dir_to_dict(self, dir):
        return [{'img':str(lesion_path),'label':self.target.loc[lesion_path.name]}
            for lesion_path in dir.iterdir() if lesion_path.name in self.target.index
        ]

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, num_workers=1)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False, num_workers=1)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False, num_workers=1)

    
dm = DataModule(
    r'C:\Users\user\data\dl_radiomics\preprocessed_3d',
    r'C:\Users\user\data\tables\lesion_followup_curated_v4.csv',
    'liver',
    'amphia'
)


In [129]:
dm.setup()

In [133]:
dm.train_dataset[2]['img'].shape

(1, 96, 96, 46)

In [131]:
x, y = next(iter(dm.train_dataloader()))

RuntimeError: Caught RuntimeError in DataLoader worker process 0.
Original Traceback (most recent call last):
  File "c:\Users\user\anaconda3\envs\rob\lib\site-packages\torch\utils\data\_utils\worker.py", line 302, in _worker_loop
    data = fetcher.fetch(index)
  File "c:\Users\user\anaconda3\envs\rob\lib\site-packages\torch\utils\data\_utils\fetch.py", line 52, in fetch
    return self.collate_fn(data)
  File "c:\Users\user\anaconda3\envs\rob\lib\site-packages\torch\utils\data\_utils\collate.py", line 160, in default_collate
    return elem_type({key: default_collate([d[key] for d in batch]) for key in elem})
  File "c:\Users\user\anaconda3\envs\rob\lib\site-packages\torch\utils\data\_utils\collate.py", line 160, in <dictcomp>
    return elem_type({key: default_collate([d[key] for d in batch]) for key in elem})
  File "c:\Users\user\anaconda3\envs\rob\lib\site-packages\torch\utils\data\_utils\collate.py", line 141, in default_collate
    return torch.stack(batch, 0, out=out)
  File "c:\Users\user\anaconda3\envs\rob\lib\site-packages\monai\data\meta_tensor.py", line 249, in __torch_function__
    ret = super().__torch_function__(func, types, args, kwargs)
  File "c:\Users\user\anaconda3\envs\rob\lib\site-packages\torch\_tensor.py", line 1121, in __torch_function__
    ret = func(*args, **kwargs)
RuntimeError: stack expects each tensor to be equal size, but got [1, 96, 96, 47] at entry 0 and [1, 96, 96, 28] at entry 4
