# 前處理

In [1]:
import SimpleITK as sitk
from antspynet.utilities import brain_extraction
import ants
import numpy as np
import os
from tqdm import tqdm
def ants_2_itk(image):
    imageITK = sitk.GetImageFromArray(image.numpy().T)
    imageITK.SetOrigin(image.origin)
    imageITK.SetSpacing(image.spacing)
    imageITK.SetDirection(image.direction.reshape(9))
    return imageITK

def itk_2_ants(image):
    image_ants = ants.from_numpy(sitk.GetArrayFromImage(image).T, 
                                 origin=image.GetOrigin(), 
                                 spacing=image.GetSpacing(), 
                                 direction=np.array(image.GetDirection()).reshape(3, 3))
    return image_ants

In [2]:
##################### 去顱骨 #####################
def ants_skull_stripping(input_image):
    # 用 U-net 生成機率圖
    prob_brain_mask = brain_extraction(input_image,modality='t1',verbose=False)
    # 將機率圖轉換為遮罩
    brain_mask = ants.get_mask(prob_brain_mask,low_thresh=0.5)
    # 套用遮罩
    masked_image = ants.mask_image(input_image,brain_mask)
    # 輸出檔案
    #masked_image.to_file(r'lab/IXI_brain_test.nii')
    return masked_image
##################### 去顱骨 #####################

In [3]:
#################### MNI配準 #####################
def sitk_mni_registration(input_image,template_image):
    # 宣告初始變換
    initial_transform = sitk.CenteredTransformInitializer(
        template_image,
        input_image,
        sitk.Euler3DTransform(),  # 使用剛體變換
        sitk.CenteredTransformInitializerFilter.GEOMETRY
    )
    # 宣告配準方法
    registration_method = sitk.ImageRegistrationMethod()
    
    # 設定配準方法
    # 初始變換
    registration_method.SetInitialTransform(initial_transform, inPlace=False)
    # 評估函數：Mattes Mutal Information
    registration_method.SetMetricAsMattesMutualInformation(numberOfHistogramBins=50)
    registration_method.SetMetricSamplingStrategy(registration_method.RANDOM)
    registration_method.SetMetricSamplingPercentage(0.01)
    # 設定優化器：Gradient Descent
    registration_method.SetOptimizerAsGradientDescent(learningRate=1.0, numberOfIterations=100, 
                                                      convergenceMinimumValue=1e-6, convergenceWindowSize=10)
    registration_method.SetOptimizerScalesFromPhysicalShift() #特別重要
    # 設定插值方法：Linear
    registration_method.SetInterpolator(sitk.sitkLinear)
    
    # 執行配準
    final_transform = registration_method.Execute(template_image, input_image)
    # 套用配準
    registered_image = sitk.Resample(input_image, template_image, final_transform, 
                                 sitk.sitkLinear, 0.0, input_image.GetPixelID())
    # 輸出
    return registered_image
#################### MNI配準 #####################

In [14]:
def preprocessing(template_image,input_image,output_dir):
    # 強度歸一 (result)
    input_image = itk_2_ants(registered_image)
    truncated_img = ants.iMath_truncate_intensity(input_image, 0.05, 0.95) # 0.01 到 0.99 分位數
    normalized_img = ants.iMath_normalize(truncated_img) # 歸一化到 [0, 1] float32
    img_uint8 = ((normalized_img - normalized_img.min()) / (normalized_img.max() - normalized_img.min()) * 255).astype('uint8') # Min-Max 歸一化到 [0, 255] unit8
    path = os.path.join(output_dir,'ppResult','pp007_result_'+item)
    ants.image_write(img_uint8, path)

In [15]:
template_image = sitk.ReadImage('mni152-s.nii',sitk.sitkFloat32)
input_file = os.path.join(directory, item)
output_dir = r'C:\Users\user\Desktop\brainModel\pp007'
input_image = ants.image_read(input_file)
preprocessing(template_image,input_image,output_dir)

C:\Users\user\Desktop\brainModel\raw_data\ABIDE\collection


100%|██████████████████████████████████████████████████████████████████████████████| 561/561 [1:40:59<00:00, 10.80s/it]


C:\Users\user\Desktop\brainModel\raw_data\ADNI\sc\CN\collection


100%|██████████████████████████████████████████████████████████████████████████████| 562/562 [1:34:45<00:00, 10.12s/it]


C:\Users\user\Desktop\brainModel\raw_data\IXI\collection


