# Pre-process the ATM22 dataset

In [None]:
import os
import numpy as np
from skimage import io

from preprocessing import crop_CT_3D_image
import SimpleITK as sitk

## 1. Specify the path where ATM22 dataset resides

你可以根据ATM22 dataset所在的存储位置来更改 `ATM22_dataset_path` 的值，最好使用相对路径。

In [None]:
ATM22_dataset_path = "../Dataset/ATM22/"

## 2. Gather the raw CT scans and labels from ATM22 dataset

收集ATM22数据集中的所有CT扫描图片文件和标注图片文件，整理排序并检查。

In [None]:
# List the raw images for train set
trainset_raw_images = os.listdir(ATM22_dataset_path + "imagesTr/")
trainset_raw_images.sort()

# trainset_raw_images

In [None]:
# List the labels for train set
trainset_labels = os.listdir(ATM22_dataset_path + "labelsTr/")
trainset_labels.sort()

# trainset_labels

In [None]:
# List the raw images for validation set
validateset_raw_images = os.listdir(ATM22_dataset_path + "imagesVal/")
validateset_raw_images.sort()

# validateset_raw_images

In [None]:
assert len(trainset_raw_images) == len(trainset_labels)

for index, item in enumerate(trainset_raw_images):
    assert (item.split(".")[0] == trainset_labels[index].split(".")[0]), "Raw_CT_image[{0}] does not correspond to CT_image_label[{0}]".format(index)

for index, ct_filename in enumerate(trainset_raw_images):
    ct_image_file = ATM22_dataset_path + "imagesTr/" + ct_filename
    ct_label_file = ATM22_dataset_path + "labelsTr/" + ct_filename
    
    # print(ct_image_file, ct_label_file)
    ct_3d_image = io.imread(ct_image_file, plugin='simpleitk')
    ct_3d_label = io.imread(ct_label_file, plugin='simpleitk')
    print("{0}.shape = {1}, \n{2}.shape = {3}\n"
          .format(ct_image_file, ct_3d_image.shape, ct_label_file, ct_3d_label.shape))
    
    assert ct_3d_image.shape == ct_3d_label.shape
    
    del ct_3d_image
    del ct_3d_label

## 3. Collect the cases names for trainset and validateset, respectively.

In [None]:
trainset_case_names = []
for item in trainset_raw_images:
    filename = item.split(".")[0]
    case_name = filename[:-5]
    trainset_case_names.append(case_name)

len(trainset_case_names)

In [None]:
validate_case_names = []
for item in validateset_raw_images:
    filename = item.split(".")[0]
    case_name = filename[:-5]
    validate_case_names.append(case_name)

len(validate_case_names)

## 4.1. Crop each CT 3D image in the trainset into cubes

将训练集中的每一个CT 3D image按照`crop_cube_size`切割为众多小的立方体，有序保存为numpy .npy格式。

In [None]:
crop_cube_size = (128, 128, 128)
stride = (64, 64, 64)

current_trainset_path = "./ATM22_train/"
current_labelset_path = "./ATM22_label/"
current_validateset_path = "./ATM22_validate/"

if not os.path.exists(current_trainset_path):
    os.mkdir(current_trainset_path)
if not os.path.exists(current_labelset_path):
    os.mkdir(current_labelset_path)
if not os.path.exists(current_validateset_path):
    os.mkdir(current_validateset_path)

- 修剪每个`trainset_raw_image`, 根据对应的`trainset_label`，剔除不含Label标记的切片。

