### My_UNETR++ (Synapese)

모델 구조  
(unetr_pp/training/network_training/unetr_pp_trainer_synapse.py)

In [1]:
import sys
sys.path.append('my_unetr_plus_plus')
 
import torch
import torch.nn as nn
from pytorch_model_summary import summary
from my_unetr_plus_plus.unetr_pp.network_architecture.synapse.unetr_pp_synapse import UNETR_PP
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
print(device)

cuda


In [2]:
input_channels=1
num_classes=14
crop_size=[64, 128, 128]

전체 모델 (SAB + CrossMFA 적용)

In [3]:
# 네트워크
network = UNETR_PP(in_channels=input_channels,
                   out_channels=num_classes,
                   img_size=crop_size,
                   feature_size=16,
                   num_heads=4,
                   depths=[3, 3, 3, 3],
                   dims=[32, 64, 128, 256],
                   do_ds=True,)
network=network.to(device) # Network

# 모델 요약 및 테스트
input=torch.zeros(1, 1, 64, 128, 128).cuda() # Input Shape : [B, C, D, H, W]
print(summary(network, input, show_input=True)) # Forwarding Shape
output=network(input) 
print('output shape:') # Output Shape
for i,out in enumerate(output):
    print(i,':',out.shape)

---------------------------------------------------------------------------------------------------
       Layer (type)                                    Input Shape         Param #     Tr. Param #
   UnetrPPEncoder-1                           [1, 1, 64, 128, 128]      22,821,472      22,821,472
     UnetResBlock-2                           [1, 1, 64, 128, 128]           7,360           7,360
         CrossMFA-3           [1, 128, 8, 8, 8], [1, 256, 4, 4, 4]       1,166,800       1,166,800
         CrossMFA-4         [1, 64, 16, 16, 16], [1, 128, 8, 8, 8]       2,244,560       2,244,560
         CrossMFA-5       [1, 32, 32, 32, 32], [1, 64, 16, 16, 16]      13,953,296      13,953,296
     UnetrUpBlock-6           [1, 256, 4, 4, 4], [1, 128, 8, 8, 8]       1,431,256       1,431,256
     UnetrUpBlock-7         [1, 128, 8, 8, 8], [1, 64, 16, 16, 16]       1,891,288       1,891,288
     UnetrUpBlock-8       [1, 64, 16, 16, 16], [1, 32, 32, 32, 32]       9,534,040       9,534,040
     Unet

proj_size: [128, 96, 64, 32]

In [4]:
from my_unetr_plus_plus.unetr_pp.network_architecture.synapse.model_components import UnetrPPEncoder

# 네트워크
unetr_pp_encoder = UnetrPPEncoder(dims=[32,64,128,256], depths=[3,3,3,3], num_heads=4)

# 테스트
input=torch.zeros(1, 1, 64, 128, 128) # Input Shape : [B, C, D, H, W]
output, hidden_states=unetr_pp_encoder(input) 
print(summary(unetr_pp_encoder,input))

print('output shape:',output.shape) # Output Shape
print('hidden_states:')
for i,hd in enumerate(hidden_states):
    print(i,':',hd.shape)

-------------------------------------------------------------------------------
          Layer (type)            Output Shape         Param #     Tr. Param #
              Conv3d-1     [1, 32, 32, 32, 32]           1,024           1,024
           GroupNorm-2     [1, 32, 32, 32, 32]              64              64
    TransformerBlock-3     [1, 32, 32, 32, 32]       5,269,768       5,269,768
    TransformerBlock-4     [1, 32, 32, 32, 32]       5,269,768       5,269,768
    TransformerBlock-5     [1, 32, 32, 32, 32]       5,269,768       5,269,768
     My_PatchMerging-6     [1, 64, 16, 16, 16]          16,512          16,512
    TransformerBlock-7     [1, 64, 16, 16, 16]         739,688         739,688
    TransformerBlock-8     [1, 64, 16, 16, 16]         739,688         739,688
    TransformerBlock-9     [1, 64, 16, 16, 16]         739,688         739,688
    My_PatchMerging-10       [1, 128, 8, 8, 8]          65,792          65,792
   TransformerBlock-11       [1, 128, 8, 8, 8]     

모델 구조

In [5]:
network.parameters

<bound method Module.parameters of UNETR_PP(
  (unetr_pp_encoder): UnetrPPEncoder(
    (downsample_layers): ModuleList(
      (0): Sequential(
        (0): Convolution(
          (conv): Conv3d(1, 32, kernel_size=(2, 4, 4), stride=(2, 4, 4), bias=False)
        )
        (1): GroupNorm(1, 32, eps=1e-05, affine=True)
      )
      (1): Sequential(
        (0): My_PatchMerging(
          (reduction): Linear(in_features=256, out_features=64, bias=False)
          (gnorm): GroupNorm(32, 64, eps=1e-05, affine=True)
        )
      )
      (2): Sequential(
        (0): My_PatchMerging(
          (reduction): Linear(in_features=512, out_features=128, bias=False)
          (gnorm): GroupNorm(64, 128, eps=1e-05, affine=True)
        )
      )
      (3): Sequential(
        (0): My_PatchMerging(
          (reduction): Linear(in_features=1024, out_features=256, bias=False)
          (gnorm): GroupNorm(128, 256, eps=1e-05, affine=True)
        )
      )
    )
    (stages): ModuleList(
      (0): S