In [1]:
import os
import pandas as pd
import numpy as np
import nibabel as nib
import h5py
import time

from einops import rearrange
import random
import pickle
import torch
from typing import Union, Optional, Callable
from pathlib import Path
import torchvision
from torch.utils.data import DataLoader

%matplotlib inline

In [2]:
mri_data_root = "/data/projects/recon/data/public/qdess/v1-release/files_recon_calib-24/"

In [3]:
class SliceDataset(torch.utils.data.Dataset):
    """
    A PyTorch Dataset that provides access to MR image slices.
    """

    def __init__(
        self,
        root: Union[str, Path, os.PathLike],
        transform: Optional[Callable] = None,
        use_dataset_cache: bool = False,
        sample_rate: Optional[float] = None,
        volume_sample_rate: Optional[float] = None,
        dataset_cache_file: Union[str, Path, os.PathLike] = "dataset_cache.pkl",
    ):
        """
        Args:
            root: Path to the dataset.
            transform: Optional; A callable object that pre-processes the raw
                data into appropriate form. The transform function should take
                'kspace', 'target', 'attributes', 'filename', and 'slice' as
                inputs. 'target' may be null for test data.
            use_dataset_cache: Whether to cache dataset metadata. This is very
                useful for large datasets like the brain data.
            sample_rate: Optional; A float between 0 and 1. This controls what fraction
                of the slices should be loaded. Defaults to 1 if no value is given.
                When creating a sampled dataset either set sample_rate (sample by slices)
                or volume_sample_rate (sample by volumes) but not both.
            volume_sample_rate: Optional; A float between 0 and 1. This controls what fraction
                of the volumes should be loaded. Defaults to 1 if no value is given.
                When creating a sampled dataset either set sample_rate (sample by slices)
                or volume_sample_rate (sample by volumes) but not both.
            dataset_cache_file: Optional; A file in which to cache dataset
                information for faster load times.
        """
        if sample_rate is not None and volume_sample_rate is not None:
            raise ValueError(
                "either set sample_rate (sample by slices) or volume_sample_rate (sample by volumes) but not both"
            )

        self.dataset_cache_file = Path(dataset_cache_file)

        self.transform = transform
        self.examples = []

        # set default sampling mode if none given
        if sample_rate is None:
            sample_rate = 1.0
        if volume_sample_rate is None:
            volume_sample_rate = 1.0

        # load dataset cache if we have and user wants to use it
        if self.dataset_cache_file.exists() and use_dataset_cache:
            with open(self.dataset_cache_file, "rb") as f:
                dataset_cache = pickle.load(f)
        else:
            dataset_cache = {}

        # check if our dataset is in the cache
        # if there, use that metadata, if not, then regenerate the metadata
        if dataset_cache.get(root) is None or not use_dataset_cache:
            files = list(Path(root).iterdir())
            for fname in sorted(files):
                num_slices = self._retrieve_metadata(fname)

                self.examples += [
                    (fname, slice_ind) for slice_ind in range(num_slices)
                ]

            if dataset_cache.get(root) is None and use_dataset_cache:
                dataset_cache[root] = self.examples
                logging.info(f"Saving dataset cache to {self.dataset_cache_file}.")
                with open(self.dataset_cache_file, "wb") as f:
                    pickle.dump(dataset_cache, f)
        else:
            logging.info(f"Using dataset cache from {self.dataset_cache_file}.")
            self.examples = dataset_cache[root]

        # subsample if desired
        if sample_rate < 1.0:  # sample by slice
            random.shuffle(self.examples)
            num_examples = round(len(self.examples) * sample_rate)
            self.examples = self.examples[:num_examples]
        elif volume_sample_rate < 1.0:  # sample by volume
            vol_names = sorted(list(set([f[0].stem for f in self.examples])))
            random.shuffle(vol_names)
            num_volumes = round(len(vol_names) * volume_sample_rate)
            sampled_vols = vol_names[:num_volumes]
            self.examples = [
                example for example in self.examples if example[0].stem in sampled_vols
            ]
            
    @staticmethod
    def _retrieve_metadata(fname):
        with h5py.File(fname, "r") as hf:
            num_slices = hf["target"].shape[2]

        return num_slices

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i: int):
        fname, dataslice = self.examples[i]

        with h5py.File(fname, "r") as hf:
            target = hf["target"][dataslice]
            sample = self.transform(target, fname.name, dataslice)

        return sample

In [4]:
class DataTransform:

    def __init__(self, use_seed: bool = True):
        """
        Args:
            mask_func: Optional; A function that can create a mask of
                appropriate shape. Defaults to None.
        """
        self.use_seed = use_seed

    def __call__(
        self,
        target: np.ndarray,
        fname: str,
        slice_num: int,
    ) -> np.ndarray:
        
        return np.abs(rearrange(target, 'x y z e -> (z e) x y')), fname, slice_num

In [5]:
def create_training_loaders(
    data_path: Path,
    sample_rate: float,
    batch_size: int,
    num_workers: int,
    TrainingTransform: torchvision.transforms.Compose,
) -> torch.utils.data.DataLoader:
    train_loader = DataLoader(
        dataset=SliceDataset(
            root=data_path,
            transform=TrainingTransform,
            sample_rate=sample_rate,
        ),
        batch_size=batch_size,
        shuffle=True,
        num_workers=num_workers,
        pin_memory=False,
    )
    return train_loader

In [6]:
train_loader = create_training_loaders(mri_data_root, 1., 1, 4, DataTransform())

In [7]:
init_start = time.perf_counter()
num_slices = 0
for i, data in enumerate(train_loader):
    (target, fname, slice_num) = data
    num_slices = num_slices + 1
print(f"Parsed slices {num_slices}. Time taken {np.round(time.perf_counter() - init_start, 2)}s")

Parsed slices 24780. Time taken 20.35s
