In [1]:
import torch
from torch import nn, einsum
from einops import rearrange
from pytorch_model_summary import summary
from monai.networks.layers.utils import get_norm_layer
from unetr_plus_plus.unetr_pp.network_architecture.dynunet_block import get_conv_layer
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
print(device)

cuda


기본 세팅

In [2]:
input_size=[32 * 32 * 32, 16 * 16 * 16, 8 * 8 * 8, 4 * 4 * 4]
dims=[32, 64, 128, 256]
proj_size =[64,64,64,32]
depths=[3, 3, 3, 3]
num_heads=4
spatial_dims=3
in_channels=1
dropout=0.0
transformer_dropout_rate=0.15

1-1. Downsampling (기존 버전: Conv)

In [3]:
class downsampling(nn.Module):
    def __init__(self,spatial_dims,in_channels,out_channels,kernel_size,stride,dropout,conv_only=True):
        super().__init__()
        self.downsample_layer=get_conv_layer(spatial_dims, in_channels, out_channels, kernel_size, stride,dropout, conv_only, )
        self.get_norm_layer=get_norm_layer(name=("group", {"num_groups": in_channels}), channels=out_channels)
    
    def forward(self,x):
        x=self.downsample_layer(x)
        x=self.get_norm_layer(x)
        return x

In [4]:
x=torch.zeros(1,64,16,16,16).cuda() # [B,C,D,H,W] input: 16 x 16 x 16 x 64
model=downsampling(
    spatial_dims=3, in_channels=x.shape[1], out_channels=x.shape[1]*2,kernel_size=[2,2,2],stride=[2,2,2],dropout=0.0
).to(device)

print(summary(model,x))
print('input:',x.shape)
print('output:',model(x).shape)

-------------------------------------------------------------------------
      Layer (type)          Output Shape         Param #     Tr. Param #
          Conv3d-1     [1, 128, 8, 8, 8]          65,536          65,536
       GroupNorm-2     [1, 128, 8, 8, 8]             256             256
Total params: 65,792
Trainable params: 65,792
Non-trainable params: 0
-------------------------------------------------------------------------
input: torch.Size([1, 64, 16, 16, 16])
output: torch.Size([1, 128, 8, 8, 8])


1-2. Downsampling (새로운 버전: Patch Merging)

