In [1]:
import os
import glob
import random
import shutil
import numpy as np
from PIL import Image
import nibabel as nib
import SimpleITK as sitk
from scipy import ndimage
import matplotlib.pyplot as plt

print (os.path.abspath('.'))

/Users/lyraaaaa/Desktop


**t1 / t2 / flair / t1ce**  = four modalities of MRI images

Each of them are a list of t1/t2/flair/t1ce MRI images (210 HGG, 75 LGG)

each image has a shape of (240, 240, 150) as (x, y, z)

x & y 组成大脑的水平横切面，z 代表横切面的 slice

sitk 读出来的 output 是 (z, x, y) 形式

In [None]:
t1 = glob.glob(r'./BraTS_2018/*/*/*t1.nii.gz')
t2 = glob.glob(r'./BraTS_2018/*/*/*t2.nii.gz')
flair = glob.glob(r'./BraTS_2018/*/*/*flair.nii.gz')
t1ce = glob.glob(r'./BraTS_2018/*/*/*t1ce.nii.gz')
seg = glob.glob(r'./BraTS_2018/*/*/*seg.nii.gz')
csv = glob.glob(r'./BraTS_2018/*.csv')

def read_img(img_path):
    return sitk.GetArrayFromImage(sitk.ReadImage(img_path))

print('No. of images =',len(seg))
print('Shape of each image =',(read_img(seg[0]).shape))

plt.figure (figsize=(12,12))

plt.subplot(1,2,1)
img = (read_img(t1[0])[60]).astype(np.uint8)
plt.imshow(img)

plt.subplot(1,2,2)
img = (read_img(seg[0])[60]).astype(np.uint8)
plt.imshow(img)

In [None]:
###### Training sample 1 ----- Slice every training MRI

dataset_path = 'sample_dataset/train/'

def process(img_p, label_p):
    data = nib.load(img_p).get_fdata()
    labels = nib.load(label_p).get_fdata()
    data = data/np.amax(data)
    data = data.transpose(2, 0, 1)[:,8:232,8:232]   # crop img to (224,224)
    labels = labels.transpose(2, 0, 1)[:,8:232,8:232]
    labels[np.where(labels==4)]=3
    img_name = (img_p.split('/')[-1]).split('_')
    num = img_name[2]
    modality = img_name[-1].split('.')[0]

    print(labels.shape)
    count = 0

    for s in range(len(data)):          # for each slice in a 3D MRI
        image = data[s]
        if np.any(image):    # if the slice has content (≠ 0)
            label = labels[s]       # get coresponding lebal for the slice
            count += 1

            name =  "{}_{}_{}".format(num, modality, str(s))
            save_dir = os.path.join('sample_dataset/train_npz/', name + '.npz') 
            np.savez(save_dir,image=image,label=label)       # save as npz
            with open('sample_dataset/train.txt', 'a') as f:
                f.write(name + '\n')
    print(count)


for patient in os.listdir(dataset_path):
    if not patient.startswith('Brats'):
        continue
    path = os.path.join(dataset_path, patient)
    mri_paths, label_path = [],[]
    for modality in os.listdir(path):
        if modality.endswith('seg.nii.gz'):
            label_path = os.path.join(path, modality)
        elif modality.endswith('nii.gz'):
            mri_paths.append(os.path.join(path, modality))

    for mri in mri_paths:
        process(mri,label_path)
        print(patient)


In [None]:
###### Training sample 2 ----- Increase proportion of positive cases

dataset_path = 'sample_dataset/train/'

def write(case, modality, slice, image, label):
    name =  "{}_{}_{}".format(case, modality, str(slice))
    save_dir = os.path.join('sample_dataset/train_npz/', name + '.npz') 
    np.savez(save_dir,image=image,label=label)       # save as npz
    with open('sample_dataset/train.txt', 'a') as f:
        f.write(name + '\n')
            
