In [1]:
import numpy as np
import os
import nibabel as nib
from scipy.ndimage import zoom

import random
import torch
from torch.utils.data import Dataset

from preprocessing_utils import calc_mean_std_tensor, preprocess_scale_epi

# Model
from pytorch3dunet.unet3d import model


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

torch.set_default_dtype(torch.float32)

Device: cuda


In [None]:
def set_seed(seed=42):
    # Set seed for reproducibility
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)

set_seed(42)

### Loading the model with the weights from the best 3D model

Training data for the 4D-UNet were model predictions from the best performing 3D-UNet. Starting by initializing the 3D-UNet with the weights saved during training, and setting the model into evaluation mode, before generating predictions.

In [None]:
best_model = model.UNet3D(in_channels=2, out_channels=1, num_groups=8, is_segmentation=False)
best_model_weights = '../3D_models/3D_model_weights/lr5e-4_wd1e-3_do04/model_lr0.0005_wd0.001_do0.4_epoch30.pth'

checkpoint = torch.load(best_model_weights, map_location=device, weights_only=True)

# Only load the model state dict from the checkpoint
best_model.load_state_dict(checkpoint['model_state_dict'])

# Set model in evaluation mode
best_model.eval()

UNet3D(
  (encoders): ModuleList(
    (0): Encoder(
      (basic_module): DoubleConv(
        (SingleConv1): SingleConv(
          (groupnorm): GroupNorm(1, 2, eps=1e-05, affine=True)
          (conv): Conv3d(2, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (ReLU): ReLU(inplace=True)
        )
        (SingleConv2): SingleConv(
          (groupnorm): GroupNorm(8, 32, eps=1e-05, affine=True)
          (conv): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (ReLU): ReLU(inplace=True)
        )
      )
    )
    (1): Encoder(
      (pooling): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (basic_module): DoubleConv(
        (SingleConv1): SingleConv(
          (groupnorm): GroupNorm(8, 64, eps=1e-05, affine=True)
          (conv): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
          (ReLU): ReLU(inplace=True)
        )
     

## Prepearing data from 3D model to 4D model

Determining indices for data splitting into training, validation and test data. Using the same participants as used in 3D preprocessing. The k-space data is already divided into appropriate folders, but the indices will be used for splitting EPI data.

In [None]:
random.seed(42)
TEST_PARTICIPANT = random.randint(0,9)
VAL_PARTICIPANT = random.randint(0,8)
print(f'Test patient index: {TEST_PARTICIPANT}')
print(f'Val patient index: {VAL_PARTICIPANT}')

Test patient index: 1
Val patient index: 0


In [None]:
class LoadData(Dataset):
    """ 
    Custom Dataset for loading preprocessed k-space and image pairs.

    This dataset assumes that each data point has been saved as a pair of '.pt' files:
    - 'kspace{idx}.pt' in 'kspace_folder'
    - 'img{idx}.pt' in 'image_folder'

    Indexing follows:
    - Training set: indexes 0 to 363
    - Validation/Test set: indexes 0 to 51

    Parameters:
    - kspace_folder: Path to directory containing k-space files.
    - image_folder: Path to directory containing image files.
    - datatype: Type of dataset: either 'train', 'val', or 'test'.
    """
    def __init__(self, kspace_folder, image_folder, datatype):
        self.kspace_folder = kspace_folder
        self.image_folder = image_folder
        self.datatype = datatype

    def __len__(self):
        """
        Returns the number of samples in the dataset.
        """
        if self.datatype == 'train':
            return 364
        else:
            return 52
        
    def __getitem__(self, idx):
        """ 
        Loads and returns a single sample (k-space, image) pair.
        """
        image_path = os.path.join(self.image_folder, f'img{idx}.pt')
        kspace_path = os.path.join(self.kspace_folder, f'kspace{idx}.pt')
        image = torch.load(image_path)
        kspace = torch.load(kspace_path)
        return kspace, image

### Defining paths for k-space data

In [7]:
train_img_path = './preprocessed_img_train'
train_kspace_path = './preprocessed_kspace_train'

val_img_path = './preprocessed_img_val'
val_kspace_path = './preprocessed_kspace_val'

test_img_path = './preprocessed_img_test'
test_kspace_path = './preprocessed_kspace_test'

In [None]:
def resize_image_space(image_data):
    """ 
    Resize images. 
    Original shape: [1, 1, 128, 128, 128]
    Desired shape: [1, 1, 64, 64, 64]

    Parameters:
    - image_data: Image to be resized

    Returns:
    - resized_img: Resized image.
    """
    zoom_factors = (1,1, 64/128, 64/128, 64/128)

    resized_img = zoom(image_data, zoom_factors, order=1)
    return resized_img

In [None]:
def preprocess_ls(no_participants, data_type, save_file_path):
    """
    Acquiring predictions from the best performing 3D-UNet and assembling into 4D volume.
    Parameters:
    - no_participants: Number of participants in the dataset (7 for train, 1 for validation, 1 for test)
    - data_type: Determining the type of dataset. Will save the data into the correct folder. (Train, validation, test)
    - save_file_path: Path to save the processed data.

    Returns:
    None
    """

    best_model.eval()
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    best_model.to(device=device)

    # Specifying data type for correct path
    if data_type == 'train':
        file_path = train_kspace_path
    elif data_type == 'val':
        file_path = val_kspace_path
    elif data_type == 'test':
        file_path = test_kspace_path
    else:
        print('Select data type')
 

    for i in range(0,no_participants):
        patient_image = []
        for j in range(0,52):
            # Data is stored in numbers from 0 to 363 for training data. Second loop is to ensure 
            # each chunk of data only consists of 52 data points before moving on to next participant
            # when looping through training data

            # Printing participate number, time slice number and index number to ensure correct preprocessing
            print(f'Participant no: {i}, time slice no: {j}, index: {i*52+j}') 
            index = i*52+j

            # Loading data from the correct path
            kspace_path = os.path.join(file_path, f'kspace{index}.pt')
            current_kspace = torch.load(kspace_path)

            # Unsqueezing to add batch dimension
            current_kspace = current_kspace.unsqueeze(0)
            current_kspace = current_kspace.to(device=device)

            with torch.no_grad():
                # Getting predictions from the 3D-UNet
                pred = best_model(current_kspace)

                # Removing batch dimension
                pred = pred.squeeze(1).cpu()
        
            patient_image.append(pred)

        # Assembling each 3D data point into a 4D volume
        recon_img = np.array(patient_image)
        recon_img = torch.Tensor(recon_img)
        
        recon_img = recon_img.permute(1,0,2,3,4)

        # Resizing predictions to match the 4D-UNet requirements
        recon_img_resized = resize_image_space(recon_img)

        current_filename = os.path.join(save_file_path, f'preprocessed{i}.pt')

        # Saving the processed files with the correct file name
        torch.save(recon_img_resized, current_filename)
        print(f'Finished processing participant {i}. Shape: {recon_img_resized.shape}')


In [10]:
train_filename = '../preprocessed_4d/ls_train'
preprocess_ls(no_patients=7, data_type='train', save_file_path=train_filename)

Participant no: 0, time slice no: 0, index: 0


  current_kspace = torch.load(kspace_path)


Participant no: 0, time slice no: 1, index: 1
Participant no: 0, time slice no: 2, index: 2
Participant no: 0, time slice no: 3, index: 3
Participant no: 0, time slice no: 4, index: 4
Participant no: 0, time slice no: 5, index: 5
Participant no: 0, time slice no: 6, index: 6
Participant no: 0, time slice no: 7, index: 7
Participant no: 0, time slice no: 8, index: 8
Participant no: 0, time slice no: 9, index: 9
Participant no: 0, time slice no: 10, index: 10
Participant no: 0, time slice no: 11, index: 11
Participant no: 0, time slice no: 12, index: 12
Participant no: 0, time slice no: 13, index: 13
Participant no: 0, time slice no: 14, index: 14
Participant no: 0, time slice no: 15, index: 15
Participant no: 0, time slice no: 16, index: 16
Participant no: 0, time slice no: 17, index: 17
Participant no: 0, time slice no: 18, index: 18
Participant no: 0, time slice no: 19, index: 19
Participant no: 0, time slice no: 20, index: 20
Participant no: 0, time slice no: 21, index: 21
Participan

In [11]:
val_filename = '../preprocessed_4d/ls_val'
preprocess_ls(no_patients=1, data_type='val', save_file_path=val_filename)

Participant no: 0, time slice no: 0, index: 0


  current_kspace = torch.load(kspace_path)


Participant no: 0, time slice no: 1, index: 1
Participant no: 0, time slice no: 2, index: 2
Participant no: 0, time slice no: 3, index: 3
Participant no: 0, time slice no: 4, index: 4
Participant no: 0, time slice no: 5, index: 5
Participant no: 0, time slice no: 6, index: 6
Participant no: 0, time slice no: 7, index: 7
Participant no: 0, time slice no: 8, index: 8
Participant no: 0, time slice no: 9, index: 9
Participant no: 0, time slice no: 10, index: 10
Participant no: 0, time slice no: 11, index: 11
Participant no: 0, time slice no: 12, index: 12
Participant no: 0, time slice no: 13, index: 13
Participant no: 0, time slice no: 14, index: 14
Participant no: 0, time slice no: 15, index: 15
Participant no: 0, time slice no: 16, index: 16
Participant no: 0, time slice no: 17, index: 17
Participant no: 0, time slice no: 18, index: 18
Participant no: 0, time slice no: 19, index: 19
Participant no: 0, time slice no: 20, index: 20
Participant no: 0, time slice no: 21, index: 21
Participan

In [12]:
test_filename = '../preprocessed_4d/ls_test'
preprocess_ls(no_patients=1, data_type='test', save_file_path=test_filename)

Participant no: 0, time slice no: 0, index: 0


  current_kspace = torch.load(kspace_path)


Participant no: 0, time slice no: 1, index: 1
Participant no: 0, time slice no: 2, index: 2
Participant no: 0, time slice no: 3, index: 3
Participant no: 0, time slice no: 4, index: 4
Participant no: 0, time slice no: 5, index: 5
Participant no: 0, time slice no: 6, index: 6
Participant no: 0, time slice no: 7, index: 7
Participant no: 0, time slice no: 8, index: 8
Participant no: 0, time slice no: 9, index: 9
Participant no: 0, time slice no: 10, index: 10
Participant no: 0, time slice no: 11, index: 11
Participant no: 0, time slice no: 12, index: 12
Participant no: 0, time slice no: 13, index: 13
Participant no: 0, time slice no: 14, index: 14
Participant no: 0, time slice no: 15, index: 15
Participant no: 0, time slice no: 16, index: 16
Participant no: 0, time slice no: 17, index: 17
Participant no: 0, time slice no: 18, index: 18
Participant no: 0, time slice no: 19, index: 19
Participant no: 0, time slice no: 20, index: 20
Participant no: 0, time slice no: 21, index: 21
Participan

## Loading EPI data

The EPI data gets preprocessed to match the data type of the reconstruced LS volumes. That includes temporal subsampling and normalization.

In [4]:
epi_path = "../../data/KSPACE/epi"
epi_files = [f for f in os.listdir(epi_path) if f.endswith(".nii")] 

epi_images = [] 

for i, filename in enumerate(epi_files, start=1):
    file_path = os.path.join(epi_path, filename)
    epi_images.append(nib.load(file_path).get_fdata())

### Subsampling the correct time points according to the LS data

In [5]:
epi_images_sliced = [image[:,:,:,90:142] for image in epi_images] 

### Dividing into training-, validation-, and test-data

In [None]:
epi_images_sliced_test = [epi_images_sliced[TEST_PARTICIPANT]]
epi_images_sliced_train_val = [epi_img for i, epi_img in enumerate(epi_images_sliced) if i != TEST_PARTICIPANT]

epi_images_sliced_val = [epi_images_sliced_train_val[VAL_PARTICIPANT]]
epi_images_sliced_train = [epi_img for i, epi_img in enumerate(epi_images_sliced_train_val) if i != VAL_PARTICIPANT]

### Padding images

The EPI volumes contained a varying amount of slices in the third dimension. Padding the volumes to match the LS volumes.

In [10]:
epi_images_sliced_train[0].shape

(64, 64, 42, 52)

In [None]:
def resize_label(image):
    """ 
    Resizing the label to match dimensions of the input data. Resized by zero-padding the third dimension.

    Parameters:
    - image: Image to be resized

    Returns:
    - resized_image: Image after zero-padding
    """
    current_shape = image.shape[2]
    desired_shape = 64 - current_shape
    padded_image = np.pad(image, ((0,0), (0,0), (0, desired_shape), (0,0)), mode='constant')

    flipped_image = np.fliplr(padded_image)
    image = flipped_image.copy()
    resized_image = torch.Tensor(image)

    # Reshape to be the same shape as input image
    resized_image = resized_image.permute(3,0,1,2) 
    resized_image = resized_image.unsqueeze(0) # add channel dimension    
    return resized_image

In [None]:
epi_images_resized_train = [resize_label(image) for image in epi_images_sliced_train]
epi_images_resized_val = [resize_label(image) for image in epi_images_sliced_val]
epi_images_resized_test = [resize_label(image) for image in epi_images_sliced_test]

In [None]:
# Calculating mean and std on training data for normalization

mean_epi, std_epi = calc_mean_std_tensor(epi_images_resized_train)

print(f'Mean X data: {mean_epi}, Std X data: {std_epi}')

Mean X data: 284.45147705078125, Std X data: 710.0986328125


In [None]:
def preprocess_and_save_epi(images, save_path):
    """ 
    Preprocessing and saving EPI images.

    Parameters:
    - images: EPI images to be preprocessed
    - save_path: Path to save the preprocessed images.

    Returns:
    None. The images gets saved to specified folders.
    """
    i = 0
    for image in images:
        preprocessed_image = preprocess_scale_epi(image, mean_epi, std_epi, normalize=True)
        file_path = os.path.join(save_path, f'preprocessed{i}.pt')
        torch.save(preprocessed_image, file_path)
        i+=1


In [None]:
# Printing shape to ensure all images were padded correctly

epi_train_path = '../preprocessed_4d/epi_train'
preprocess_and_save_epi(epi_images_resized_train, epi_train_path)

In [None]:
epi_val_path = '../preprocessed_4d/epi_val'
preprocess_and_save_epi(epi_images_resized_val, epi_val_path)

In [None]:
epi_test_path = '../preprocessed_4d/epi_test'
preprocess_and_save_epi(epi_images_resized_test, epi_test_path)