In [None]:
import numpy as np
from glob import glob
from tqdm import tqdm
import h5py
import nrrd
import pandas as pd

output_size =[128, 128, 64]

# 数组替换元素
def array_replace(array,olds,news):
    #olds:list of old value
    #news:list of new value
    olds = np.array(olds)
    news = np.array(news)
    offset = olds.max()*10
    tmps = olds+offset
    for old,tmp in zip(olds,tmps):
        array[array==old] = tmp
    for tmp,new in zip(tmps,news):
        array[array==tmp] = new
    return array
    
    
def get_stats(glob_str, old_replaced, new_replaced):
    """
    获取数据集的统计信息。
    备注：不要骨头，骨头合并到背景类别中
    """
    listt = glob(glob_str)
    error_samples = []
    stats = pd.DataFrame(columns=['sample_name',
                                  'space0','space1','space2',
                                  'mean_whole', 
                                  'mean_bg', 
                                  'mean_dura', 
                                  'mean_SC', 
                                  'std_whole',
                                  'std_bg',
                                  'std_dura',
                                  'std_SC',
                                 ])
    for item in tqdm(listt):
        sample_name = item.split('/')[-2]
        print(sample_name,':')#win系统改为'\\'
        
        image, img_header = nrrd.read(item)
        label, label_header = nrrd.read(item.replace(old_replaced, 'Segmentation-label.nrrd'))
        seg, seg_header = nrrd.read(item.replace(old_replaced, 'Segmentation.seg.nrrd'))
        
        space = np.diagonal(label_header['space directions'])
        
        offset=[]
        for k in seg_header['Segmentation_ReferenceImageExtentOffset'].split():
            offset += [int(k)]
        
        print(offset)
        
        sizes = seg_header['sizes'][1::]
        image = image[offset[0]:offset[0]+sizes[0],
                      offset[1]:offset[1]+sizes[1],
                      offset[2]:offset[2]+sizes[2]]
        label = label[offset[0]:offset[0]+sizes[0],
                      offset[1]:offset[1]+sizes[1],
                      offset[2]:offset[2]+sizes[2]].astype(np.uint8)
        
        # 错误病例：标记的尺寸和image尺寸不同，缺少其中一个类别或者多个类别的标记
        if not np.unique(label).tolist()==[0,1,2,3]:
            error_samples.append(sample_name)
            print("error sample(no dura/SC):",sample_name)
            continue
        
        if not image.shape==seg.shape[1::]:
            error_samples.append(sample_name)
            print("error sample(shape mismatch):",sample_name)
            continue
            
        # 类别名称和顺序
        target_name = ['dura','bone','SC']#目标类别顺序
        label_name = [
            seg_header['Segment0_Name'],
            seg_header['Segment1_Name'],
            seg_header['Segment2_Name'] 
            ]#人工标注的类别顺序
        ## 调整顺序，注意：seg是onehot编码
        idx = [label_name.index(name) for name in target_name]
        idx = np.array(idx)+1
        idx = [0]+idx.tolist()
        label = array_replace(label,olds=[0,1,2,3],news=idx)
        
        # 合并骨头到背景中
        label[label==2] = 0
        label[label==3] = 2
        
        # 统计信息
        
        # std
        std_whole = np.std(image)
        std_bg = np.std(image[label==0])
        std_dura = np.std(image[label==1])
        std_SC = np.std(image[label==2])
        # mean
        mean_whole = np.mean(image)
        mean_bg = np.mean(image[label==0])
        mean_dura = np.mean(image[label==1])
        mean_SC = np.mean(image[label==2])
        
        stats.loc[sample_name]=[sample_name,
                                space[0],space[1],space[2],
                                mean_whole,mean_bg,mean_dura,mean_SC,
                                std_whole,std_bg,std_dura,std_SC,]   
        
    return error_samples,stats

def covert_h5_unseg(glob_str, old_replaced, new_replaced):
    """
    备注：无标注数据的格式转换
    """
    listt = glob(glob_str)
    for item in tqdm(listt):
        print(item.split('/')[-2],':')
        # 读取原始图
        image, img_header = nrrd.read(item) 
        
#         # 缩小图像
#         image = image[0:-1:2,0:-1:2,:]

#         # 标准化
#         image = (image - np.mean(image)) / np.std(image)
#         image = image.astype(np.float32)
#         print(image.shape)
        
        f = h5py.File(item.replace(old_replaced, new_replaced), 'w')
        f.create_dataset('image', data=image, compression="gzip")
        f.close()  
    print("total number of unseg-samples:", len(listt))

    
if __name__ == '__main__':
    # 有标签数据
    print('seg dataset:')
    glob_str = '../../data/CTM_dataset/Segmented/*/CTM.nrrd'
    error_samples,stats = get_stats(glob_str,'CTM.nrrd','mri_norm2.h5')
    display(stats)
    stats.to_csv('../../data/CTM_dataset/Segmented/stats.csv')
#     # 无标签数据
#     print('unseg dataset:')
#     glob_str = '../../data/CTM_dataset/unSegmented/*/CTM.nrrd'
#     covert_h5_unseg(glob_str,'CTM.nrrd','mri_norm2.h5')
#     glob_str = '../../data/CTM_dataset/unSegmented/*/CT-vol.nrrd'
#     covert_h5_unseg(glob_str,'CT-vol.nrrd','mri_norm2.h5')      
#     error_samples


  0%|          | 0/181 [00:00<?, ?it/s]

seg dataset:
845929 :
> <ipython-input-1-bcc3d39615b1>(56)get_stats()
-> space = np.diagonal(label_header['space directions'])
(Pdb) c
[0, 126, 0]


  1%|          | 1/181 [00:04<14:41,  4.90s/it]

1559027 :
> <ipython-input-1-bcc3d39615b1>(52)get_stats()
-> pdb.set_trace()
(Pdb) c
[0, 32, 0]


  1%|          | 2/181 [00:08<13:09,  4.41s/it]

1433611 :
> <ipython-input-1-bcc3d39615b1>(56)get_stats()
-> space = np.diagonal(label_header['space directions'])
(Pdb) c
[0, 127, 0]


  2%|▏         | 3/181 [00:10<11:11,  3.77s/it]

B1027338 :
> <ipython-input-1-bcc3d39615b1>(52)get_stats()
-> pdb.set_trace()
(Pdb) c
[0, 15, 0]


  2%|▏         | 4/181 [00:14<11:17,  3.82s/it]

936932-dingzi :
> <ipython-input-1-bcc3d39615b1>(56)get_stats()
-> space = np.diagonal(label_header['space directions'])


In [None]:
import SimpleITK
dir(SimpleITK)
asd = SimpleITK.ReadImage('../../data/CTM_dataset/Segmented/845929/CTM.nrrd')
# dir(asd)
asd.GetSpacing()

In [None]:
# nrrd.read(item.replace(old_replaced, 'Segmentation-label.nrrd'))
label = SimpleITK.ReadImage('../../data/CTM_dataset/Segmented/845929/Segmentation-label.nrrd')
label.GetSpacing()
dir(label)

In [None]:
seg = SimpleITK.ReadImage('../../data/CTM_dataset/Segmented/845929/Segmentation.seg.nrrd')
seg.GetSpacing()
dir(seg)
seg.GetMetaDataKeys()

In [None]:
stats['std_SC']/stats['std_whole'].mean()