## 配准
每个病例文件夹下有ORIGIN,ROI两个字文件夹。其中：<br>

    ORIGIN文件夹里边具有多个期的图像;<br>
    ROI文件夹里边只有一个期的人工标注。<br>
    
以带有ROI标注的期相（不一定是20min）位fix图像，将其它期相的图像，配准到fix图像中；

    step1、找出所有配准和待配准图像索引表（dataframe）
    1-1：找出数据集中每一个病例的标注文件，并获取其标注的期相，作为fix；
    1-2：找出其余所有需要标注的期相（记为moving图像）；
    step2、分别将moving图像，往fix图像上配准，并保存
    2-1：读图，获得fixed和moving图像的sitk对象；
    2-2：判断sitk对象是否包含多个通道，若是，则挑选出我们真正想要的那个期相。
    2-3：配准。

注意：图像配准过程中，sitk会生成比较大的过程文件，程序运行后记得清理一下电脑的存储。

In [None]:
import os
from glob import glob
from tqdm import tqdm
import pandas as pd
import SimpleITK as sitk
from utils import util
import re
import shutil
import numpy as np

from parfor import parfor
from multiprocessing import Pool
from multiprocessing.dummy import Pool as ThreadPool

In [None]:
params = {
    'data_dir': '../Out_source/Need Reg/', # 原始数据集路径
    'save_dir': '../Registrated_all4/Registrated_20221117/', # 配准后的数据集保存根目录
    'temp_dir': '../temp', # 临时文件存放的地方
    'phases': ['T1','T2','AP','PP','20 MIN'], # 所有期相
}

In [None]:
data_dir = params['data_dir']# 原始数据集路径
phases = params['phases']
dir_reg = params['save_dir']# 配准后的数据集保存根目录

''' 
注意！ 文件路径应为：一级目录 / 医院 / 病灶类型 / ID / ROI / phase.nrrd 
'''
pat = os.path.join( data_dir, r'*/*/*/ROI/*.nrrd' )# glob寻找人工标注文件的模式
label_filepath_list = glob(pat)# 所有的人工标注文件，列表
print( '%d label nrrd file in total in %s.'%(len(label_filepath_list),data_dir) )

filepaths_to_registrated = pd.DataFrame(
    index=label_filepath_list,
    columns=['fix']+phases)# 新建一个dataframe，用以保存配准信息
filepaths_to_registrated.index.name = 'label'

# 遍历每一个病例样本
k=0
for label_filepath in tqdm(filepaths_to_registrated.index):
    _,_,_,_,cat,case,_,label_filename = label_filepath.split('/')
    # 图像文件路径
    image_dir = os.path.split(label_filepath)[0].replace('ROI','ORIGIN')
    # 找出所有moving图像文件路径：
    for phase in phases:
        try:
            filepaths_to_registrated.loc[label_filepath,phase] = \
                glob( os.path.join(image_dir,'*{}*.nrrd'.format(phase)) )[0]
        except:
            None


    # 找出fix图像文件路径（人工标注的期相）:
    fn = os.path.split(label_filename)[-1]
    fix_phase = None
    for phase in params['phases']:
        if phase.upper() in fn.upper():
            fix_phase = phase
            k += 1
    
    if fix_phase:  
        filepaths_to_registrated.loc[label_filepath,'fix'] = filepaths_to_registrated.loc[label_filepath,fix_phase]
    else:
        err_msg = 'Error: fixed image not found: {}'.format(label_filepath)
        print(err_msg)
# 删除掉哪些fix图像缺失的病例
print( '{} samples lack of fix image'.format(filepaths_to_registrated['fix'].isna().sum()) )
samples_noFixImage = filepaths_to_registrated[ filepaths_to_registrated['fix'].isna() ]
print('Samples missing fixed image:')
display(samples_noFixImage)
filepaths_to_registrated.drop(index=samples_noFixImage.index, inplace=True)
print('Fine samples:')
display( filepaths_to_registrated )