def process(img_p, label_p):
    data = nib.load(img_p).get_fdata()
    labels = nib.load(label_p).get_fdata()
    data = data/np.amax(data)
    data = data.transpose(2, 0, 1)[:,8:232,8:232]   # crop img to (224,224)
    labels = labels.transpose(2, 0, 1)[:,8:232,8:232]
    labels[np.where(labels==4)]=3
    img_name = (img_p.split('/')[-1]).split('_')
    case = img_name[2]
    modality = img_name[-1].split('.')[0]

    #print(data.shape)
    count = 0

    for s in range(len(labels)):          # for each slice in a 3D MRI
        label = labels[s]
        image = data[s]
        if np.any(label):                 # if the label is not background (class 0)
            write(case, modality, s, image, label)
            count += 1
        elif np.any(image):               # if the label is background but image has content (class 0)
            if random.random() < 0.55:
                write(case, modality, s, image, label)
                count += 1
    print(count)

for patient in os.listdir(dataset_path):
    if not patient.startswith('Brats'):
        continue
    path = os.path.join(dataset_path, patient)
    mri_paths, label_path = [],[]
    for modality in os.listdir(path):
        if modality.endswith('seg.nii.gz'):
            label_path = os.path.join(path, modality)
        elif modality.endswith('nii.gz'):
            mri_paths.append(os.path.join(path, modality))

    for mri in mri_paths:
        process(mri,label_path)
    print(patient)

In [None]:
###### Training sample 3 ----- increase dataset

dataset_path = 'sample_dataset/train/'
count = [0,0]

def write(case, modality, slice, image, label):
    name =  "{}_{}_{}".format(case, modality, str(slice))
    save_dir = os.path.join('sample_dataset/train_npz/', name + '.npz') 
    np.savez(save_dir,image=image,label=label)       # save as npz
    with open('sample_dataset/train.txt', 'a') as f:
        f.write(name + '\n')
            
def process(img_p, label_p, count):
    data = nib.load(img_p).get_fdata()
    labels = nib.load(label_p).get_fdata()
    data = data.transpose(2, 0, 1)[:,8:232,8:232]           # crop img to (224,224)
    labels = labels.transpose(2, 0, 1)[:,8:232,8:232]
    labels[np.where(labels==4)]=3
    # data = data/np.amax(data)                             # min-max normalisation
    data = (data - np.mean(data)) / np.std(data)            # z-score normalisation
    data[np.where(data<0)]=0
    img_name = (img_p.split('/')[-1]).split('_')
    case = img_name[2]
    modality = img_name[-1].split('.')[0]
    # tumor area: ED > ET >> NET
    
    for s in range(len(labels)):          # for each slice in a 3D MRI
        label = labels[s]
        image = data[s]
        if np.any(label):                 # if the label is not background (class 0)
            if random.random() < 0.55:
                write(case, modality, s, image, label)
                count[0] += 1
        elif np.any(image):               # if the label is background but image has content (class 0)
            if random.random() < 0.22:
                write(case, modality, s, image, label)
                count[1] += 1
    return count

for patient in os.listdir(dataset_path):
    if not patient.startswith('Brats'):
        continue
    path = os.path.join(dataset_path, patient)
    mri_paths, label_path = [],[]
    for modality in os.listdir(path):
        if modality.endswith('seg.nii.gz'):
            label_path = os.path.join(path, modality)
        elif modality.endswith('nii.gz'):
            mri_paths.append(os.path.join(path, modality))

    for mri in mri_paths:
        count = process(mri,label_path, count)

    print(patient)
print(count)

In [2]:
###### Convert testing samples

dataset_path = 'sample_dataset/test/'

def test_process(img_p, label_p):
    data = nib.load(img_p).get_fdata()
    labels = nib.load(label_p).get_fdata()
    data = data/np.amax(data)
    data = data.transpose(2, 0, 1)[:,8:232,8:232]   # crop img to (240,240)
    labels = labels.transpose(2, 0, 1)[:,8:232,8:232]
    labels[np.where(labels==4)]=3
    img_name = (img_p.split('/')[-1]).split('_')
    num = img_name[2]
    modality = img_name[-1].split('.')[0]

    print(labels.shape)

    for s in range(len(data)-1, -1, -1):          # for each slice in a 3D MRI
        image = data[s]
        if not np.any(image):           # if the slice has no content (≠ 0)
            data = np.delete(data, s, axis=0)
            labels = np.delete(labels, s, axis=0)
            print(data.shape)
                
    name =  "{}_{}".format(num, modality)
    save_dir = os.path.join('sample_dataset/test_vol_h5/', name + '.npz') 
    np.savez(save_dir,image=data,label=labels)       # save as npz
    with open('sample_dataset/test_vol.txt', 'a') as f:
        f.write(name + '\n')
    