In [5]:
class PatchMerging(nn.Module):
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.norm = norm_layer(8 * dim)
        self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False)

    def forward(self, x):
        """
        x: B,C,D,H,W
        """
        x=x.permute(0,3,4,2,1) # [B,H,W,D,C]
        B=x.shape[0];H=x.shape[1];W=x.shape[2];D=x.shape[3];C=x.shape[4]

        y=None
        for i in range(0,D,2):
            # process 2 slice
            x_=x[:, :, :, i:i+2, :] # B, H/2, W/2, 2, C
            
            x_0=x_[:, 0::2, 0::2, :, :] # B, H/2, W/2, 2, C
            x_1=x_[:, 0::2, 1::2, :, :] # B, H/2, W/2, 2, C 
            x_2=x_[:, 1::2, 0::2, :, :]  # B, H/2, W/2, 2, C 
            x_3=x_[:, 1::2, 1::2, :, :] # B, H/2, W/2, 2, C

            # width, height information -> channel information
            rst=torch.cat([x_0,x_1,x_2,x_3],-1) # B, H/2, W/2, 2, 4*C

            # dimension information -> channel information
            rst=rst.view(B, H//2, W//2, 1, 8*C) # B, H/2, W/2, 1, 8*C

            # concat 
            if i==0:
                y=rst.clone() # B, H/2, W/2, 1, 8*C
            else:
                y=torch.cat([y,rst],-2) # final shape -> [B, H/2, W/2, D/2, 8*C]
        
        # normalization
        y=self.norm(y) # B, H/2, W/2, D/2, 8*C
        
        # embedding
        y=self.reduction(y) # B, H/2, W/2, D/2, 2*C

        y=y.permute(0,4,3,1,2) # B, 2*C, D/2, H/2, W/2
        return y

In [6]:
x=torch.zeros(1,64,16,16,16) # [B,C,D,H,W] input: 16 x 16 x 16 x 64
model=PatchMerging(dim=x.shape[1])

print(summary(model,x))
print('input:',x.shape)
print('output:',model(x).shape)

-------------------------------------------------------------------------
      Layer (type)          Output Shape         Param #     Tr. Param #
       LayerNorm-1     [1, 8, 8, 8, 512]           1,024           1,024
          Linear-2     [1, 8, 8, 8, 128]          65,536          65,536
Total params: 66,560
Trainable params: 66,560
Non-trainable params: 0
-------------------------------------------------------------------------
input: torch.Size([1, 64, 16, 16, 16])
output: torch.Size([1, 128, 8, 8, 8])


2-1.Upsampling (기존 버전: TrasposedConv)

In [10]:
class upsampling(nn.Module):
    def __init__(self,spatial_dims,in_channels,out_channels,upsample_kernel_size,upsample_stride,dropout,conv_only=True):
        super().__init__()
        self.transp_conv = get_conv_layer(
            spatial_dims,
            in_channels,
            out_channels,
            kernel_size=upsample_kernel_size,
            stride=upsample_stride,
            conv_only=True,
            is_transposed=True,
        )
    
    def forward(self,y):
        y=self.transp_conv(y)
        return y

In [11]:
y=torch.zeros(1,128,8,8,8) # [B,C,D,H,W] input: 8 x 8 x 8 x 128
model=upsampling(
    spatial_dims=3, in_channels=y.shape[1], out_channels=y.shape[1]//2,upsample_kernel_size=[2,2,2],upsample_stride=[2,2,2],dropout=0.0
)

print(summary(model,y))
print('input:',y.shape)
print('output:',model(y).shape)

-----------------------------------------------------------------------------
        Layer (type)            Output Shape         Param #     Tr. Param #
   ConvTranspose3d-1     [1, 64, 16, 16, 16]          65,536          65,536
Total params: 65,536
Trainable params: 65,536
Non-trainable params: 0
-----------------------------------------------------------------------------
input: torch.Size([1, 128, 8, 8, 8])
output: torch.Size([1, 64, 16, 16, 16])


2-2.Upsampling (새로운 버전: Patch Expanding)

In [12]:
class PatchExpanding(nn.Module):
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.norm = norm_layer(dim//2)
        self.expand = nn.Linear(dim, 4 * dim, bias=False)

    def forward(self, y):
        """
        y: B,C,D,H,W
        """
        y=y.permute(0,3,4,2,1) # [B,H,W,D,C]
        B=y.shape[0];H=y.shape[1];W=y.shape[2];D=y.shape[3];C=y.shape[4]

        # channel expand
        y=self.expand(y) # B, H, W, D, 4*C

        x=None
        for i in range(0,D):
            y_=y[:,:,:,i,:] # B, H, W, 1, 4*C
            y_=y_.view(B,H,W,1,4*C) 

            # channel information -> dimension information
            y_=y_.view(B, H, W, 2, 2*C) # B, H, W, 2, 2*C

            # channel informatinon -> width, height information
            rst=rearrange(y_,'b h w d (p1 p2 c)-> b (h p1) (w p2) d c', p1=2, p2=2, c=C//2) # B, 2*H, 2*W, 2, C//2
            
            # concat
            if i==0:
                x=rst.clone() # B, 2*H, 2*W, 2, C//2
            else:
                x=torch.cat([x,rst],-2) # final shape -> [B, 2*H, 2*W, 2*D, C//2]
                        
        # normalization
        x=self.norm(x) # B, 2*H, 2*W, 2*D, C//2

        x=x.permute(0,4,3,1,2) # B, C//2, 2*D, 2*H, 2*W
        return x

In [14]:
y=torch.zeros(1,128,8,8,8) # [B,C,D,H,W] input: 8 x 8 x 8 x 128
model=PatchExpanding(dim=y.shape[1])

print(summary(model,y))
print('input:',y.shape)
print('output:',model(y).shape)

---------------------------------------------------------------------------
      Layer (type)            Output Shape         Param #     Tr. Param #
          Linear-1       [1, 8, 8, 8, 512]          65,536          65,536
       LayerNorm-2     [1, 16, 16, 16, 64]             128             128
Total params: 65,664
Trainable params: 65,664
Non-trainable params: 0
---------------------------------------------------------------------------
input: torch.Size([1, 128, 8, 8, 8])
output: torch.Size([1, 64, 16, 16, 16])


[Check] x-> (Patch Merging) -> y -> (Patch Expanding) -> z

In [16]:
# stage1 input: x = 128 x 128 x 64 x 1 (H x W x D x 1)
# stage2 input: x = 32 x 32 x 32 x 32 (H/4 x W/4 x D/2 x C)
# stage3 input: x = 16 x 16 x 16 x 64 (H/8 x W/8 x D/4 x 2C)
# stage4 input: x = 8 x 8 x 8 x 128 (H/16 x W/16 x D/8 x 4C)

input=torch.zeros(1,1,64,128,128) # [B, C, D, H, W]
# patch embedding은 생략 -> stage2부터 시작 
C=32
for i in range(0,3):
    print(f'====stage{i+2}====')
    x=torch.zeros(input.shape[0],C*2**i,input.shape[2]//(2*2**i),input.shape[3]//(4*2**i),input.shape[4]//(4*2**i))
    print('x.shape:',x.shape)

    down=PatchMerging(dim=x.shape[1])
    y=down(x)
    print('y.shape:',y.shape,'<-- downsampling(patch merging)')

    up=PatchExpanding(dim=y.shape[1])
    z=up(y)
    print('z.shape:',z.shape,'<-- upsampling(patch expanding)')

====stage2====
x.shape: torch.Size([1, 32, 32, 32, 32])
y.shape: torch.Size([1, 64, 16, 16, 16]) <-- downsampling(patch merging)
z.shape: torch.Size([1, 32, 32, 32, 32]) <-- upsampling(patch expanding)
====stage3====
x.shape: torch.Size([1, 64, 16, 16, 16])
y.shape: torch.Size([1, 128, 8, 8, 8]) <-- downsampling(patch merging)
z.shape: torch.Size([1, 64, 16, 16, 16]) <-- upsampling(patch expanding)
====stage4====
x.shape: torch.Size([1, 128, 8, 8, 8])
y.shape: torch.Size([1, 256, 4, 4, 4]) <-- downsampling(patch merging)
z.shape: torch.Size([1, 128, 8, 8, 8]) <-- upsampling(patch expanding)


3-1. EPA 모듈 (기존 버전: 병렬)

In [17]:
class EPA(nn.Module):
    """
        Efficient Paired Attention Block, based on: "Shaker et al.,
        UNETR++: Delving into Efficient and Accurate 3D Medical Image Segmentation"
        """
    def __init__(self, input_size, hidden_size, proj_size, num_heads=4, qkv_bias=False,
                 channel_attn_drop=0.1, spatial_attn_drop=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
        self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1))

        # qkvv are 4 linear layers (query_shared, key_shared, value_spatial, value_channel)
        self.qkvv = nn.Linear(hidden_size, hidden_size * 4, bias=qkv_bias)

        # E and F are projection matrices with shared weights used in spatial attention module to project
        # keys and values from HWD-dimension to P-dimension
        self.E = self.F = nn.Linear(input_size, proj_size)

        self.attn_drop = nn.Dropout(channel_attn_drop)
        self.attn_drop_2 = nn.Dropout(spatial_attn_drop)

        self.out_proj = nn.Linear(hidden_size, int(hidden_size // 2))
        self.out_proj2 = nn.Linear(hidden_size, int(hidden_size // 2))

    def forward(self, x):
        B, N, C = x.shape

        qkvv = self.qkvv(x).reshape(B, N, 4, self.num_heads, C // self.num_heads)

        qkvv = qkvv.permute(2, 0, 3, 1, 4)

        q_shared, k_shared, v_CA, v_SA = qkvv[0], qkvv[1], qkvv[2], qkvv[3]

        q_shared = q_shared.transpose(-2, -1)
        k_shared = k_shared.transpose(-2, -1)
        v_CA = v_CA.transpose(-2, -1)
        v_SA = v_SA.transpose(-2, -1)

        k_shared_projected = self.E(k_shared)

        v_SA_projected = self.F(v_SA)

        q_shared = torch.nn.functional.normalize(q_shared, dim=-1)
        k_shared = torch.nn.functional.normalize(k_shared, dim=-1)

        attn_CA = (q_shared @ k_shared.transpose(-2, -1)) * self.temperature

        attn_CA = attn_CA.softmax(dim=-1)
        attn_CA = self.attn_drop(attn_CA)

        x_CA = (attn_CA @ v_CA).permute(0, 3, 1, 2).reshape(B, N, C)

        attn_SA = (q_shared.permute(0, 1, 3, 2) @ k_shared_projected) * self.temperature2

        attn_SA = attn_SA.softmax(dim=-1)
        attn_SA = self.attn_drop_2(attn_SA)

        x_SA = (attn_SA @ v_SA_projected.transpose(-2, -1)).permute(0, 3, 1, 2).reshape(B, N, C)

        # Concat fusion
        x_SA = self.out_proj(x_SA)
        x_CA = self.out_proj2(x_CA)
        x = torch.cat((x_SA, x_CA), dim=-1)
        return x

In [71]:
x=torch.zeros(1,32,32,32,32) # B, C, H, W, D
B, C, H, W, D = x.shape
x = x.reshape(B, C, H * W * D).permute(0, 2, 1) # B, H*W*D, C

input_size=x.shape[1]
hidden_size=32
proj_size=64

model=EPA(input_size, hidden_size, proj_size, num_heads=4)
print(summary(model,x))
print('input:',x.shape)
print('output:',model(x).shape)

-------------------------------------------------------------------------
      Layer (type)          Output Shape         Param #     Tr. Param #
          Linear-1       [1, 32768, 128]           4,096           4,096
          Linear-2         [1, 4, 8, 64]       2,097,216       2,097,216
         Dropout-3          [1, 4, 8, 8]               0               0
         Dropout-4     [1, 4, 32768, 64]               0               0
          Linear-5        [1, 32768, 16]             528             528
          Linear-6        [1, 32768, 16]             528             528
Total params: 2,102,368
Trainable params: 2,102,368
Non-trainable params: 0
-------------------------------------------------------------------------
input: torch.Size([1, 32768, 32])
output: torch.Size([1, 32768, 32])


3-2. EPA 모듈 (새로운 버전: 직렬)

In [29]:
class My_EPA(nn.Module):
    def __init__(self, input_size, hidden_size, proj_size, num_heads=4, qkv_bias=False,
                 channel_attn_drop=0.1, spatial_attn_drop=0.1):
        super().__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1)) # for channel attention
        self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1)) # for spatial attention

        # qkv are 3 linear layers (query, key, value)
        self.qkv = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)
        self.qkv2 = nn.Linear(hidden_size, hidden_size * 3, bias=qkv_bias)

        # projection matrices with shared weights used in attention module to project
        self.proj_k = self.proj_v = nn.Linear(input_size, proj_size)

        self.attn_drop = nn.Dropout(channel_attn_drop) 
        self.attn_drop_2 = nn.Dropout(spatial_attn_drop)
    
    def forward(self, x):
        '''
        Channel Attention
        : [ Q_T x K ]
        '''
        B, N, C = x.shape # N=HWD

        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) # B x N x 3 x h x C/h
        qkv = qkv.permute(2, 0, 3, 1, 4) # 3 x B x h x N x C/h
        q, k, v = qkv[0], qkv[1], qkv[2] # B x h x N x C/h

        q_t = q.transpose(-2, -1) # B x h x C/h x N
        k_t = k.transpose(-2, -1) # B x h x C/h x N
        v_t = v.transpose(-2, -1) # B x h x C/h x N

        q_t = torch.nn.functional.normalize(q_t, dim=-1)
        k_t = torch.nn.functional.normalize(k_t, dim=-1)
        
        k = k_t.transpose(-2, -1) # K : B x h x C/h x C/h
        attn_CA = (q_t @ k) * self.temperature # [Q_T x K] B x h x C/h x C/h 

        attn_CA = attn_CA.softmax(dim=-1)
        attn_CA = self.attn_drop(attn_CA) # [Channel Attn Map] B x h x C/h x C/h

        v = v_t.permute(0,1,3,2) # V : B x h x N x C/h

        # [V x Channel Attn Map] B x h x N x C/h -> B x C/h x h x N -> B x N x C
        x_CA = (v @ attn_CA).permute(0, 3, 1, 2).reshape(B, N, C)

        '''
        Spatial Attention
        : K -> K(p), V -> V(p) [ Q x K_T(p) ]
        '''
        qkv2 = self.qkv2(x_CA).reshape(B, N, 3, self.num_heads, C // self.num_heads) # B x N x 3 x h x C/h
        qkv2 = qkv2.permute(2, 0, 3, 1, 4) # 3 x B x h x N x C/h
        q2, k2, v2 = qkv2[0], qkv2[1], qkv2[2] # B x h x N x C/h

        q2_t = q2.transpose(-2, -1) # B x h x C/h x N
        k2_t = k2.transpose(-2, -1) # B x h x C/h x N
        v2_t = v2.transpose(-2, -1) # B x h x C/h x N

        k2_t_projected = self.proj_k(k2_t) # B x h x C/h x p
        v2_t_projected = self.proj_v(v2_t) # B x h x C/h x p

        q2_t = torch.nn.functional.normalize(q2_t, dim=-1)
        k2_t = torch.nn.functional.normalize(k2_t, dim=-1)

        q2 = q2_t.permute(0, 1, 3, 2) # Q : B x h x N x C/h
        attn_SA = (q2 @ k2_t_projected) * self.temperature2  # [Q x K_T(p)] B x h x N x p
        
        attn_SA = attn_SA.softmax(dim=-1)
        attn_SA = self.attn_drop_2(attn_SA) # [Spatial Attn Map] B x h x N x p
        
        v2_projected = v2_t_projected.transpose(-2, -1) # V(p) : B x h x p x C/h

        # [Spatial Attn Map x V(p)] B x h x N x C/h -> B x C/h x h x N -> B x N x C
        x_SA = (attn_SA @ v2_projected).permute(0, 3, 1, 2).reshape(B, N, C) 
        x = x_SA

        return x

In [30]:
x=torch.zeros(1,32,32,32,32) # B, C, H, W, D
B, C, H, W, D = x.shape
x = x.reshape(B, C, H * W * D).permute(0, 2, 1) # B, H*W*D, C

input_size=x.shape[1]
hidden_size=32
proj_size=64

model=My_EPA(input_size, hidden_size, proj_size, num_heads=4)
print(summary(model,x))
print('input:',x.shape)
print('output:',model(x).shape)

-------------------------------------------------------------------------
      Layer (type)          Output Shape         Param #     Tr. Param #
          Linear-1        [1, 32768, 96]           3,072           3,072
         Dropout-2          [1, 4, 8, 8]               0               0
          Linear-3        [1, 32768, 96]           3,072           3,072
          Linear-4         [1, 4, 8, 64]       2,097,216       2,097,216
         Dropout-5     [1, 4, 32768, 64]               0               0
Total params: 2,103,360
Trainable params: 2,103,360
Non-trainable params: 0
-------------------------------------------------------------------------
input: torch.Size([1, 32768, 32])
output: torch.Size([1, 32768, 32])


4-1. TIF 모듈 (기존 버전: 2D)

In [2]:
class Conv_block(nn.Module):
    def __init__(self, in_ch, out_ch, groups):
        super(Conv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(num_channels=out_ch,num_groups=groups),
            nn.ReLU(inplace=True),
            nn.Conv2d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True),
            nn.GroupNorm(num_channels=out_ch,num_groups=groups),
            nn.ReLU(inplace=True))

    def forward(self, x):
        x = self.conv(x)
        return x

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = dots.softmax(dim=-1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class Cross_Att(nn.Module):
    def __init__(self, dim_e, dim_r):
        super().__init__()
        self.transformer_e = Transformer(dim=dim_e, depth=1, heads=4, dim_head=dim_e//4, mlp_dim=128) # UNETR++와 head, dim_head 통일 <local>
        self.transformer_r = Transformer(dim=dim_r, depth=1, heads=4, dim_head=dim_r//4, mlp_dim=256) # UNETR++와 head, dim_head 통일 <global>
        self.norm_e = nn.LayerNorm(dim_e) 
        self.norm_r = nn.LayerNorm(dim_r) 
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.linear_e = nn.Linear(dim_e, dim_r)
        self.linear_r = nn.Linear(dim_r, dim_e)

    def forward(self, e, r):
       b_e, c_e, h_e, w_e = e.shape
       e = e.reshape(b_e, c_e, -1).permute(0, 2, 1)
       b_r, c_r, h_r, w_r = r.shape
       r = r.reshape(b_r, c_r, -1).permute(0, 2, 1)
       e_t = torch.flatten(self.avgpool(self.norm_e(e).transpose(1,2)), 1)
       r_t = torch.flatten(self.avgpool(self.norm_r(r).transpose(1,2)), 1)
       e_t = self.linear_e(e_t).unsqueeze(1)
       r_t = self.linear_r(r_t).unsqueeze(1)
       r = self.transformer_r(torch.cat([e_t, r],dim=1))[:, 1:, :]
       e = self.transformer_e(torch.cat([r_t, e],dim=1))[:, 1:, :]
       e = e.permute(0, 2, 1).reshape(b_e, c_e, h_e, w_e) 
       r = r.permute(0, 2, 1).reshape(b_r, c_r, h_r, w_r) 
       return e, r

In [23]:
class TIF(nn.Module):
    def __init__(self,dim_e,dim_r):
        super().__init__()
        self.dim_e=dim_e # local feature map channel (32,64,128,256)
        self.dim_r=dim_r # global feature map channel (64,128,256,512)
        self.cross_attn=Cross_Att(self.dim_e,self.dim_r)
        self.up = nn.Upsample(scale_factor=2)
        self.conv=Conv_block(in_ch=self.dim_e+self.dim_r, out_ch=self.dim_e, groups=32)
        
    def forward(self,e,r):
        '''
        e: local feature (H x W x C)
        r: global feature (H/2 x W/2 x 2C)
        '''
        e,r=self.cross_attn(e,r) # [B,C,H,W], [B,C,H/2,W/2]
        e = torch.cat([e,self.up(r)],1) # B,2C,H,W
        e=self.conv(e) # B,C,H,W
        
        return e

In [25]:
e=torch.zeros(1,32,32,32) # [B,C,H,W] Local 
r=torch.zeros(1,64,16,16) # [B,C,H,W] Global

model=TIF(e.shape[1],r.shape[1]) # dim=32

print(summary(model,e,r))
print('input(e,r):',e.shape, r.shape)
print('output:',model(e,r).shape)

----------------------------------------------------------------------------------------
      Layer (type)                         Output Shape         Param #     Tr. Param #
       Cross_Att-1     [1, 32, 32, 32], [1, 64, 16, 16]          66,784          66,784
        Upsample-2                      [1, 64, 32, 32]               0               0
      Conv_block-3                      [1, 32, 32, 32]          37,056          37,056
Total params: 103,840
Trainable params: 103,840
Non-trainable params: 0
----------------------------------------------------------------------------------------
input(e,r): torch.Size([1, 32, 32, 32]) torch.Size([1, 64, 16, 16])
output: torch.Size([1, 32, 32, 32])


4-2. TIF 모듈 (새로운 버전: 3D)

In [3]:
class My_Conv_block(nn.Module):
    def __init__(self, in_ch, out_ch, groups):
        super(My_Conv_block, self).__init__()
        self.conv = nn.Sequential(
            nn.Conv3d(in_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True), # 수정
            nn.GroupNorm(num_channels=out_ch,num_groups=groups),
            nn.ReLU(inplace=True),
            nn.Conv3d(out_ch, out_ch, kernel_size=3, stride=1, padding=1, bias=True), # 수정
            nn.GroupNorm(num_channels=out_ch,num_groups=groups),
            nn.ReLU(inplace=True))

    def forward(self, x):
        x = self.conv(x)
        return x

class My_Attention(nn.Module):
    def __init__(self, input_size, proj_size, dim, heads, dim_head, dropout = 0.): # 수정
        super().__init__()
        inner_dim = dim_head *  heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False) 

        self.E = self.F = nn.Linear(input_size+1, proj_size) # 수정

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
        '''
        k, v projection (차원 축소하지 않으면 연산량이 너무 커짐)
        '''
        k=self.E(k.transpose(-2, -1)).transpose(-2,-1) # 수정
        v=self.F(v.transpose(-2, -1)).transpose(-2,-1) # 수정

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = dots.softmax(dim=-1)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out

class My_Transformer(nn.Module):
    def __init__(self, input_size, proj_size, dim, depth, heads, dim_head, mlp_dim, dropout = 0.): # 수정
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                PreNorm(dim, My_Attention(input_size, proj_size, dim, heads = heads, dim_head = dim_head, dropout = dropout)),
                PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout))
            ]))
    def forward(self, x):
        for attn, ff in self.layers:
            x = attn(x) + x
            x = ff(x) + x
        return x