In [None]:
# main
for label_filepath in tqdm(filepaths_to_registrated.index[:]):
    # 读取fix图像,(有些fix图像原文中包含多个期相，只需要保留我们想要的那个期相)：
    # fixed图像的文件路径
    fixed_filepath = filepaths_to_registrated.loc[label_filepath, 'fix']

    if isinstance(fixed_filepath, float):
        if np.isnan(fixed_filepath):
            continue
    elif isinstance(fixed_filepath, str):
        if not os.path.isfile(fixed_filepath):
            continue

    fixedImage = sitk.ReadImage(fixed_filepath)
    # 部分病例文件存在多通道情况，要根据文件命名来提取出想要的那个通道
    if fixedImage.GetNumberOfComponentsPerPixel() > 1:
        fixedImage = sitk.Cast(fixedImage, sitk.sitkVectorFloat32)
        pat = '(\d)-(\d).*.nrrd'
        fixed_basename = os.path.basename(fixed_filepath)
        match_result = re.match(pat, fixed_basename)
        n, m = match_result.groups()
        n, m = int(n), int(m)  # n表示该期相所在的通道（从1开始编号），m表示该文件中的所有通道数
        fixedImage_arr = sitk.GetArrayFromImage(fixedImage)
        fixedImage_corrected = sitk.GetImageFromArray(fixedImage_arr[..., n-1])
        fixedImage_corrected.CopyInformation(fixedImage)
    else:
        fixedImage_corrected = sitk.Cast(fixedImage, sitk.sitkFloat32)

    # 新建文件夹，用于保存该病例配准的图像文件，并保存fix文件、ROI文件
    _, _, _, hospital, subset, case, case_subdir, filename = fixed_filepath.split(os.sep)
    sample_save_dir = os.path.join(
        dir_reg, hospital, subset, case, case_subdir)  # 该病例保存的文件夹
    if not os.path.exists(sample_save_dir):
        os.makedirs(sample_save_dir)
    # 保存fix图像文件
    copy_fixed_filepath = os.path.join(sample_save_dir, re.sub(
        '\s*\d+-\d+\s*', '', os.path.basename(fixed_filepath)))
    if fixedImage.GetNumberOfComponentsPerPixel() > 1:
        sitk.WriteImage(fixedImage_corrected, copy_fixed_filepath)
    else:
        shutil.copy(src=fixed_filepath, dst=copy_fixed_filepath)

    # 保存ROI
    copy_label_path = sample_save_dir.replace(case_subdir, 'ROI')
    if not os.path.exists(copy_label_path):
        os.makedirs(copy_label_path)
    copy_label_filepath = os.path.join(copy_label_path, re.sub(
        '\s*\d+-\d+\s*', '', os.path.basename(label_filepath)))
    shutil.copy(src=label_filepath, dst=copy_label_filepath)
    # sitk.WriteImage(sitk.ReadImage(label_filepath,sitk.sitkFloat32), copy_label_filepath)

    # 遍历每一个待配准序列
    # for colname in filepaths_to_registrated.columns.drop('fix'):

    @parfor(filepaths_to_registrated.columns.drop('fix'), bar=False, qbar=False, rP=1, serial=2)
    def fun(colname):
        # moving图像的文件路径
        moving_filepath = filepaths_to_registrated.loc[label_filepath, colname]
        if moving_filepath == fixed_filepath:
            # continue
            return
        if isinstance(moving_filepath, float):
            if np.isnan(moving_filepath):
                # continue
                return
        elif isinstance(moving_filepath, str):
            if not os.path.isfile(moving_filepath):
                # continue
                return

        movingImage = sitk.ReadImage(moving_filepath)  # sitk.sitkFloat32 )
        # 执行配准
        if movingImage.GetNumberOfComponentsPerPixel() > 1:
            movingImage = sitk.Cast(movingImage, sitk.sitkVectorFloat32)
            pat = '(\d)+-(\d)+.*.nrrd'
            moving_basename = os.path.basename(moving_filepath)
            print(moving_basename)
            result = re.match(pat, moving_basename)
            n, m = result.groups()
            n, m = int(n), int(m)  # n表示该期相所在的通道（从1开始编号），m表示该文件中的所有通道数
            movingImage_arr = sitk.GetArrayFromImage(movingImage)
            movingImage_corrected = sitk.GetImageFromArray(
                movingImage_arr[..., n-1])
            movingImage_corrected.CopyInformation(movingImage)
        else:
            movingImage_corrected = sitk.Cast(movingImage, sitk.sitkFloat32)
        try:
            resultImage = util.Image_Registration_Method4(
                fixedImage_corrected, movingImage_corrected, DefaultPixelValue=0.0)
        except:
            print('Error while trying to registrated {}'.format(moving_filepath))
            # continue
            return
        # 保存配准后的moving文件：
        resultImage_filepath = os.path.join(sample_save_dir, re.sub(
            '\s*\d+-\d+\s*', '', os.path.basename(moving_filepath)))
        sitk.WriteImage(resultImage, resultImage_filepath)

    # 删除临时文件
    for fn in os.listdir(params['temp_dir']):
        try:
            os.remove(os.path.join(params['temp_dir'], fn))
        except:
            None