In [1]:
%matplotlib notebook
from pathlib import Path
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm 

In [2]:
root = Path('Task02_Heart\Task02_Heart\imagesTr')
label = Path('Task02_Heart\Task02_Heart\labelsTr')

In [3]:
def change_to_label(path):
    parts = list(path.parts)
    parts[parts.index('imagesTr')] = 'labelsTr'
    return Path(*parts)

In [4]:
sample_path = list(root.glob("la*"))[0]
sample_path


WindowsPath('Task02_Heart/Task02_Heart/imagesTr/la_003.nii.gz')

In [5]:
sample_path_label = change_to_label(sample_path)
sample_path_label

WindowsPath('Task02_Heart/Task02_Heart/labelsTr/la_003.nii.gz')

In [6]:
data  = nib.load(sample_path)
label = nib.load(sample_path_label)
                 
mri = data.get_fdata()
mask = label.get_fdata().astype(np.uint8)

In [7]:
nib.aff2axcodes(data.affine)

('R', 'A', 'S')

In [8]:
from celluloid import Camera
from IPython.display import HTML


In [9]:
fig  = plt.figure()
camera = Camera(fig)

for  i in range(mri.shape[2]):
    plt.imshow(mri[:,:,i], cmap='bone')
    mask_ = np.ma.masked_where(mask[:,:,i]==0, mask[:,:,i])
    plt.imshow(mask_, alpha=0.5)
    camera.snap()
    
anim = camera.animate()


<IPython.core.display.Javascript object>

In [10]:
HTML(anim.to_html5_video())

In [11]:
def normalize(full_volume):
    mu = full_volume.mean()
    std = np.std(full_volume)
    normalized = (full_volume - mu) / std
    return normalized
def standardize(normalized):
    standardized =(normalized - normalized.min())/(normalized.max() - normalized.min())
    return standardized

In [12]:
all_files = list(root.glob('la*'))


In [13]:
len(all_files)

20

In [14]:
# save_root = Path('Preprocessed')

# for counter, path_to_mri in enumerate(tqdm(all_files)):

#     path_to_label = change_to_label(path_to_mri)
#     mri = nib.load(path_to_mri)
#     assert nib.aff2axcodes(mri.affine) == ('R', 'A', 'S')
#     mri_data = mri.get_fdata()
#     label_data = nib.load(path_to_label).get_fdata().astype(np.uint8)

#     mri_data = mri_data[32:-32, 32:-32]
#     label_data = label_data[32:-32, 32:-32]

#     nomalized_mri = normalize(mri_data)
#     standardized_mri = standardize(nomalized_mri)

#     if counter < 17:
#         current_path = save_root/'train'/str(counter)
#     else:
#         current_path = save_root/'test'/str(counter)

#     for i in range(standardized_mri.shape[-1]):
#         slice = standardized_mri[:,:,i]
#         mask = label_data[:,:,i]
#         slice_path = current_path/'data'
#         mask_path = current_path/'mask'
#         slice_path.mkdir(parents=  True, exist_ok=True)
#         mask_path.mkdir(parents= True, exist_ok=True)
#         np.save(slice_path/str(i), slice)
#         np.save(mask_path/str(i), mask)
    

In [15]:
path = Path('Preprocessed/train/0')

In [16]:
file = '6.npy'
slice = np.load(path/'data'/file)
mask = np.load(path/'mask'/file)

In [17]:
plt.figure()
plt.imshow(slice, cmap='bone')
mask_ = np.ma.masked_where(mask==0, mask)
plt.imshow(mask_, alpha=0.5)
plt.show()

<IPython.core.display.Javascript object>

In [18]:
import torch
import imgaug
from imgaug.augmentables.segmaps import SegmentationMapsOnImage
import imgaug.augmenters as iaa

In [19]:
class CardiacDataset(torch.utils.data.Dataset):
    def __init__(self, root, augment_params):
        self.all_files = self.extract_files(root)
        self.augment_params = augment_params

        
    @staticmethod 
    def extract_files(root):
        files=[]
        for subject in root.glob('*'):
            slice_path  = subject/'data'
            for slice in slice_path.glob('*.npy'):
                files.append(slice)
        return files
    @staticmethod
    def change_to_label(path):
        parts = list(path.parts)
        parts[parts.index('data')] = 'mask'
        return Path(*parts)
    
    def augment(self, slice, mask):
        random_seed = torch.randint(0, 100000, (1,)).item()
        imgaug.seed(random_seed)

        mask = SegmentationMapsOnImage(mask, mask.shape)
        slice_aug, mask_aug = self.seq(image=slice, segmentation_maps=mask)
        mask_aug = mask_aug.get_arr()
        return slice_aug, mask_aug
    
    def __len__(self):
        return len(self.all_files)
    
    def __getitem__(self, idx):
        file_path = self.all_files[idx]
        mask_path = self.change_to_label(file_path)
        slice = np.load(file_path).astype(np.float32)
        mask = np.load(mask_path)

        if self.augment_params:
            slice, mask = self.augment(slice, mask)
        return np.expand_dims(slice, 0), np.expand_dims(mask, 0)


In [21]:
seq = iaa.Sequential([
            iaa.Affine(scale=(0.85, 1.15)),
            iaa.Affine(rotate=(-45, 45)),
            iaa.ElasticTransformation()
        ])

In [22]:
path = Path('Preprocessed/train')
dataset = CardiacDataset(path,seq)

In [None]:
fig, axis = plt.subplots(3, 3, figsize=(9, 9))
for i in range(3):
    for j in range(3):
        slice, mask = dataset[200]
        axis[i][j].imshow(slice[0], cmap='bone')
        mask_ = np.ma.masked_where(mask==0, mask)
        axis[i,][j].imshow(mask_[0])