In [1]:
import nibabel as nib
import numpy as np
import matplotlib.pyplot as plt
import os

from keras.utils import to_categorical
from sklearn import preprocessing

from PIL import Image
import itertools
from tqdm import tqdm
from dltk.io.preprocessing import *

  from ._conv import register_converters as _register_converters
Using TensorFlow backend.


In [2]:
root = '../dataset/'
ct_set = os.path.join(root,'ct_train_test/ct_train/')
mr_set = os.path.join(root,'mr_train_test/mr_train/')

class_num=7+1 # background

In [3]:
ct_list = os.listdir(ct_set)
ct_images = list()
ct_labels = list()
for ct_l in ct_list:
    if 'image' in ct_l:
        file_path = os.path.join(ct_set, ct_l)
        fn = os.listdir(file_path)
        ct_images.append(nib.load(file_path + '/' + fn[0]))
    elif 'label' in ct_l:
        file_path = os.path.join(ct_set, ct_l)
        fn = os.listdir(file_path)
        ct_labels.append(nib.load(file_path + '/' + fn[0]))

In [4]:
mr_list = os.listdir(mr_set)
mr_images = list()
mr_labels = list()
for mr_l in mr_list:
    if 'image' in mr_l:
        file_path = os.path.join(mr_set, mr_l)
        fn = os.listdir(file_path)
        mr_images.append(nib.load(file_path + '/' + fn[0]))
    elif 'label' in mr_l:
        file_path = os.path.join(mr_set, mr_l)
        fn = os.listdir(file_path)
        mr_labels.append(nib.load(file_path + '/' + fn[0]))

In [5]:
ct_cnt = len(ct_images)
mr_cnt = len(mr_images)

# image shape

In [6]:
ct_size = list()
mr_size = list()
for ct_image in ct_images:
    ct_size.append(ct_image.shape)    
for mr_image in mr_images:
    mr_size.append(mr_image.shape)

In [7]:
ct_size

[(512, 512, 363),
 (512, 512, 239),
 (512, 512, 298),
 (512, 512, 200),
 (512, 512, 177),
 (512, 512, 248),
 (512, 512, 243),
 (512, 512, 222),
 (512, 512, 293),
 (512, 512, 274),
 (512, 512, 239),
 (512, 512, 177),
 (512, 512, 211),
 (512, 512, 358),
 (512, 512, 300),
 (512, 512, 333),
 (512, 512, 283),
 (512, 512, 187),
 (512, 512, 297),
 (512, 512, 363)]

In [8]:
mr_size

[(512, 512, 160),
 (512, 512, 128),
 (288, 288, 160),
 (288, 288, 120),
 (288, 288, 130),
 (256, 256, 160),
 (288, 288, 180),
 (288, 288, 130),
 (512, 512, 120),
 (288, 288, 160),
 (288, 288, 160),
 (512, 512, 128),
 (512, 512, 112),
 (512, 512, 160),
 (340, 340, 200),
 (288, 288, 130),
 (288, 288, 140),
 (288, 288, 150),
 (288, 288, 135),
 (288, 288, 135)]

# image resize

In [9]:
ct_pad_images = np.zeros((len(ct_images),256,256,256,1))
for i in tqdm(range(len(ct_images))):
    img = ct_images[i].get_data()
    ct_pad_images[i,:,:,:,:] =  resize_image_with_crop_or_pad(img, [256, 256, 256], mode='symmetric').reshape(256,256,256,1)
del(ct_pad_images)

100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:12<00:00,  1.63it/s]


In [10]:
ct_pad_labels = np.zeros((len(ct_labels),256,256,256,8))
label_encoder = preprocessing.LabelEncoder()
for i in tqdm(range(len(ct_labels))):
    img = ct_labels[i].get_data()    
    img = resize_image_with_crop_or_pad(img, [256, 256, 256], mode='symmetric') 
    
    # encoder
    raw_shape = img.shape
    img = img.reshape(-1)
    img = label_encoder.fit_transform(img)
    img = to_categorical(img, class_num)
    
    # reshape to raw shape
    img = img.reshape((1,) + raw_shape + (class_num,))
    
    ct_pad_labels[i,:,:,:,:] =  img
del(ct_pad_labels)

