In [None]:
import numpy as np
from glob import glob
from tqdm import tqdm
import h5py
import nrrd
import os
import pandas as pd
from dataset_split import remove_files
import SimpleITK as sitk
from skimage import transform

output_size =[128, 128, 64]

def resample_image3D(
    image3D,
    newspacing=[0.3,0.3,3],
    newsize=None,
    method='Linear',):
    """做插值"""
    resample = sitk.ResampleImageFilter()
    if method is 'Linear':
        resample.SetInterpolator(sitk.sitkLinear)
    elif method is 'Nearest':
        resample.SetInterpolator(sitk.sitkNearestNeighbor)
    resample.SetOutputDirection(image3D.GetDirection())
    resample.SetOutputOrigin(image3D.GetOrigin())
    resample.SetOutputSpacing(newspacing)

    if not newsize:
        newsize = np.round(np.array(image3D.GetSize())*np.abs(image3D.GetSpacing())/np.array(newspacing)).astype('int').tolist()

    resample.SetSize(newsize)
    # resample.SetDefaultPixelValue(0)

    newimage = resample.Execute(image3D)
    return newimage

# 数组替换元素
def array_replace(array,olds,news):
    #olds:list of old value
    #news:list of new value
    olds = np.array(olds)
    news = np.array(news)
    offset = olds.max()*10
    tmps = olds+offset
    for old,tmp in zip(olds,tmps):
        array[array==old] = tmp
    for tmp,new in zip(tmps,news):
        array[array==tmp] = new
    return array

def covert_h5(glob_str, old_replaced, new_replaced):
    """
    备注：不要骨头，骨头合并到背景类别中
    """
    listt = glob(glob_str)
    error_samples = []
    error_samples_origin = []
    stats = pd.DataFrame(columns=['sample_name',
                                  'mean_whole', 
                                  'mean_bg', 
                                  'mean_dura', 
                                  'mean_SC', 
                                  'std_whole',
                                  'std_bg',
                                  'std_dura',
                                  'std_SC',
                                  'old_space0','old_space1','old_space2',
                                  'new_space0','new_space1','new_space2',
                                 ])
    for item in tqdm(listt):
        sample_name = item.split('/')[-2]
        print(sample_name,':')#win系统改为'\\'

        image = sitk.ReadImage(item)
        label = sitk.ReadImage(item.replace(old_replaced, 'Segmentation-label.nrrd'))
        seg = sitk.ReadImage(item.replace(old_replaced, 'Segmentation.seg.nrrd'))
        
        oldspacing = np.abs(image.GetSpacing())
        newspacing = [0.3, 0.3, 3.0]
#         offset = [int(k) for k in seg.GetMetaData('Segmentation_ReferenceImageExtentOffset').split()]
#         offset = np.round(np.array(offset)*np.array(oldspacing)/np.array(newspacing)).astype(int).tolist()
        label_name = [
            seg.GetMetaData('Segment0_Name'),
            seg.GetMetaData('Segment1_Name'),
            seg.GetMetaData('Segment2_Name') 
            ]#人工标注的类别顺序
        
        # resample
#         image = resample_image3D(image,newspacing,method='Linear')
#         label = resample_image3D(label,newspacing,method='Nearest')##,newsize=image.GetSize()
#         seg = resample_image3D(seg,newspacing)
        
        image = sitk.GetArrayFromImage(image).transpose((2,1,0))#tanspose之后才能与sizes匹配
        label = np.round( sitk.GetArrayFromImage(label) ).transpose((2,1,0))#tanspose之后才能与sizes匹配
