In [1]:
from src import *
from data.scanner.CAVA_V1 import *

In [21]:
class SparseCartesianDataset():
    """ 
    This class stores a Cartesian dataset without zeros in the k-space and stores a tensor of measured k-space coordinates (trajectory) instead of a sampling mask. 
    
    attributes:
    - `kspace`: shape (Nl, Nc, Nr, 2) contains the complex k-space measurements (real, imaginary part in last dimension)
    - `trajectory`: shape (Nl, Nr, 3) contains the k-space coordinates of the measurements normalized to 1/FOV. Order of the dimensions: z y x.
    - `mask`: shape (Nl, Nr, 3) contains the indices of the measured coordinates on the Cartesian grid (same information as in trajectory, but in a more accessible format). Order of the dimensions: z y x.
    - `smaps`:  shape (Nc, Nz, Ny, Nx, 2) torch.tensor float32
    - `transform`: optional function: dict -> dict
    - `additional_data`: arbitrary data

    naming:
    - `Nk`: number of frames
    - `Nl`: number of measured k-space lines
    - `Nr`: number if measurements in the read-out direction (x-direction)
    - `Nc`: number of receiver coils
    - `Nz`, `Ny`, `Nx`: resolution in z, y, and x-direction
    """

    def __init__(self, kspace, trajectory, mask, line_indices, smaps, reference=None, transform=None, additional_data=None):
        
        self.kspace = kspace # shape (Nk, Nc, Nl, Nr, 2) torch.tensor float32
        self.trajectory = trajectory # shape (Nk, Nl, Nr, 3) torch.tensor float32
        self.mask = mask # shape (Nk, Nl, Nr, 3) LongTensor
        self.line_indices = line_indices # shape (Nk, Nl) LongTensor
        self.smaps = smaps # shape (Nc, Nz, Ny, Nx) torch.tensor float32
        self.transform = transform # function: transform(dict: sample) -> dict
        self.reference = reference
        self.additional_data = additional_data # arbitrary data 

        assert kspace.ndim == 5
        assert trajectory.ndim == 4
        assert mask.ndim == 4
        assert line_indices.ndim == 2
        assert smaps.ndim == 4
        assert kspace.shape[-1] == 2
        assert trajectory.shape[-1] == 3
        assert mask.shape[-1] == 3

        self.Nk, self.Nc, self.Nl, self.Nr, _ = kspace.shape
        _, self.Nz, self.Ny, self.Nx = smaps.shape
                


    @classmethod
    def from_sparse_matfile2d_extract_validation_dataset_rebin(self, matfile_path, listfile_path,
                                                         transform=None,
                                                         shift=False,
                                                         remove_padding=False,
                                                         set_smaps_outside_to_one=False,
                                                         validation_percentage=0.,
                                                         number_of_lines_per_frame=6,
                                                         max_Nk=-1,
                                                         seed=1998
                                                         ):
        """ 
        Loads `mat_file_path` that stores measurement data without zeros in the k-space (sparse respresentation of data).
        Requires `listfile_path` that contains information about the order and position of the measured k-space lines.
        Randomly extracts `validation_percentage` percent of the k-space lines for validation.
        Bins the remaining lines into frames with `number_of_lines_per_frame` lines each. If `max_Nk` is specified, the number of frames is reduced and excess frames are discarded.


        Arguments:
        - `transform`: An optional function that is applied to every sample data loaded from the dataset with the get_item() method.
        - `shift`: If true, the image is shifted by Ny/4. This is necessary if the k-space data does not match the smaps otherwise.
        - `remove_padding`: By default, the k-space data of the scanner is zero-padded and the smaps are computed on a larger grid. If `remove_padding` is true, the padding is removed and the smaps are cropped in the Fourier domain.
        - `set_smaps_outside_to_one`: By default, the smaps estimated by the scanner have zero entries outside the human body (they cannot be estimated outside, as there is no signal). Thus, the reconstructions can take arbitrary values outside the body without affecting the reconstruction loss. If `set_smaps_outside_to_one` is true, the zero-sensitivities are set to 1.0. By setting the smaps outside to 1.0, the reconstructions are forced to zero outside the human body.
        - `validation_percentage`: Percentage of the k-space lines that are used for validation. Default: 0.
        - `number_of_lines_per_frame`: Number of k-space lines per frame. Default: 6.
        - `max_Nk`: If not -1, the number of frames is limited.
        - `seed`: seed for the random extraction of k-space lines.
        """

        # save the current state of the RNG and set the new one
        random_state = np.random.get_state() 
        np.random.seed(seed) 

        list_data = ListData(file_name=listfile_path)

        # detect datasets that are binned by cardiac phases -> handle them separately as they probably use the SENSE pattern
        assert list_data.Nk_card == 1, "ECG-binned data needs to be loaded differently (Nk_card > 1, Nk == 1). Not implemented yet for sparse datasets."
        
        # load matrices from the .mat file
        with h5py.File(matfile_path, 'r', rdcc_nbytes=1024**3, rdcc_w0=1, rdcc_nslots=1024) as f:
            raw_smaps = h5py2Complex(f["smaps"], load_in_chunks=False)
            raw_kspace = h5py2Complex(f["kspace"], load_in_chunks=False)
            reference = np.array(f["reference"])
            encoding_pars = parse_struct(f["encoding_pars"])

        # if required, set the zero pixels of the smaps (outside the body) to 1.0
        if set_smaps_outside_to_one:
            raw_smaps[raw_smaps == 0.] = 1.

        # if required, remove the zero-padding from the k-space and truncate the smaps in the Fourier domain
        if remove_padding: # remove the padding

            # compute the resolution in x-direction without padding
            Nx = int(encoding_pars["KxRange"][1] - encoding_pars["KxRange"][0] + 1)
            Nc, Nyold, Nxold = raw_smaps.shape

            # compute the resolution in y-direction without padding
            if -encoding_pars["KyRange"][0] == encoding_pars["KyRange"][1] + 1: # standard case
                Ny = int(encoding_pars["KyRange"][1] - encoding_pars["KyRange"][0] + 1)
            elif -encoding_pars["KyRange"][0] < encoding_pars["KyRange"][1]: # probably partial-Fourier
                Ny = int(2 * encoding_pars["KyRange"][1])
            else: # unknown case
                print(encoding_pars["KyRange"], raw_smaps.shape[2])
                raise Exception

            # truncate the smaps in the Fourier domain
            Nystart, Nxstart = int((Nyold - Ny) / 2), int((Nxold - Nx) / 2)

            smaps = to_tensor(raw_smaps)
            smaps_fft = fft2(smaps)
            smaps_fft = smaps_fft[:,Nystart:(Nystart+Ny), Nxstart:(Nxstart+Nx), :]
            smaps = ifft2(smaps_fft)

        else: # keep the padding
            Nc, Ny, Nx = raw_smaps.shape
            smaps = to_tensor(raw_smaps)
            
        # get a lists that contains the following information for every ky-line in the matrix `kspace`: index of the dynamic, index of the coil, ky indices (shifted)
        dynamics, coil_indices, ky_indices = list_data.get_dynamics_channel_indices_and_kyindices()
        num_lines_all_coils, Nr = raw_kspace.shape
        assert num_lines_all_coils % Nc == 0
        Nl = num_lines_all_coils // Nc
        
        # generate a random subset of measured lines:
        validation_indices = np.arange(stop=Nl)
        np.random.shuffle(validation_indices)
        validation_indices = validation_indices[0:int(Nl * validation_percentage / 100)]

        Nk = (Nl - len(validation_indices)) // number_of_lines_per_frame
        if max_Nk != -1:
            Nk = min(Nk, max_Nk)

        # create matrices where the sparse data is filled into
        kspace = np.zeros((Nk, Nc, number_of_lines_per_frame, Nr), dtype=np.csingle)
        mask = np.zeros((Nk, number_of_lines_per_frame, Nr, 3), dtype=np.int64)
        trajectory = np.zeros((Nk, number_of_lines_per_frame, Nr, 3), dtype=np.float32)
        line_indices = np.zeros((Nk, number_of_lines_per_frame), dtype=np.int64)

        # find the zero index of the k-space matrix in ky-direction
        ky_zero_index = int(Ny / 2)
        # find the first kx index that should be filled with data
        kx_shift = int(Nx / 2 - Nr / 2)

        # stacks the measurements from different coils
        def line_generator():

            for l in range(Nl):

                line_kspace = np.zeros((Nc, Nr), dtype=np.csingle)
                for c in range(Nc):
                    i = l*Nc+c
                    assert dynamics[i] == dynamics[l*Nc]
                    assert ky_indices[i] == ky_indices[l*Nc]
                    line_kspace[coil_indices[i], :] = raw_kspace[i, :]

                ky = ky_indices[l*Nc]
                ky_index = ky_zero_index + ky
                ky_coordinate = np.pi * ky / Ny

                if shift:
                    line_kspace *= np.exp(np.pi*1j*(ky_index - Ny/2))

                kx_indices = kx_shift + np.arange(Nr)
                kx_coordinates = np.pi * (-int(Nr/2) + np.arange(Nr)) / Nx

                line_mask = np.stack((np.zeros(Nr), np.ones(Nr)*ky_index, kx_indices), axis=-1)
                line_trajectory = np.stack((np.zeros(Nr), np.ones(Nr)*ky_coordinate, kx_coordinates), axis=-1)

                yield line_kspace, line_trajectory, line_mask


        lines = line_generator()
        j = 0
        validation_dataset = []
        for k in range(Nk):
            for l in range(number_of_lines_per_frame):
                while j in validation_indices: # put lines in the validation dataset
                    line_kspace, line_trajectory, line_mask = next(lines)
                    validation_dataset.append({
                        "line_index": j,
                        "k": k,
                        "kspace": to_tensor(line_kspace),
                        "trajectory": to_tensor(line_trajectory),
                        "mask": to_tensor(line_mask)
                    })
                    j += 1
                
                # put the line in the training dataset
                line_kspace, line_trajectory, line_mask = next(lines)
                kspace[k, :, l, :] = line_kspace
                trajectory[k, l, :, :] = line_trajectory
                mask[k, l, :, :] = line_mask
                line_indices[k, l] = j

                j += 1

        
        # insert dimensions for the z-axis (that is not used since this method handles 2D datasets with a single slice)
        kspace = to_tensor(kspace)
        trajectory = to_tensor(trajectory)
        mask = to_tensor(mask)
        
        # if reference data is available, the reference matrix has at least 3 dimensions
        if reference.ndim < 3:
            reference = None
        if reference is not None: reference = to_tensor(np.expand_dims(reference, axis=1))

        # restore the state of the RNG
        np.random.set_state(random_state)

        return self(kspace, trajectory, mask, line_indices, smaps, reference=reference, transform=transform), validation_dataset
    
    @classmethod
    def from_sparse_cartesian_dataset(self, dataset, transform=None):
        """
        Creates a shallow copy (underlying data remains identical).
        
        Use this method to initialize many different subclasses from the same data, without loading the data into memory multiple times.
        """
        return self(dataset.kspace, dataset.trajectory, dataset.mask, dataset.line_indices, dataset.smaps, reference=dataset.reference, additional_data=dataset.additional_data, transform=transform)

    def shape(self):
        """
        Returns (Nk, Nc, Nz, Ny, Nx).
        """
        return (self.Nk, self.Nc, self.Nz, self.Ny, self.Nx)
        
    def subset(self, subset_indices):
        return Subset(self, subset_indices)

    def __len__(self):
        return self.Nk

    def __getitem__(self, k):        

        if isinstance(k, int):
            indices = [k]
        elif isinstance(k, slice):
            indices = range(k.start or 0, k.stop or self.Nk, k.step or 1)
        elif isinstance(k, Iterable):
            indices = k
        else:
            raise Exception("invalid index format")

        reference = None
        if not self.reference is None:
            reference = torch.zeros((len(indices), 1, self.Ny, self.Nx), dtype=np.csingle)
            for n, index in enumerate(indices):
                reference[n, 0, :, :] = self.reference[index, :, :, :]

        sample = {
            'indices': indices,
            'kspace': self.kspace[indices, :, :, :, :],
            'trajectory': self.trajectory[indices, :, :, :],
            'mask': self.mask[indices, :, :, :],
            'line_indices': self.line_indices[indices, :],
            'smaps': self.smaps,
            'reference': reference,
        }

        if(self.transform):
            return self.transform(sample)
        return sample

In [25]:
dataset_info = datasets_cava_v1[10]
dataset, validation_dataset = SparseCartesianDataset.from_sparse_matfile2d_extract_validation_dataset_rebin(dataset_info["matfile_path"], dataset_info["listfile_path"], remove_padding=True, validation_percentage=5)