100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:29<00:00,  1.49s/it]


In [11]:
mr_pad_images = np.zeros((len(mr_images),256,256,256,1))
for i in tqdm(range(len(mr_images))):
    img = mr_images[i].get_data()
    mr_pad_images[i,:,:,:,:] =  resize_image_with_crop_or_pad(img, [256, 256, 256], mode='symmetric').reshape(256,256,256,1)
del(mr_pad_images)

100%|██████████████████████████████████████████████████████████████████████████████████| 20/20 [00:03<00:00,  6.53it/s]


In [12]:
mr_pad_labels = np.zeros((len(mr_labels),256,256,256,8))
label_encoder = preprocessing.LabelEncoder()
for i in tqdm(range(len(mr_labels))):
    img = mr_labels[i].get_data()
    img = resize_image_with_crop_or_pad(img, [256, 256, 256], mode='symmetric')
    
    # encoder
    raw_shape = img.shape
    img = img.reshape(-1)
    
    img = label_encoder.fit_transform(img)
    print(img.shape)
    label, cnt = np.unique(img,return_counts=True)
    print('label: ',label)
    print('cnt: ',cnt)
    img = to_categorical(img, class_num)
    
    # reshape to raw shape
    img = img.reshape((1,) + raw_shape + (class_num,))
    
    mr_pad_labels[i,:,:,:,:] =  img
del(mr_pad_labels)

  0%|                                                                                           | 0/20 [00:00<?, ?it/s]

(16777216,)
label:  [0 1 2 3 4 5 6 7]
cnt:  [16050881   199469   100824    61215   101819   116835    97177    48996]


  5%|████▏                                                                              | 1/20 [00:01<00:31,  1.65s/it]

(16777216,)
label:  [0 1 2 3 4 5 6 7]
cnt:  [15346084   217992   194124   143424   316988   186434   164110   208060]


 10%|████████▎                                                                          | 2/20 [00:03<00:29,  1.65s/it]

(16777216,)
label:  [0 1 2 3 4 5 6 7]
cnt:  [15657616   201691    52021   242805   188901   224576   156875    52731]


 15%|████████████▍                                                                      | 3/20 [00:04<00:28,  1.66s/it]

(16777216,)
label:  [0 1 2 3 4 5 6 7]
cnt:  [15527664   234139    98360   287138   132616   322280   106781    68238]


 20%|████████████████▌                                                                  | 4/20 [00:06<00:26,  1.64s/it]

(16777216,)
label:  [0 1 2 3 4 5 6 7]
cnt:  [15902370   232146   113617   206335   100830   114424    46456    61038]


 25%|████████████████████▊                                                              | 5/20 [00:08<00:24,  1.64s/it]

(16777216,)
label:  [0 1 2 3 4 5 6 7]
cnt:  [15902184   145043    50585   165202   151884   267266    66484    28568]


 30%|████████████████████████▉                                                          | 6/20 [00:09<00:23,  1.66s/it]

(16777216,)
label:  [0 1 2 3 4 5 6 7]
cnt:  [15513695   139244    41710   149944   401785   446601    41378    42859]


 35%|█████████████████████████████                                                      | 7/20 [00:11<00:21,  1.67s/it]

(16777216,)
label:  [0 1 2 3 4 5 6 7]
cnt:  [15390971   263428   132771   281612   230446   271909   150644    55435]


 40%|█████████████████████████████████▏                                                 | 8/20 [00:13<00:20,  1.68s/it]

(16777216,)
label:  [0 1 2 3 4 5 6 7]
cnt:  [14765033   285064   274279   276514   630170   227551   149118   169487]


 45%|█████████████████████████████████████▎                                             | 9/20 [00:15<00:18,  1.68s/it]

(16777216,)
label:  [0 1 2 3 4 5 6 7 8]
cnt:  [15494952   312719   110267    11667   245762   190593   187564   162053
    61639]


IndexError: index 8 is out of bounds for axis 1 with size 8

In [25]:
mr_label = resize_image_with_crop_or_pad(mr_labels[8].get_data(), [256, 256, 256], mode='symmetric')
mr_label.shape

(256, 256, 256)

