In [14]:
#model轉onnx(C2FViT)
import torch
import torch.nn as nn
import torch.nn.functional as F
import nibabel as nib
import numpy as np
from C2FViT_model import C2F_ViT_stage, AffineCOMTransform, Center_of_mass_initial_pairwise
from Functions import min_max_norm, pad_to_shape

class FullModel(nn.Module):
    def __init__(self, model, affine_transform, init_center):
        super(FullModel, self).__init__()
        self.model = model
        self.affine_transform = affine_transform
        self.init_center = init_center

    def forward(self, moving_img, fixed_img):
        # Center of mass initialization
        moving_img, init_flow = self.init_center(moving_img, fixed_img)
        
        # Downsample the images
        X_down = F.interpolate(moving_img, scale_factor=0.5, mode="trilinear", align_corners=True)
        Y_down = F.interpolate(fixed_img, scale_factor=0.5, mode="trilinear", align_corners=True)
        
        # Run the core model
        warpped_x_list, y_list, affine_para_list = self.model(X_down, Y_down)
        
        # Apply the affine transformation
        X_Y, affine_matrix = self.affine_transform(moving_img, affine_para_list[-1])
        
        return X_Y, affine_matrix

# 設定裝置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定義模型
model = C2F_ViT_stage(img_size=128, patch_size=[3, 7, 15], stride=[2, 4, 8], num_classes=12,
                      embed_dims=[256, 256, 256], num_heads=[2, 2, 2], mlp_ratios=[2, 2, 2], qkv_bias=False,
                      qk_scale=None, drop_rate=0., attn_drop_rate=0., norm_layer=nn.Identity,
                      depths=[4, 4, 4], sr_ratios=[1, 1, 1], num_stages=3, linear=False).to(device)

# 加載預訓練模型權重
model_path = '../Model/C2FViT_affine_COM_template_matching_tigerdata_RAS/C2FViT_affine_COM_template_matching_tigerdata_RAS_stagelvl3_249000.pth'
print(f"Loading model weight {model_path} ...")
model.load_state_dict(torch.load(model_path))
model.eval()

# 定義轉換器
affine_transform = AffineCOMTransform().to(device)
init_center = Center_of_mass_initial_pairwise()

# 將核心模型和轉換器打包成完整模型
full_model = FullModel(model, affine_transform, init_center).to(device)

# 加載固定影像
fixed_path = '../Data/MNI152_T1_1mm_brain_pad_RSP.nii.gz'
fixed_img_nii = nib.load(fixed_path)
fixed_img = fixed_img_nii.get_fdata()

# 確保影像尺寸是 256x256x256
target_shape = (256, 256, 256)
if fixed_img.shape != target_shape:
    fixed_img = pad_to_shape(fixed_img, target_shape)

fixed_img = min_max_norm(fixed_img)
fixed_img = torch.from_numpy(fixed_img).float().to(device).unsqueeze(0).unsqueeze(0)

# Dummy moving image for ONNX conversion
dummy_moving_img = torch.randn(1, 1, 256, 256, 256).to(device)

# 將完整模型轉換成 ONNX 格式
onnx_path = "C2FViT_full_model.onnx"
torch.onnx.export(full_model, 
                  (dummy_moving_img, fixed_img), 
                  onnx_path, 
                  export_params=True, 
                  opset_version=20, 
                  do_constant_folding=True, 
                  input_names=['moving_img', 'fixed_img'], 
                  output_names=['moved', 'affine_matrix'])

print(f"Full model has been converted to {onnx_path}")


Loading model weight ../Model/C2FViT_affine_COM_template_matching_tigerdata_RAS/C2FViT_affine_COM_template_matching_tigerdata_RAS_stagelvl3_249000.pth ...


  model.load_state_dict(torch.load(model_path))


Full model has been converted to C2FViT_full_model.onnx


In [15]:
#讀取資料並預測
import onnxruntime as ort
import numpy as np
import nibabel as nib
from Functions import min_max_norm, pad_to_shape, reorient_image

    
moving_nii = nib.load('/NFS/PeiMao/GitHub/C2FViT_Medical_Image/Data/ABIDE_0050003_tbet.nii.gz')
fixed_nii = nib.load('/NFS/PeiMao/GitHub/C2FViT_Medical_Image/Data/MNI152_T1_1mm_brain_pad_RSP_RAS.nii.gz')
fixed_affine = fixed_nii.affine
fixed_header = fixed_nii.header
moving_nii = reorient_image(moving_nii, ('R', 'A', 'S'))
moving_data = moving_nii.get_fdata().astype(np.float32)
fixed_data = fixed_nii.get_fdata().astype(np.float32)
moving_data = pad_to_shape(moving_data, (256, 256, 256))

fixed_data = np.clip(fixed_data, a_min=2500, a_max=np.max(fixed_data))

# 在第0轴和第1軸位置添加新维度（增加一个 batch size 维度）
moving = np.expand_dims(moving_data, axis=0)
moving = np.expand_dims(moving, axis=1)
fixed = np.expand_dims(fixed_data, axis=0)
fixed = np.expand_dims(fixed, axis=1)

moving = min_max_norm(moving)
fixed = min_max_norm(fixed)

# 创建 ONNX Runtime 会话
session = ort.InferenceSession("C2FViT_full_model.onnx")

# 获取输入的名称
input_names = [input.name for input in session.get_inputs()]
output_names = [output.name for output in session.get_outputs()]

# 创建输入字典
inputs = {input_names[0]: moving, input_names[1]: fixed}

# 运行模型推理
outputs = session.run(None, inputs)

# 打印输出结果
print(outputs[0].shape)
print(outputs[1].shape)

# 使用 squeeze 移除长度为 1 的维度
moved = np.squeeze(outputs[0])
affine_matrix = np.squeeze(outputs[1])
print(affine_matrix)

# 创建一个 NIfTI 图像对象
moved_nii = nib.Nifti1Image(moved, fixed_affine)

# 保存 NIfTI 图像为 .nii.gz 文件
nib.save(moved_nii, 'onnx_output_image.nii.gz')

(1, 1, 256, 256, 256)
(1, 3, 4)
[[ 0.77673465  0.3878609  -0.04437534  0.01298795]
 [-0.36949837  0.8408904   0.01499392 -0.02714885]
 [ 0.05178259 -0.01343449  0.8940104   0.00496839]]
