血的教训:
1. 对mask做插值(rescale,resize),一定且只能在onehot编码上操作
2. 不能用skimage,一定要用sitk

In [5]:
import numpy as np
from glob import glob
from tqdm import tqdm
import h5py
import nrrd
import os
import pandas as pd
from dataset_split import remove_files
import SimpleITK as sitk
from skimage import transform
from collections import Counter
import matplotlib.pyplot as plt
from tensorflow.keras.utils import to_categorical

output_size =[128, 128, 64]

def resample_image3D(
    image3D,
    newspacing=[0.3,0.3,3],
    newsize=None,
    method='Linear',):
    """做插值"""
    resample = sitk.ResampleImageFilter()
    if method == 'Linear':
        resample.SetInterpolator(sitk.sitkLinear)
    elif method == 'Nearest':
        resample.SetInterpolator(sitk.sitkNearestNeighbor)
    resample.SetOutputDirection(image3D.GetDirection())
    resample.SetOutputOrigin(image3D.GetOrigin())
    resample.SetOutputSpacing(newspacing)

    if not newsize:
        newsize = np.round(np.array(image3D.GetSize())*np.abs(image3D.GetSpacing())/np.array(newspacing)).astype('int').tolist()

    resample.SetSize(newsize)
    # resample.SetDefaultPixelValue(0)

    newimage = resample.Execute(image3D)
    return newimage

def sitk_onehot_transform(image):
    image_array = sitk.GetArrayFromImage(image)
    label_array_onehot = to_categorical(image_array)
    image_onehot = sitk.GetImageFromArray(label_array_onehot)
    image_onehot.SetOrigin(image.GetOrigin())
    image_onehot.SetDirection(image.GetDirection())
    image_onehot.SetSpacing(image.GetSpacing())
    return image_onehot

# 数组替换元素
def array_replace(array,olds,news):
    # 不适用于onehot
    #olds:list of old value
    #news:list of new value
    olds = np.array(olds)
    news = np.array(news)
    offset = olds.max()*10
    tmps = olds+offset
    array += offset
    for tmp,new in zip(tmps,news):
        array[array==tmp] = new
    return array

def covert_h5(glob_str, old_replaced, new_replaced):
    """
    备注：不要骨头，骨头合并到背景类别中
    """
    listt = glob(glob_str)
    error_samples = []
    error_samples_origin = []
    stats = pd.DataFrame(columns=['sample_name',
                                  'mean_whole', 
                                  'mean_bg', 
                                  'mean_dura', 
                                  'mean_SC', 
                                  'std_whole',
                                  'std_bg',
                                  'std_dura',
                                  'std_SC',
                                  'old_space0','old_space1','old_space2',
                                  'new_space0','new_space1','new_space2',
                                 ])
    for item in tqdm(listt):
        sample_name = item.split('/')[-2]
        print(sample_name,'||')#win系统改为'\\'
#         if not sample_name == "1171704-neck":#B809338":#""1352900":#B809338":#"1756747":#1700637-neck":
#             continue
        
        # read image
        image = sitk.ReadImage(item)
        seg = sitk.ReadImage(item.replace(old_replaced, 'Segmentation.seg.nrrd'))
        label = sitk.ReadImage(item.replace(old_replaced, 'Segmentation-label.nrrd'))
        label_onehot = sitk_onehot_transform(label)
        
        label_name = [
            'bg',
            seg.GetMetaData('Segment0_Name'),
            seg.GetMetaData('Segment1_Name'),
            seg.GetMetaData('Segment2_Name') 
            ]#人工标注的类别顺序
        oldspacing = np.abs(image.GetSpacing())
        newspacing = [0.3, 0.3, 3.0]

        # resample/rescale( by sitk )
        image = resample_image3D(image,newspacing,method='Linear')
        label_onehot = resample_image3D(label_onehot,newspacing,method='Nearest')
        
        # get array
        image = sitk.GetArrayFromImage(image).transpose((2,1,0))#tanspose之后才能与sizes匹配
        label_onehot = np.round( sitk.GetArrayFromImage(label_onehot) ).transpose((2,1,0,3))#tanspose之后才能与sizes匹配
        label = np.argmax(label_onehot,axis=-1)
        plot_slice_sample(image,label,np.nonzero(label)[2].max(),item.replace(old_replaced,'slice_sample_origin.png'))
        
        
        if not image.shape == label_onehot.shape[:-1]:
            error_samples.append(sample_name)
            print("error sample(mismatch shape of image and label):",sample_name)
            continue

        if not label_onehot.sum(axis=-1).max()==1:
            # label onehot encoder可以解决这个问题
            error_samples.append(sample_name)
            print("error sample(some pixels in seg are multi-category at the same time):",sample_name)
            continue
        
        # 灰度标准化
        image = (image - np.mean(image)) / np.std(image)
        image = image.astype(np.float32)

        
        # 错误病例：标记的尺寸和image尺寸不同，缺少其中一个类别或者多个类别的标记
        if not label_onehot.shape[-1] == 4:
            error_samples.append(sample_name)
            print("error sample(no df/pf/fra):",sample_name)
            continue 
        if not (np.unique(label_onehot) == [0, 1]).all():
            error_samples.append(sample_name)
            print("error sample label file error:",sample_name)   
            continue
        
        ## 调整类别顺序&合并骨头到背景中，注意：是onehot编码
        target_name = ['bg','df','pf','fra']#目标类别顺序
        idx = [label_name.index(name) for name in target_name]
        assert len(idx)==4,'one or more classes missed'
        label_onehot = label_onehot[:,:,:,idx]

        ## bone归入背景类
        bg = label_onehot[:,:,:,[0,1]].sum(axis=-1)[:,:,:,np.newaxis]
        label_onehot = np.concatenate((bg,label_onehot[:,:,:,2:]),axis=-1)
        assert (np.unique(label_onehot) == [0, 1]).all(), "1: pixel class error"
        ## 转化为非onehot编码以便作图
        label = np.argmax(label_onehot, axis=-1)
        
        # cut( random center cut)
        tempL = np.nonzero(label)
        minx, maxx = np.min(tempL[0]), np.max(tempL[0])
        miny, maxy = np.min(tempL[1]), np.max(tempL[1])
        minz, maxz = np.min(tempL[2]), np.max(tempL[2])
        w, h, d = label.shape
        px = max(output_size[0] - (maxx-minx+1), 0) // 2
        py = max(output_size[1] - (maxy-miny+1), 0) // 2
        #pz = max(output_size[2] - (maxz-minz+1), 0) // 2
        minx = max(minx - np.random.randint(10, 20) - px, 0)
        maxx = min(maxx + np.random.randint(10, 20) + px, w-1)
        miny = max(miny - np.random.randint(10, 20) - py, 0)
        maxy = min(maxy + np.random.randint(10, 20) + py, h-1)
        #minz = max(minz - np.random.randint(10, 20) - pz, 0)
        #maxz = min(maxz + np.random.randint(10, 20) + pz, d)
        image = image[minx:maxx+1, miny:maxy+1, minz:maxz+1]
        label = label[minx:maxx+1, miny:maxy+1, minz:maxz+1]
        print("cut image.shape:",image.shape, "cut label.shape:",label.shape)
        plot_slice_sample(image,label,maxz-minz,item.replace(old_replaced,'slice_sample.png'))
        
        # save files
        f = h5py.File(item.replace(old_replaced, new_replaced), 'w')
        f.create_dataset('image', data=image, compression="gzip")
        f.create_dataset('label', data=label_onehot, compression="gzip")
        f.close()
    print("total number of samples:", len(listt))
    return error_samples, error_samples_origin

