In [1]:
#model轉onnx(C2FViT)
import torch
import torch.nn as nn
import torch.nn.functional as F
import nibabel as nib
import numpy as np
from C2FViT_model import C2F_ViT_stage, AffineCOMTransform, CustomAffineCOMTransform, Center_of_mass_initial_pairwise, CustomCenter_of_mass_initial_pairwise
from Functions import min_max_norm, pad_to_shape

class FullModel(nn.Module):
    def __init__(self, model, affine_transform, init_center):
        super(FullModel, self).__init__()
        self.model = model
        self.affine_transform = affine_transform
        self.init_center = init_center

    def forward(self, moving_img, fixed_img):
        # Center of mass initialization
        moving_img, init_flow = self.init_center(moving_img, fixed_img)
        
        # Downsample the images
        X_down = F.interpolate(moving_img, scale_factor=0.5, mode="trilinear", align_corners=True)
        Y_down = F.interpolate(fixed_img, scale_factor=0.5, mode="trilinear", align_corners=True)
        
        # Run the core model
        warpped_x_list, y_list, affine_para_list = self.model(X_down, Y_down)
        #rigid
        #affine_para_list[-1][0, 11] = 0
        #affine_para_list[-1][0, 10] = 0
        #affine_para_list[-1][0, 9] = 0
        #affine_para_list[-1][0, 8] = 0
        #affine_para_list[-1][0, 7] = 0
        #affine_para_list[-1][0, 6] = 0
        # Apply the affine transformation
        X_Y, affine_matrix = self.affine_transform(moving_img, affine_para_list[-1])
        
        return X_Y, affine_matrix, init_flow

# 設定裝置
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# 定義模型
model = C2F_ViT_stage(img_size=128, patch_size=[3, 7, 15], stride=[2, 4, 8], num_classes=12,
                      embed_dims=[256, 256, 256], num_heads=[2, 2, 2], mlp_ratios=[2, 2, 2], qkv_bias=False,
                      qk_scale=None, drop_rate=0., attn_drop_rate=0., norm_layer=nn.Identity,
                      depths=[4, 4, 4], sr_ratios=[1, 1, 1], num_stages=3, linear=False).to(device)

# 加載預訓練模型權重
model_path = '../Model/C2FViT_affine_COM_template_matching_tigerdata_RAS/C2FViT_affine_COM_template_matching_tigerdata_RAS_stagelvl3_249000.pth'
print(f"Loading model weight {model_path} ...")
model.load_state_dict(torch.load(model_path))
model.eval()

# 定義轉換器
affine_transform = CustomAffineCOMTransform().to(device)
init_center = CustomCenter_of_mass_initial_pairwise()

# 將核心模型和轉換器打包成完整模型
full_model = FullModel(model, affine_transform, init_center).to(device)

# 加載固定影像
fixed_path = '../Data/MNI152_T1_1mm_brain_pad_RSP.nii.gz'
fixed_img_nii = nib.load(fixed_path)
fixed_img = fixed_img_nii.get_fdata()

# 確保影像尺寸是 256x256x256
target_shape = (256, 256, 256)
if fixed_img.shape != target_shape:
    fixed_img = pad_to_shape(fixed_img, target_shape)

fixed_img = min_max_norm(fixed_img)
fixed_img = torch.from_numpy(fixed_img).float().to(device).unsqueeze(0).unsqueeze(0)

# Dummy moving image for ONNX conversion
dummy_moving_img = torch.randn(1, 1, 256, 256, 256).to(device)

# 將完整模型轉換成 ONNX 格式
onnx_path = "mprage_affine_v001_trainttttttttttt.onnx"
torch.onnx.export(full_model, 
                  (dummy_moving_img, fixed_img), 
                  onnx_path, 
                  export_params=True, 
                  opset_version=19, 
                  do_constant_folding=True, 
                  input_names=['moving_img', 'fixed_img'], 
                  output_names=['moved', 'affine_matrix', 'init_flow'])

print(f"Full model has been converted to {onnx_path}")


  warn(
  _torch_pytree._register_pytree_node(


Loading model weight ../Model/C2FViT_affine_COM_template_matching_tigerdata_RAS/C2FViT_affine_COM_template_matching_tigerdata_RAS_stagelvl3_249000.pth ...


  assert n == gn


UnsupportedOperatorError: Exporting the operator 'aten::affine_grid_generator' to ONNX opset version 19 is not supported. Please feel free to request support or submit a pull request on PyTorch GitHub: https://github.com/pytorch/pytorch/issues.