100%|██████████████████████████████████████████████████████████████████████████████| 564/564 [1:27:15<00:00,  9.28s/it]


C:\Users\user\Desktop\brainModel\raw_data\camcan\collection


100%|██████████████████████████████████████████████████████████████████████████████| 653/653 [2:03:16<00:00, 11.33s/it]


# 腦齡預測

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch
from torch.utils.data import DataLoader, TensorDataset, ConcatDataset
import numpy as np
import os
from tqdm import tqdm
import torch.nn.functional as F

In [None]:
# 定義 sfcn 的基本塊
class BasicBlock(nn.Module): # Conv -> BN -> MaxPooling -> ReLU
    expansion = 1  # 通道擴展倍數

    def __init__(self, in_planes, out_planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride,  padding=1)
        self.bn1 = nn.BatchNorm3d(out_planes)
        self.maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2)

    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.maxpool1(out)
        out = torch.relu(out)
        return out

In [None]:
# Define the full network architecture
class SFCN(nn.Module):
    def __init__(self, num_classes=1):
        super(SFCN, self).__init__()
        
        # Initial channel size
        self.in_planes = 1
        # Block 1 (Input channel: 1 -> Output channel: 32)
        self.layer1 = self._make_layer(BasicBlock, 32, num_blocks=1, stride=1)

        # Block 2 (Input channel: 32 -> Output channel: 64)
        self.layer2 = self._make_layer(BasicBlock, 64, num_blocks=1, stride=1) 

        # Block 3 (Input channel: 64 -> Output channel: 128)
        self.layer3 = self._make_layer(BasicBlock, 128, num_blocks=1, stride=1)

        # Block 4 (Input channel: 128 -> Output channel: 256)
        self.layer4 = self._make_layer(BasicBlock, 256, num_blocks=1, stride=1)

        # Block 5 (Input channel: 256 -> Output channel: 256)
        self.layer5 = self._make_layer(BasicBlock, 256, num_blocks=1, stride=1)

        # Stage 2 (Conv1*1 -> BN -> Relu)
        self.conv1 = nn.Conv3d(256, 64, kernel_size=1, stride=1, padding=0)
        self.bn1 = nn.BatchNorm3d(64)
        
        # Stage 3 (AvgPool -> Dropout -> Conv1*1)
        self.dropout = nn.Dropout(p=0.2)
        self.conv2 = nn.Conv3d(64, 50, kernel_size=1, stride=1, padding=0)
        self.softmax = nn.Softmax(dim=1)
        

    def _make_layer(self, block, out_planes, num_blocks, stride):
        layers = []
        layers.append(block(self.in_planes, out_planes, stride))
        self.in_planes = out_planes * block.expansion
        for _ in range(1, num_blocks):
            layers.append(block(self.in_planes, out_planes))
        return nn.Sequential(*layers)

    def forward(self, x):
        # Stage 1
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = self.layer5(x)
        # Stage 2
        x = self.conv1(x)
        x = self.bn1(x)
        x = torch.relu(x)
        # Stage 3
        x = F.adaptive_avg_pool3d(x, output_size=1)
        x = self.dropout(x)
        x = self.conv2(x)
        x = x.view(x.shape[0], 50)
        x = self.softmax(x)
        # 權重平均
        weights = torch.arange(1, 51, dtype=x.dtype, device=x.device).view(1, -1)
        x = torch.sum(x * weights, dim=1, keepdim=True)
        x = x * 2
        return x

In [None]:
cnns=[None, None, None, None, None]
pt_ls = os.listdir(r'models')
for i in range(5):
    cnns[i] = SFCN(num_classes=1)
    path = os.path.join(r'models',pt_ls[i])
    cnns[i].load_state_dict(torch.load(path))
    cnns[i].eval()
    cnns[i].to(device)

In [None]:
image = np.zeros((1,128,192,128),dtype=np.uint8)
# 讀檔案
path = r'ppResult/'
# 製作 numpy image
data = nib.load(path)
image = data.get_fdata()
# 製作 torch image
image_tensor = torch.from_numpy(image).to(torch.uint8)  # 儲存為 uint8

In [None]:
inputs = image_tensor.to(torch.float32) / 255.0
inputs = inputs.unsqueeze(0)
# 組合評估
ensemble_output = 0.0
for model_i in range(5):
    outputs = cnns[model_i](inputs)
    ensemble_output += outputs.item()*0.2
print(ensemble_output)