class My_Cross_Att(nn.Module):
    def __init__(self, HWD_e, HWD_r, proj_size, dim_e, dim_r): # 수정
        super().__init__()
        self.transformer_e = My_Transformer(HWD_e, proj_size, dim=dim_e, depth=1, heads=4, dim_head=dim_e//4, mlp_dim=128) # UNETR++와 head, dim_head 통일 <local>
        self.transformer_r = My_Transformer(HWD_r, proj_size, dim=dim_r, depth=1, heads=4, dim_head=dim_r//4, mlp_dim=256) # UNETR++와 head, dim_head 통일 <global>
        self.norm_e = nn.LayerNorm(dim_e) 
        self.norm_r = nn.LayerNorm(dim_r) 
        self.avgpool = nn.AdaptiveAvgPool1d(1)
        self.linear_e = nn.Linear(dim_e, dim_r)
        self.linear_r = nn.Linear(dim_r, dim_e)

    def forward(self, e, r):
       b_e, c_e, d_e, h_e, w_e = e.shape
       e = e.reshape(b_e, c_e, -1).permute(0, 2, 1) # B, N, C
       b_r, c_r, d_r, h_r, w_r = r.shape
       r = r.reshape(b_r, c_r, -1).permute(0, 2, 1)
       e_t = torch.flatten(self.avgpool(self.norm_e(e).transpose(1,2)), 1)
       r_t = torch.flatten(self.avgpool(self.norm_r(r).transpose(1,2)), 1)
       e_t = self.linear_e(e_t).unsqueeze(1)
       r_t = self.linear_r(r_t).unsqueeze(1)
       r = self.transformer_r(torch.cat([e_t, r],dim=1))[:, 1:, :]
       e = self.transformer_e(torch.cat([r_t, e],dim=1))[:, 1:, :]
       e = e.permute(0, 2, 1).reshape(b_e, c_e, d_e, h_e, w_e) 
       r = r.permute(0, 2, 1).reshape(b_r, c_r, d_r, h_r, w_r) 
       return e, r

In [10]:
class My_TIF(nn.Module):
    def __init__(self,HWD_e,HWD_r,proj_size,dim_e,dim_r): 
        super().__init__()
        # dim_e = local feature map channel (32,64,128,256)
        # dim_r = global feature map channel (64,128,256,512)
        self.cross_attn=My_Cross_Att(HWD_e,HWD_r,proj_size,dim_e,dim_r)
        self.up = nn.Upsample(scale_factor=2)
        self.conv=My_Conv_block(in_ch=dim_e+dim_r, out_ch=dim_e, groups=16)
        
    def forward(self,e,r):
        '''
        e: local feature (H x W x D x C)
        r: global feature (H/2 x W/2 x D/2 x 2C)
        '''
        skip=e
        e,r=self.cross_attn(e,r) # [B,C,D,H,W], [B,C,D/2,H/2,W/2]
        e = torch.cat([e,self.up(r)],1) # B,3C,D,H,W
        e=self.conv(e) # B,C,D,H,W
        e=skip+e # skip connection
        
        return e

In [13]:
e=torch.zeros(1,32,32,32,32) # [B,C,D,H,W] Local  
r=torch.zeros(1,64,16,16,16) # [B,C,D,H,W] Global

HWD_e=e.shape[2]*e.shape[3]*e.shape[4]
HWD_r=r.shape[2]*r.shape[3]*r.shape[4]
proj_size=64
dim_e=e.shape[1]
dim_r=r.shape[1]

model=My_TIF(HWD_e,HWD_r,proj_size,dim_e,dim_r)

print(summary(model,e,r))
print('input(e,r):',e.shape, r.shape)
print('output:',model(e,r).shape)

------------------------------------------------------------------------------------------------
      Layer (type)                                 Output Shape         Param #     Tr. Param #
    My_Cross_Att-1     [1, 32, 32, 32, 32], [1, 64, 16, 16, 16]       2,426,336       2,426,336
        Upsample-2                          [1, 64, 32, 32, 32]               0               0
   My_Conv_block-3                          [1, 32, 32, 32, 32]         110,784         110,784
Total params: 2,537,120
Trainable params: 2,537,120
Non-trainable params: 0
------------------------------------------------------------------------------------------------
input(e,r): torch.Size([1, 32, 32, 32, 32]) torch.Size([1, 64, 16, 16, 16])
output: torch.Size([1, 32, 32, 32, 32])


5. Fusion 모듈  
(channel reduction -> concat -> EPA -> Conv2d -> Skip Connection)

In [27]:
class Fusion(nn.Module):
    def __init__(self,HWD_e,HWD_r,proj_size,dim_e,dim_r): 
        super().__init__()
        # dim_e = local feature map channel (32,64,128)
        # dim_r = global feature map channel (64,128,256)
        self.HWD_e, self.HWD_r= HWD_e,HWD_r
        self.HWD=self.HWD_e+self.HWD_r
        self.reduction_r=nn.Linear(dim_r,dim_e) 
        self.dim=dim_e
        self.EPA=My_EPA(input_size=self.HWD,hidden_size=self.dim,proj_size=proj_size)
        # feature dimension reduction with conv2d
        ks_h, ks_w=int(1+(1/8*HWD_e)),1
        self.conv=nn.Conv2d(in_channels=1,out_channels=1,kernel_size=(ks_h,ks_w))
        self.relu=nn.ReLU()

    def forward(self,e,r):
        '''
        e: local feature (H x W x D x C)
        r: global feature (H/2 x W/2 x D/2 x 2C)
        '''
        skip=e # e: [B, C, D, H, W]
        b_e, c_e, d_e, h_e, w_e = e.shape[0],e.shape[1],e.shape[2],e.shape[3],e.shape[4]

        e=e.reshape(e.shape[0],e.shape[1],e.shape[2]*e.shape[3]*e.shape[4]).permute(0,2,1) # e: [B, HWD, C]
        r=r.reshape(r.shape[0],r.shape[1],r.shape[2]*r.shape[3]*r.shape[4]).permute(0,2,1) # r: [B, HWD/8, 2C]

        r=self.reduction_r(r) # [B, HWD/8,C]
        x=torch.cat([e,r],1) # [B, 9/8*HWD, C]
        x=self.EPA(x) # [B, 9/8*HWD, C]
        x=x.unsqueeze(1) # [B, 1, 9/8*HWD, C]
        x=self.conv(x) # [B, 1, HWD, C]
        x=x.squeeze(1).permute(0,2,1).reshape(b_e,c_e,d_e,h_e,w_e) # [B,C,D,H,W]
        x=x+skip # skip connection
        x=self.relu(x)

        return x

In [28]:
e=torch.zeros(1,32,32,32,32) # [B,C,D,H,W] Local 
r=torch.zeros(1,64,16,16,16) # [B,C,D,H,W] Global

HWD_e=e.shape[2]*e.shape[3]*e.shape[4]
HWD_r=r.shape[2]*r.shape[3]*r.shape[4]
proj_size=64
dim_e=e.shape[1]
dim_r=r.shape[1]

model=Fusion(HWD_e,HWD_r,proj_size,dim_e,dim_r)

print(summary(model,e,r))
print('input(e,r):',e.shape, r.shape)
print('output:',model(e,r).shape)

---------------------------------------------------------------------------
      Layer (type)            Output Shape         Param #     Tr. Param #
          Linear-1           [1, 4096, 32]           2,080           2,080
          My_EPA-2          [1, 36864, 32]       2,362,440       2,362,440
          Conv2d-3       [1, 1, 32768, 32]           4,098           4,098
            ReLU-4     [1, 32, 32, 32, 32]               0               0
Total params: 2,368,618
Trainable params: 2,368,618
Non-trainable params: 0
---------------------------------------------------------------------------
input(e,r): torch.Size([1, 32, 32, 32, 32]) torch.Size([1, 64, 16, 16, 16])
output: torch.Size([1, 32, 32, 32, 32])


6. NFCE 모듈

In [6]:
class My_NFCE(nn.Module):
    def __init__(self,in_dim): 
        super().__init__()
        mid_dim=in_dim//4
        self.conv1=nn.Conv3d(in_channels=in_dim, out_channels=mid_dim, kernel_size=1, bias=False) # Conv 1x1x1
        self.norm1=nn.BatchNorm3d(mid_dim)

        # depthwise seperable convolution
        self.dsconv=nn.Sequential(
            nn.Conv3d(in_channels=mid_dim, out_channels=mid_dim, kernel_size=3, padding=1, bias=False, groups=mid_dim), # Depth-wise Conv 3x3x3
            nn.Conv3d(in_channels=mid_dim, out_channels=mid_dim, kernel_size=1, bias=False) # Point-wise Conv 1x1x1
        )
        self.norm2=nn.BatchNorm3d(mid_dim)

        self.conv3=nn.Conv3d(in_channels=mid_dim, out_channels=in_dim, kernel_size=1, bias=False) # Conv 1x1x1
        self.norm3=nn.BatchNorm3d(in_dim)

        self.relu=nn.ReLU()

    def forward(self,x):
        '''
        x: feature (H x W x D x C)
        ex) 16 x 16 x 16 x 64
        '''
        save=x # [B, C, D, H, W]
        
        # 1x1x1 conv -> [B, C//4, D, H, W]
        x=self.conv1(x) 
        x=self.norm1(x)
        x=self.relu(x)

        # depthwise seperable conv -> [B, C//4, D, H, W]
        x=self.dsconv(x) 
        x=self.norm2(x)
        x=self.relu(x)
        
        # 1x1x1 conv -> [B, C, D, H, W]
        x=self.conv3(x)
        x=self.norm3(x)
        
        # skip connection -> [B, C, D, H, W]
        x=x+save 
        x=self.relu(x)
        
        return x

In [4]:
x=torch.zeros(1, 32, 32, 32, 32)
in_dim=x.shape[1]

model=My_NFCE(in_dim)
print(summary(model,x))
print('Input:',x.shape)
print('Output:',model(x).shape)

---------------------------------------------------------------------------
      Layer (type)            Output Shape         Param #     Tr. Param #
          Conv3d-1      [1, 8, 32, 32, 32]             256             256
     BatchNorm3d-2      [1, 8, 32, 32, 32]              16              16
            ReLU-3      [1, 8, 32, 32, 32]               0               0
          Conv3d-4      [1, 8, 32, 32, 32]             216             216
          Conv3d-5      [1, 8, 32, 32, 32]              64              64
     BatchNorm3d-6      [1, 8, 32, 32, 32]              16              16
          Conv3d-7     [1, 32, 32, 32, 32]             256             256
     BatchNorm3d-8     [1, 32, 32, 32, 32]              64              64
Total params: 888
Trainable params: 888
Non-trainable params: 0
---------------------------------------------------------------------------
Input: torch.Size([1, 32, 32, 32, 32])
Output: torch.Size([1, 32, 32, 32, 32])
