In [23]:
# Adusted SliceDataset class: Instead of taking as input 
# the folder directory of the fastmri dataset as in the 
# original implementation, the adjusted class takes as input 
# a list of paths to individual files of the fastmri dataset
# and creates of a SliceDataset based on that list of files
import fastmri
import pandas as pd
import h5py
import numpy as np
import torch
from pathlib import Path
import xml.etree.ElementTree as etree


from typing import (
    Any,
    Callable,
    Dict,
    List,
    NamedTuple,
    Optional,
    Sequence,
    Tuple,
    Union,
)

def et_query(
    root: etree.Element,
    qlist: Sequence[str],
    namespace: str = "http://www.ismrm.org/ISMRMRD",
) -> str:
    """
    ElementTree query function.
    This can be used to query an xml document via ElementTree. It uses qlist
    for nested queries.
    Args:
        root: Root of the xml to search through.
        qlist: A list of strings for nested searches, e.g. ["Encoding",
            "matrixSize"]
        namespace: Optional; xml namespace to prepend query.
    Returns:
        The retrieved data as a string.
    """
    s = "."
    prefix = "ismrmrd_namespace"

    ns = {prefix: namespace}

    for el in qlist:
        s = s + f"//{prefix}:{el}"

    value = root.find(s, ns)
    if value is None:
        raise RuntimeError("Element not found")

    return str(value.text)

class FastMRIRawDataSample(NamedTuple):
    fname: Path
    slice_ind: int
    metadata: Dict[str, Any]

class SliceDataset(torch.utils.data.Dataset):
    """
    A PyTorch Dataset that provides access to MR image slices.
    """

    def __init__(
        self,
        files,
        challenge: str = 'multicoil',
        transform: Optional[Callable] = None,
    ):
        """
        Args:
            files: list of filepaths
            challenge: "singlecoil" or "multicoil" depending on which challenge
                to use.
            transform: Optional; A callable object that pre-processes the raw
                data into appropriate form. The transform function should take
                'kspace', 'target', 'attributes', 'filename', and 'slice' as
                inputs. 'target' may be null for test data.
        """
        if challenge not in ("singlecoil", "multicoil"):
            raise ValueError('challenge should be either "singlecoil" or "multicoil"')


        self.transform = transform
        self.recons_key = (
            "reconstruction_esc" if challenge == "singlecoil" else "reconstruction_rss"
        )
        self.raw_samples = []
        files = [Path(f) for f in files]
        for fname in sorted(files):
            metadata, num_slices = self._retrieve_metadata(fname)
            new_raw_samples = []
            for slice_ind in range(num_slices):
                raw_sample = FastMRIRawDataSample(fname, slice_ind, metadata)
                new_raw_samples.append(raw_sample)

            self.raw_samples += new_raw_samples

    def _retrieve_metadata(self, fname):
        with h5py.File(fname, "r") as hf:
            et_root = etree.fromstring(hf["ismrmrd_header"][()])

            enc = ["encoding", "encodedSpace", "matrixSize"]
            enc_size = (
                int(et_query(et_root, enc + ["x"])),
                int(et_query(et_root, enc + ["y"])),
                int(et_query(et_root, enc + ["z"])),
            )
            rec = ["encoding", "reconSpace", "matrixSize"]
            recon_size = (
                int(et_query(et_root, rec + ["x"])),
                int(et_query(et_root, rec + ["y"])),
                int(et_query(et_root, rec + ["z"])),
            )

            lims = ["encoding", "encodingLimits", "kspace_encoding_step_1"]
            enc_limits_center = int(et_query(et_root, lims + ["center"]))
            enc_limits_max = int(et_query(et_root, lims + ["maximum"])) + 1

            padding_left = enc_size[1] // 2 - enc_limits_center
            padding_right = padding_left + enc_limits_max

            num_slices = hf["kspace"].shape[0]

            metadata = {
                "padding_left": padding_left,
                "padding_right": padding_right,
                "encoding_size": enc_size,
                "recon_size": recon_size,
                **hf.attrs,
            }

        return metadata, num_slices

    def __len__(self):
        return len(self.raw_samples)

    def __getitem__(self, i: int):
        fname, dataslice, metadata = self.raw_samples[i]

        with h5py.File(fname, "r") as hf:
            kspace = hf["kspace"][dataslice]

            mask = np.asarray(hf["mask"]) if "mask" in hf else None

            target = hf[self.recons_key][dataslice] if self.recons_key in hf else None
            
            attrs = dict(hf.attrs)
            attrs.update(metadata)
 
            if self.transform is None:
                sample = (kspace, mask, target, attrs, fname.name, dataslice)
            else:
                sample = self.transform(kspace, mask, target, attrs, fname.name, dataslice)

        return sample

In [24]:
# How to use
import pandas as pd
df_dataset = pd.read_csv('knee_trainset.csv') 

# FILTER DATAFRAME HERE
# FOR EXAMPLE, filter by kspace shape:
df_filtered = df_dataset[(df_dataset['encodeX']==640) & (df_dataset['encodeY']==368)]

# Create list of file paths
path_to_fastmri = '/media/hdd1/fastMRIdata/knee_singlecoil_train/' # Adjust this
files = list(image_path + df_filtered['filename'])

# Slice dataset
slice_dataset = SliceDataset(files, challenge='singlecoil', transform=None)

# Analogous for knee_valset, and brain dataset