In [26]:
label, cnt = np.unique(mr_label.reshape(-1),return_counts=True)
print('label: ',label)
print('cnt: ',cnt)

label:  [  0 205 420 500 550 600 820 850]
cnt:  [14765033   285064   274279   276514   630170   227551   149118   169487]


In [21]:
# encoder
raw_shape = mr_label.shape
mr_label = img.reshape(-1)
label_encoder = preprocessing.LabelEncoder()
mr_label = label_encoder.fit_transform(mr_label)
print(mr_label.shape)
label, cnt = np.unique(mr_label,return_counts=True)
print('label: ',label)
print('cnt: ',cnt)
img = to_categorical(mr_label, class_num)

(16777216,)
label:  [0 1 2 3 4 5 6 7 8]
cnt:  [15494952   312719   110267    11667   245762   190593   187564   162053
    61639]


IndexError: index 8 is out of bounds for axis 1 with size 8

### 아래 이미지는 channel 차원을 추가하기 전 이미지

In [None]:
f, ax = plt.subplots(5,8, figsize=(25,20))
img_idx = 0
pad_idx = 0
for i in range(5*8):
    if (i%8)%2==0:
        ax[i//8,i%8].imshow(ct_images[img_idx].get_data()[256,:,:])
        ax[i//8,i%8].axis('off')
        img_idx+=1
    else:
        mask = np.argmax(ct_pad_images[pad_idx], axis=-1)
        ax[i//8,i%8].imshow(mask[128,:,:])
        ax[i//8,i%8].axis('off')
        pad_idx+=1

In [45]:
mask = ct_labels[0].get_data()
mask.shape

(512, 512, 363)

In [46]:
resized_mask = resize_image_with_crop_or_pad(mask, [256, 256, 256], mode='symmetric')

  res = super(memmap, self).__getitem__(index)


In [None]:
print(ct_pad_images[0].shape)
img = nib.Nifti1Image(ct_pad_images[0], affine=np.eye(4))
nib.save(img,'../dataset/ct_imgs.nii')

In [None]:
mask = ct_labels[0].get_data()
mask.shape

In [None]:
resized_mask = resize_image_with_crop_or_pad(mask, [256, 256, 256], mode='symmetric')

In [None]:
print(resized_mask.shape)
img = nib.Nifti1Image(resized_mask, affine=np.eye(4))
nib.save(img,'../dataset/mr_imgs.nii')

In [None]:
print(ct_pad_labels[0].shape)
img = nib.Nifti1Image(ct_pad_labels[0], affine=np.eye(4))
nib.save(img,'../dataset/ct_labels.nii')

In [None]:
print(mr_pad_labels[0].shape)
img = nib.Nifti1Image(mr_pad_labels[0], affine=np.eye(4))
nib.save(img,'../dataset/mr_labels.nii')

# Test image

In [None]:
ct_test_dir = './dataset/ct_train_test/ct_test/'
ct_test_list = os.listdir(ct_test_dir)
ct_test_pad_images = list()

mr_test_dir = './dataset/mr_train_test/mr_test/'
mr_test_list = os.listdir(mr_test_dir)
mr_test_pad_images = list()

In [None]:
ct_test_images = list()
for fn in ct_test_list:
    img_dir = os.path.join(ct_test_dir,fn)
    img_fn = os.listdir(img_dir)[0]
    im = nib.load(os.path.join(img_dir, img_fn))
    ct_test_images.append(im)
    
mr_test_images = list()
for fn in mr_test_list:
    img_dir = os.path.join(mr_test_dir,fn)
    img_fn = os.listdir(img_dir)[0]
    im = nib.load(os.path.join(img_dir, img_fn))
    mr_test_images.append(im)

In [None]:
print('='*100)
print('CT')
print('='*100)
for img in ct_test_images:
    print(img.shape)
    
print('='*100)
print('MR')
print('='*100)
for img in mr_test_images:
    print(img.shape)

In [None]:
f, ax = plt.subplots(1,5,figsize=(30,10))
for i in range(5):
    ax[i].imshow(ct_test_images[i].get_data()[256,:,:])
    ax[i].axis('off')

In [None]:
ct_test_pad_images = image_preprocess(ct_test_images)
mr_test_pad_images = mask_preprocess(mr_test_images)