# Pre-process the ATM22 dataset

In [27]:
import os
import numpy as np
from skimage import io

from preprocessing import crop_CT_3D_image
import SimpleITK as sitk

## 1. Specify the path where ATM22 dataset resides

你可以根据ATM22 dataset所在的存储位置来更改 `ATM22_dataset_path` 的值，最好使用相对路径。

In [14]:
ATM22_dataset_path = "../Dataset/ATM22/"

## 2. Gather the raw CT scans and labels from ATM22 dataset

收集ATM22数据集中的所有CT扫描图片文件和标注图片文件，整理排序并检查。

In [15]:
# List the raw images for train set
trainset_raw_images = os.listdir(ATM22_dataset_path + "imagesTr/")
trainset_raw_images.sort()

# trainset_raw_images

In [16]:
# List the labels for train set
trainset_labels = os.listdir(ATM22_dataset_path + "labelsTr/")
trainset_labels.sort()

# trainset_labels

In [17]:
# List the raw images for validation set
validateset_raw_images = os.listdir(ATM22_dataset_path + "imagesVal/")
validateset_raw_images.sort()

# validateset_raw_images

In [18]:
assert len(trainset_raw_images) == len(trainset_labels)

for index, item in enumerate(trainset_raw_images):
    assert (item.split(".")[0] == trainset_labels[index].split(".")[0]), "Raw_CT_image[{0}] does not correspond to CT_image_label[{0}]".format(index)

## 3. Collect the cases names for trainset and validateset, respectively.

In [19]:
trainset_case_names = []
for item in trainset_raw_images:
    filename = item.split(".")[0]
    case_name = filename[:-5]
    trainset_case_names.append(case_name)

len(trainset_case_names)

300

In [20]:
validate_case_names = []
for item in validateset_raw_images:
    filename = item.split(".")[0]
    case_name = filename[:-5]
    validate_case_names.append(case_name)

len(validate_case_names)

50

## 4.1. Crop each CT 3D image in the trainset into cubes

将训练集中的每一个CT 3D image按照`crop_cube_size`切割为众多小的立方体，有序保存为numpy .npy格式。

In [28]:
crop_cube_size = (256, 256, 256)
stride = (256, 256, 256)

current_trainset_path = "./ATM22_train/"
current_labelset_path = "./ATM22_label/"
current_validateset_path = "./ATM22_validate/"

if not os.path.exists(current_trainset_path):
    os.mkdir(current_trainset_path)
if not os.path.exists(current_labelset_path):
    os.mkdir(current_labelset_path)
if not os.path.exists(current_validateset_path):
    os.mkdir(current_validateset_path)

In [1]:
# for raw_image_name in trainset_raw_images:
#     each_CT_3d_img_filename = ATM22_dataset_path + "imagesTr/" + raw_image_name
#     # print(each_CT_3d_img_filename)
    
#     case_name = (each_CT_3d_img_filename.split("/")[-1]).split(".")[0][:-4]
#     # print(case_name)
    
#     ct_3d_img = io.imread(each_CT_3d_img_filename, plugin='simpleitk')
#     cropped_cube_image_list = crop_CT_3D_image(ct_3d_img, crop_cube_size, stride)
    
#     for index, crop_cube in enumerate(cropped_cube_image_list):
#         print("Save file {0}{1}{2:04}.npy".format(current_trainset_path, case_name, index+1))
#         np.save("{0}{1}{2:04}.npy".format(current_trainset_path, case_name, index+1),
#                 crop_cube)
        
#         # sitk.WriteImage(sitk.GetImageFromArray(crop_cube), 
#         #                 "{0}{1}{2:04}.nii.gz".format(current_trainset_path, case_name, index+1))
        

## 4.2. Crop each CT 3D label in the labelset into cubes

In [23]:
# for raw_label_name in trainset_labels:
#     each_CT_label_filename = ATM22_dataset_path + "labelsTr/" + raw_label_name
#     # print(each_CT_label_filename)
    
#     case_name = (each_CT_label_filename.split("/")[-1]).split(".")[0][:-4]
#     # print(case_name)
    
#     ct_3d_label = io.imread(each_CT_label_filename, plugin='simpleitk')
#     cropped_cube_label_list = crop_CT_3D_image(ct_3d_label, crop_cube_size, stride)
    
#     for index, crop_cube in enumerate(cropped_cube_label_list):
#         print("Save file {0}{1}{2:04}.npy".format(current_labelset_path, case_name, index+1))
#         np.save("{0}{1}{2:04}.npy".format(current_labelset_path, case_name, index+1), 
#                 crop_cube)

## 4.3. Crop each 3D image in the validateset into cubes

In [25]:
# for raw_image_name in validateset_raw_images:
#     each_CT_3d_img_filename = ATM22_dataset_path + "imagesVal/" + raw_image_name
#     # print(each_CT_3d_img_filename)
    
#     case_name = (each_CT_3d_img_filename.split("/")[-1]).split(".")[0][:-4]
#     # print(case_name)
    
#     ct_3d_img = io.imread(each_CT_3d_img_filename, plugin='simpleitk')
#     cropped_cube_img_list = crop_CT_3D_image(ct_3d_img, crop_cube_size, stride)
    
#     for index, crop_cube in enumerate(cropped_cube_img_list):
#         print("Save file {0}{1}{2:04}.npy".format(current_validateset_path, case_name, index+1))
#         np.save("{0}{1}{2:04}.npy".format(current_validateset_path, case_name, index+1), 
#                 crop_cube)