In [1]:
import re
import matplotlib.pyplot as plt
import numpy as np
import cv2
from glob import glob
import os
import nibabel as nib
import torch
import torch.nn as nn
from skimage.transform import resize
import torch.nn.functional as F

In [2]:
root = "../data_for_training/Synapse/RawData/Training"

root_images = os.path.join(root, "img")
root_labels = os.path.join(root, "label")

print(root_images, root_labels)

../data_for_training/Synapse/RawData/Training\img ../data_for_training/Synapse/RawData/Training\label


In [3]:
def get_all_files(folder_path, formatted_extension): 
    paths = sorted(glob(os.path.join(folder_path, formatted_extension)))

    return paths

In [5]:
images = get_all_files(root_images, "*.nii.gz")
labels = get_all_files(root_labels, "*.nii.gz")

print(len(images), len(labels))

30 30


In [6]:
for x, y in zip (images, labels): 
    print(x, y) 

../data_for_training/Synapse/RawData/Training\img\img0001.nii.gz ../data_for_training/Synapse/RawData/Training\label\label0001.nii.gz
../data_for_training/Synapse/RawData/Training\img\img0002.nii.gz ../data_for_training/Synapse/RawData/Training\label\label0002.nii.gz
../data_for_training/Synapse/RawData/Training\img\img0003.nii.gz ../data_for_training/Synapse/RawData/Training\label\label0003.nii.gz
../data_for_training/Synapse/RawData/Training\img\img0004.nii.gz ../data_for_training/Synapse/RawData/Training\label\label0004.nii.gz
../data_for_training/Synapse/RawData/Training\img\img0005.nii.gz ../data_for_training/Synapse/RawData/Training\label\label0005.nii.gz
../data_for_training/Synapse/RawData/Training\img\img0006.nii.gz ../data_for_training/Synapse/RawData/Training\label\label0006.nii.gz
../data_for_training/Synapse/RawData/Training\img\img0007.nii.gz ../data_for_training/Synapse/RawData/Training\label\label0007.nii.gz
../data_for_training/Synapse/RawData/Training\img\img0008.nii.

In [8]:
def get_slice_from_volumetric_data(image_volume, mask_volume, start_idx, num_slice=8):
    end_idx = start_idx + num_slice

    images = torch.empty(num_slice, 1, 256, 256)
    masks = torch.empty(num_slice, 1, 256, 256)

    for i in range(start_idx, end_idx, 1):
        image = image_volume[:, :, i]
        image = cv2.resize(image, (256, 256))
        image = np.expand_dims(image, axis=0)
        image = torch.from_numpy(image)

        images[i - start_idx, :, :, :] = image

        mask = mask_volume[:, :, i]
        mask = torch.from_numpy(mask).long()
        mask = F.one_hot(mask, num_classes=14)
        mask = mask.numpy()
        mask = resize(mask, (256, 256, 14),
                      preserve_range=True, anti_aliasing=True)
        mask = torch.from_numpy(mask) 
        mask = torch.argmax(mask, dim=-1)
        mask = torch.unsqueeze(mask, dim=0) 

        masks[i - start_idx, :, :, :] = mask

    return images, masks

In [9]:
image_volume = nib.load(
    "../data_for_training/Synapse/RawData/Training\img\img0001.nii.gz").get_fdata()

mask_volume = nib.load(
    "../data_for_training/Synapse/RawData/Training\label\label0001.nii.gz").get_fdata()

images, masks = get_slice_from_volumetric_data(image_volume, mask_volume, 0, num_slice=9)

print(images.shape, masks.shape)

torch.Size([9, 1, 256, 256]) torch.Size([9, 1, 256, 256])


In [10]:
print(np.unique(mask_volume))

[ 0.  1.  2.  3.  4.  5.  6.  7.  8.  9. 10. 11. 12. 13.]