#         seg = sitk.GetArrayFromImage(seg).transpose((2,1,0,3))#tanspose之后才能与sizes匹配
        if not image.shape == label.shape:
            error_samples_origin.append(sample_name)
            print("error sample(mismatch shape of image and label):",sample_name)
            continue
        
        import pdb
        pdb.set_trace()
        # rescale
        ratio = (np.array(newspacing)/np.array(oldspacing)).tolist() 
        image = transform.rescale(image,ratio,order=1,anti_aliasing=True,preserve_range=True,multichannel=False)
        label = np.round(transform.rescale(label,ratio,order=0,anti_aliasing=True,preserve_range=True,multichannel=False)).astype(int)
        
#         seg = transform.rescale(seg,ratio,order=1,preserve_range=True,multichannel=True)
#         sizes = seg.GetSize()#resample后seg的size
        
        # 灰度标准化
        image = (image - np.mean(image)) / np.std(image)
        image = image.astype(np.float32)
        print("image shape:",image.shape,"label shape:",label.shape)
                
        
#         # offset
#         image = image[offset[0]:offset[0]+sizes[0],
#                       offset[1]:offset[1]+sizes[1],
#                       offset[2]:offset[2]+sizes[2]]
#         label = label[offset[0]:offset[0]+sizes[0],
#                       offset[1]:offset[1]+sizes[1],
#                       offset[2]:offset[2]+sizes[2]].astype(np.uint8)
        
        # 错误病例：标记的尺寸和image尺寸不同，缺少其中一个类别或者多个类别的标记
        print("np.unique(label):",np.unique(label))
        if not np.unique(label).tolist()==[0,1,2,3]:
            error_samples.append(sample_name)
            print("error sample(no dura/SC):",sample_name)
            continue
        
        #if not np.abs(np.array(image.shape)-np.array(seg.shape[1::])).sum()<=1:
        if not image.shape == label.shape:
            error_samples.append(sample_name)
            print("error sample(mismatch shape of image and label):",sample_name)
            continue
            
        
        ## 调整类别顺序，注意：seg是onehot编码
        target_name = ['dura','bone','SC']#目标类别顺序
        idx = [label_name.index(name) for name in target_name]
        idx = np.array(idx)+1
        idx = [0]+idx.tolist()
        label = array_replace(label,olds=[0,1,2,3],news=idx)
        
        # 合并骨头到背景中
        label[label==2] = 0
        label[label==3] = 2
        
        print("uncut image.shape:",image.shape,"uncut label.shape:",label.shape)
        assert np.unique(label).tolist()==[0,1,2], print('pixel classes are not [0,1,2]') 
        
        # cut
        
        
        minx, maxx = np.min(tempL[0]), np.max(tempL[0])
        miny, maxy = np.min(tempL[1]), np.max(tempL[1])
        minz, maxz = np.min(tempL[2]), np.max(tempL[2])
        w, h, d = label.shape
        px = max(output_size[0] - (maxx - minx), 0) // 2
        py = max(output_size[1] - (maxy - miny), 0) // 2
        minx = max(minx - np.random.randint(10, 20) - px, 0)
        maxx = min(maxx + np.random.randint(10, 20) + px, w)
        miny = max(miny - np.random.randint(10, 20) - py, 0)
        maxy = min(maxy + np.random.randint(10, 20) + py, h)
        
        image = image[minx:maxx, miny:maxy]
        label = label[minx:maxx, miny:maxy]
        print("cut image.shape:",image.shape)
        print("cut label.shape:",label.shape)
        
#         f = h5py.File(item.replace(old_replaced, new_replaced), 'w')
#         f.create_dataset('image', data=image, compression="gzip")
#         f.create_dataset('label', data=label, compression="gzip")
#         f.close()
    print("total number of seg-samples:", len(listt))
    return error_samples, error_samples_origin