In [None]:
def TrimCTRaw3DImage(trainset_ct_file_name):
    trainset_3d_image = io.imread(ATM22_dataset_path + "imagesTr/" + trainset_ct_file_name, 
                                  plugin='simpleitk')
    trainset_3d_label = io.imread(ATM22_dataset_path + "labelsTr/" + trainset_ct_file_name, 
                                  plugin='simpleitk')
    
    assert trainset_3d_image.shape == trainset_3d_label.shape
    
    slice_start_index = 0
    slice_end_index = 0
    
    shape = trainset_3d_image.shape
    for slice_index in range(0, shape[0], 1):
        slice_label = trainset_3d_label[slice_index]
        slice_label_sum = np.sum(slice_label)
        if slice_label_sum > 0:
            slice_start_index = slice_index
            break
    
    for slice_index in range(shape[0]-1, 0, -1):
        slice_label = trainset_3d_label[slice_index]
        slice_label_sum = np.sum(slice_label)
        if slice_label_sum > 0:
            slice_end_index = slice_index
            break
    
    trimmed_3d_images = trainset_3d_image[(slice_start_index-5):(slice_end_index+5), :, :]
    trimmed_3d_labels = trainset_3d_label[(slice_start_index-5):(slice_end_index+5), :, :]
    
    sitk.WriteImage(sitk.GetImageFromArray(trimmed_3d_images), 
                    "{0}.nii.gz".format("ATM22_train/" + trainset_ct_file_name))
    
    sitk.WriteImage(sitk.GetImageFromArray(trimmed_3d_labels), 
                    "{0}.nii.gz".format("ATM22_label/" + trainset_ct_file_name))

In [None]:
for each_ct_file in trainset_raw_images:
    print("Trim CT file " + each_ct_file)
    TrimCTRaw3DImage(each_ct_file)

**注：** 以下代码是将trainset, labelset和validateset中的每个3D Image .nii.gz文件切割成(256, 256, 256)大小的立方体，保存为Numpy .npy格式。 

由于切割后的文件体积急剧膨胀，且文件数量庞大。 存储效率很差，故先取消这种做法。

In [None]:
# for raw_image_name in trainset_raw_images:
#     each_CT_3d_img_filename = ATM22_dataset_path + "imagesTr/" + raw_image_name
#     # print(each_CT_3d_img_filename)
    
#     case_name = (each_CT_3d_img_filename.split("/")[-1]).split(".")[0][:-4]
#     # print(case_name)
    
#     ct_3d_img = io.imread(each_CT_3d_img_filename, plugin='simpleitk')
#     cropped_cube_image_list = crop_CT_3D_image(ct_3d_img, crop_cube_size, stride)
    
#     for index, crop_cube in enumerate(cropped_cube_image_list):
#         print("Save file {0}{1}{2:04}.npy".format(current_trainset_path, case_name, index+1))
#         np.save("{0}{1}{2:04}.npy".format(current_trainset_path, case_name, index+1),
#                 crop_cube)
        
#         # sitk.WriteImage(sitk.GetImageFromArray(crop_cube), 
#         #                 "{0}{1}{2:04}.nii.gz".format(current_trainset_path, case_name, index+1))
        

## 4.2. Crop each CT 3D label in the labelset into cubes

In [None]:
# for raw_label_name in trainset_labels:
#     each_CT_label_filename = ATM22_dataset_path + "labelsTr/" + raw_label_name
#     # print(each_CT_label_filename)
    
#     case_name = (each_CT_label_filename.split("/")[-1]).split(".")[0][:-4]
#     # print(case_name)
    
#     ct_3d_label = io.imread(each_CT_label_filename, plugin='simpleitk')
#     cropped_cube_label_list = crop_CT_3D_image(ct_3d_label, crop_cube_size, stride)
    
#     for index, crop_cube in enumerate(cropped_cube_label_list):
#         print("Save file {0}{1}{2:04}.npy".format(current_labelset_path, case_name, index+1))
#         np.save("{0}{1}{2:04}.npy".format(current_labelset_path, case_name, index+1), 
#                 crop_cube)

## 4.3. Crop each 3D image in the validateset into cubes

In [None]:
# for raw_image_name in validateset_raw_images:
#     each_CT_3d_img_filename = ATM22_dataset_path + "imagesVal/" + raw_image_name
#     # print(each_CT_3d_img_filename)
    
#     case_name = (each_CT_3d_img_filename.split("/")[-1]).split(".")[0][:-4]
#     # print(case_name)
    
#     ct_3d_img = io.imread(each_CT_3d_img_filename, plugin='simpleitk')
#     cropped_cube_img_list = crop_CT_3D_image(ct_3d_img, crop_cube_size, stride)
    
#     for index, crop_cube in enumerate(cropped_cube_img_list):
#         print("Save file {0}{1}{2:04}.npy".format(current_validateset_path, case_name, index+1))
#         np.save("{0}{1}{2:04}.npy".format(current_validateset_path, case_name, index+1), 
#                 crop_cube)