def plot_slice_sample(image,label,d,fn):
    fig = plt.figure()
    a = fig.add_subplot(1, 2, 1)
    imgplot = plt.imshow(image[:,:,d].squeeze())
    a.set_title('image')
    plt.colorbar(orientation='horizontal')
    a = fig.add_subplot(1, 2, 2)
    imgplot = plt.imshow(label[:,:,d].squeeze())
    imgplot.set_clim(0.0, 3.0)
    a.set_title('label')
    plt.colorbar(orientation='horizontal')
    plt.savefig(fn)
    plt.show()
    
def covert_h5_unseg(glob_str, old_replaced, new_replaced):
    """
    备注：无标注数据的格式转换
    """
    listt = glob(glob_str)
    for item in tqdm(listt):
        sample_name = item.split('/')[-2]
        print(sample_name,'||')#win系统改为'\\'
        image = sitk.ReadImage(item)
        
        # resample
        newspacing = [0.3, 0.3, 3.0]
        image = resample_image3D(image,newspacing,method='Linear')
        image = sitk.GetArrayFromImage(image).transpose((2,1,0))#tanspose之后才能与sizes匹配
        
        # 灰度标准化
        image = (image - np.mean(image)) / np.std(image)
        image = image.astype(np.float32)
        print("image shape:",image.shape)
        
        f = h5py.File(item.replace(old_replaced, new_replaced), 'w')
        f.create_dataset('image', data=image, compression="gzip")
        f.close()  
    print("total number of unseg-samples:", len(listt))

    
if __name__ == '__main__':
    # 有标签数据
    print('seg dataset:')
    ## 先删除旧文件
    dataset_dir = '../../data/gz_dataset/segmented'
    re = os.path.join(dataset_dir,'*/mri_norm2.h5')
    remove_files(re=re)
    ## 再生成新文件
    glob_str = '../../data/gz_dataset/segmented/*/CT.nrrd'
    error_samples,error_samples_origin = covert_h5(glob_str,'CT.nrrd','mri_norm2.h5')
    
    # 无标签数据
    print('unseg dataset:')
    ## 先删除旧文件
    dataset_dir = '../../data/gz_dataset/unSegmented'
    re = os.path.join(dataset_dir,'*/mri_norm2.h5')
    '''remove_files(re=re)
    ## 再生成新文件
    glob_str = '../../data/CTM_dataset/unSegmented/*/CTM.nrrd'
    covert_h5_unseg(glob_str,'CTM.nrrd','mri_norm2.h5')
    glob_str = '../../data/CTM_dataset/unSegmented/*/CT-vol.nrrd'
    covert_h5_unseg(glob_str,'CT-vol.nrrd','mri_norm2.h5')  '''    


ModuleNotFoundError: No module named 'nrrd'

# 数据集划分&生成*.list文件

In [6]:
from dataset_split import dataset_split, make_dataset_list
# 有标签数据(划分为两个数据集并生成列表)
save_dir = '../../data/CTM_dataset'
dataset_dir = '../../data/CTM_dataset/Segmented'
list_train_validatioin,list_test = dataset_split(path=dataset_dir,save_dir=save_dir)
# 无标签数据(不需要划分,直接生存列表即可)
dataset_dir = '../../data/CTM_dataset/unSegmented'
make_dataset_list(path=dataset_dir,save_dir=save_dir)

In [3]:
error_samples

['B1709234', '1409022-no SC', 'B1409022', 'B1334915-need revision', '1709234']