def covert_h5_unseg(glob_str, old_replaced, new_replaced):
    """
    备注：无标注数据的格式转换
    """
    listt = glob(glob_str)
    for item in tqdm(listt):
        sample_name = item.split('/')[-2]
        print(sample_name,':')#win系统改为'\\'
        image = sitk.ReadImage(item)
        
        # resample
        newspacing = [0.3, 0.3, 3.0]
        image = resample_image3D(image,newspacing,method='Linear')
        image = sitk.GetArrayFromImage(image).transpose((2,1,0))#tanspose之后才能与sizes匹配
        
        # 灰度标准化
        image = (image - np.mean(image)) / np.std(image)
        image = image.astype(np.float32)
        print("image shape:",image.shape)
        
        f = h5py.File(item.replace(old_replaced, new_replaced), 'w')
        f.create_dataset('image', data=image, compression="gzip")
        f.close()  
    print("total number of unseg-samples:", len(listt))

    
if __name__ == '__main__':
    # 有标签数据
    print('seg dataset:')
#     ## 先删除旧文件
#     dataset_dir = '../../data/CTM_dataset/Segmented'
#     re = os.path.join(dataset_dir,'*/mri_norm2.h5')
#     remove_files(re=re)
    ## 再生成新文件
    glob_str = '../../data/CTM_dataset/Segmented/*/CTM.nrrd'
    error_samples,error_samples_origin = covert_h5(glob_str,'CTM.nrrd','mri_norm2.h5')
    
#     # 无标签数据
#     print('unseg dataset:')
#     ## 先删除旧文件
#     dataset_dir = '../../data/CTM_dataset/unSegmented'
#     re = os.path.join(dataset_dir,'*/mri_norm2.h5')
#     remove_files(re=re)
#     ## 再生成新文件
#     glob_str = '../../data/CTM_dataset/unSegmented/*/CTM.nrrd'
#     covert_h5_unseg(glob_str,'CTM.nrrd','mri_norm2.h5')
#     glob_str = '../../data/CTM_dataset/unSegmented/*/CT-vol.nrrd'
#     covert_h5_unseg(glob_str,'CT-vol.nrrd','mri_norm2.h5')      


  0%|          | 0/181 [00:00<?, ?it/s]

seg dataset:
845929 :
> <ipython-input-1-eebeb4f94cf2>(105)covert_h5()
-> ratio = (np.array(newspacing)/np.array(oldspacing)).tolist()
(Pdb) image.shape
(512, 512, 92)
(Pdb) label.shape
(512, 512, 92)
(Pdb) n
> <ipython-input-1-eebeb4f94cf2>(106)covert_h5()
-> image = transform.rescale(image,ratio,order=1,anti_aliasing=True,preserve_range=True,multichannel=False)
(Pdb) n
> <ipython-input-1-eebeb4f94cf2>(107)covert_h5()
-> label = np.round(transform.rescale(label,ratio,order=0,anti_aliasing=True,preserve_range=True,multichannel=False)).astype(int)
(Pdb) n
> <ipython-input-1-eebeb4f94cf2>(113)covert_h5()
-> image = (image - np.mean(image)) / np.std(image)
(Pdb) index_cut=np.where(label==1)
(Pdb) index_cut[0].max(),index_cut[0].min()
(327, 0)
(Pdb) index_cut[1].max(),index_cut[1].min()
(324, 81)
(Pdb) index_cut[2].max(),index_cut[2].min()
(91, 0)
(Pdb) help(transform.rescale)
*** No help for '(transform.rescale)'
(Pdb) ratio
[0.641025641025641, 0.641025641025641, 1.0]
(Pdb) index_cut
(arr

# 数据集划分&生成*.list文件

In [None]:
from dataset_split import dataset_split, make_dataset_list
# 有标签数据
save_dir = '../../data/CTM_dataset'
dataset_dir = '../../data/CTM_dataset/Segmented'
list_traratioin,list_test = dataset_split(path=dataset_dir,save_dir=save_dir)
# 无标签数据
dataset_dir = '../../data/CTM_dataset/unSegmented'
make_dataset_list(path=dataset_dir,save_dir=save_dir)

In [None]:
error_samples

In [None]:
list_train

In [None]:
#845929 cut前后的shape变化不大