test_process(mri,label_path)

(155, 224, 224)
(154, 224, 224)
(153, 224, 224)
(152, 224, 224)
(151, 224, 224)
(150, 224, 224)
(149, 224, 224)
(148, 224, 224)
(147, 224, 224)
(146, 224, 224)
(145, 224, 224)
(144, 224, 224)
(143, 224, 224)
(142, 224, 224)
(141, 224, 224)
(140, 224, 224)
(139, 224, 224)
(138, 224, 224)
(137, 224, 224)
Brats18_TCIA09_312_1
(155, 224, 224)
(154, 224, 224)
(153, 224, 224)
(152, 224, 224)
(151, 224, 224)
(150, 224, 224)
(149, 224, 224)
(148, 224, 224)
(147, 224, 224)
(146, 224, 224)
(145, 224, 224)
(144, 224, 224)
(143, 224, 224)
(142, 224, 224)
(141, 224, 224)
(140, 224, 224)
(139, 224, 224)
(138, 224, 224)
(137, 224, 224)
Brats18_TCIA09_312_1
(155, 224, 224)
(154, 224, 224)
(153, 224, 224)
(152, 224, 224)
(151, 224, 224)
(150, 224, 224)
(149, 224, 224)
(148, 224, 224)
(147, 224, 224)
(146, 224, 224)
(145, 224, 224)
(144, 224, 224)
(143, 224, 224)
(142, 224, 224)
(141, 224, 224)
(140, 224, 224)
(139, 224, 224)
(138, 224, 224)
(137, 224, 224)
Brats18_TCIA09_312_1
(155, 224, 224)
(154, 224

In [138]:
def random_rot_flip(image):
    angle = np.random.randint(-25, 25)
    image = ndimage.rotate(image, angle, order=0, reshape=False)
    k = np.random.randint(0, 4)
    image = np.rot90(image, k).copy()
    axis = np.random.randint(0, 2)
    image = np.flip(image, axis=axis).copy()
    return image

In [146]:
data_path = 'sample_dataset/train/Brats18_TCIA05_444_1/Brats18_TCIA05_444_1_t1.nii.gz'
labels = nib.load(data_path).get_fdata()
labels = labels.transpose(2, 0, 1)

p = labels[:,8:232,8:232]
print(labels.size())
img = ndimage.rotate(p[75], 25, order=0, reshape=False)
img_roate90 = np.rot90(p[75], 3)
img_flip = np.flip(p[75], axis=1).copy()
after = random_rot_flip(p[75])
print(p.shape)

plt.figure(figsize=(15,15))
plt.subplot(1,4,1)
plt.imshow(p[75])
plt.subplot(1,4,2)
plt.imshow(img)
plt.subplot(1,4,3)
plt.imshow(img_roate90)
plt.subplot(1,4,4)
plt.imshow(after)

TypeError: 'int' object is not callable

In [13]:
import numpy as np
a=np.array([[1,2],[3,4]])
b=np.arange(4)
c=np.arange(5)
np.savez('array_save.npz',img=a,seg = b,c_array=c)
A=np.load('array_save.npz')
print(A['img'])
print(A['seg'])
print(A['c_array'])

In [None]:
# Split training and validation samples

path = 'sample_dataset/'
traget_path = 'target/'

def slice_MRI(dataset_path):
    t1 = glob.glob(r'./sample_dataset/*/*t1.nii.gz')
    read_img(t1)

    file_no = len(os.listdir(dataset_path))
    for i in range(file_no):
        file = os.listdir(dataset_path)[i]
        if file != '.DS_Store':
            old_path = os.path.join(dataset_path, file)
            if i < file_no*split_rate:
                with open('Splited BraTS/valid.txt', 'a') as f:
                    f.write(file + '\n')
                new_path = os.path.join(traget_path + 'Valid/', file + '/') 
            else:
                with open('Splited BraTS/train.txt', 'a') as f:
                    f.write(file + '\n')
                new_path = os.path.join(traget_path + 'Train/', file + '/') 

            shutil.copytree(old_path,new_path)

train_valid_split(path_HGG)
train_valid_split(path_LGG)