In [1]:
!nvidia-smi

Sun May 28 18:03:03 2023       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 465.19.01    Driver Version: 465.19.01    CUDA Version: 11.3     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ...  On   | 00000000:81:00.0 Off |                  N/A |
| 39%   30C    P8    17W / 350W |      3MiB / 24268MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
|   1  NVIDIA GeForce ...  On   | 00000000:C1:00.0 Off |                  N/A |
| 39%   32C    P8    20W / 350W |      3MiB / 24268MiB |      0%      Default |
|       

In [11]:
import sys
sys.path.append('/home/sunghyunahn/anomaly_detection/avss_anomaly_detection/network')

In [12]:
import torch
import math
from torch import nn, einsum
from einops import rearrange
from pytorch_model_summary import summary
from monai.networks.layers.utils import get_norm_layer
from dynunet_block import get_conv_layer, UnetResBlock
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
print(device)

cuda


DSConv

In [181]:
class DSConv(nn.Module):
    """
    Depthwise seperable convolution. 
    """
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, spatial_dims=3):
        super().__init__()

        if spatial_dims == 3:
            self.depthwise = nn.Conv3d(in_channels, in_channels, kernel_size, stride, 
                                padding, dilation, groups=in_channels, bias=False)
            self.pointwise = nn.Conv3d(in_channels, out_channels, kernel_size=1, 
                                    stride=1, padding=0, dilation=1, groups=1, bias=False)
            self.bn = nn.BatchNorm3d(out_channels, momentum=0.9997, eps=4e-5)

        else:
            self.depthwise = nn.Conv2d(in_channels, in_channels, kernel_size, stride, 
                                padding, dilation, groups=in_channels, bias=False)
            self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, 
                                    stride=1, padding=0, dilation=1, groups=1, bias=False)
            self.bn = nn.BatchNorm2d(out_channels, momentum=0.9997, eps=4e-5)

        self.act = nn.ReLU()
        
    def forward(self, inputs):
        x = self.depthwise(inputs)
        x = self.pointwise(x)
        x = self.bn(x)
        return self.act(x)

In [182]:
x=torch.zeros(1,384,4,8,8).cuda() # [B,C,D,H,W] input: 8 x 8 x 4 x 384
model=DSConv(in_channels=x.shape[1], out_channels=x.shape[1]).to(device)

print(summary(model,x))
print('input:',x.shape)
print('output:',model(x).shape)

-------------------------------------------------------------------------
      Layer (type)          Output Shape         Param #     Tr. Param #
          Conv3d-1     [1, 384, 4, 8, 8]          10,368          10,368
          Conv3d-2     [1, 384, 4, 8, 8]         147,456         147,456
     BatchNorm3d-3     [1, 384, 4, 8, 8]             768             768
            ReLU-4     [1, 384, 4, 8, 8]               0               0
Total params: 158,592
Trainable params: 158,592
Non-trainable params: 0
-------------------------------------------------------------------------
input: torch.Size([1, 384, 4, 8, 8])
output: torch.Size([1, 384, 4, 8, 8])


Patch Embedding  
(Partition+Embedding)

In [13]:
class PatchEmbedding(nn.Module):
    def __init__(self, spatial_dims=3, in_channels=3, out_channels=24, kernel_size=(1,4,4), stride=(1,4,4), dropout=0.0):
        super().__init__()
        self.conv=get_conv_layer(spatial_dims, in_channels, out_channels, kernel_size, stride, dropout, conv_only=True)
        self.norm=get_norm_layer(name=("group", {"num_groups": in_channels}), channels=out_channels)
    
    def forward(self,x):
        x=self.conv(x)
        x=self.norm(x)
        return x

In [14]:
x=torch.zeros(1,3,4,256,256).cuda() # [B,C,D,H,W] input: 256 x 256 x 4 x 3
model=PatchEmbedding().to(device)

print(summary(model,x))
print('input:',x.shape)
print('output:',model(x).shape)

--------------------------------------------------------------------------
      Layer (type)           Output Shape         Param #     Tr. Param #
          Conv3d-1     [1, 24, 4, 64, 64]           1,152           1,152
       GroupNorm-2     [1, 24, 4, 64, 64]              48              48
Total params: 1,200
Trainable params: 1,200
Non-trainable params: 0
--------------------------------------------------------------------------
input: torch.Size([1, 3, 4, 256, 256])
output: torch.Size([1, 24, 4, 64, 64])


Downsample

In [24]:
class Downsample(nn.Module):
    def __init__(self, in_channels, out_channels, spatial_dims=3, kernel_size=(1,2,2), stride=(1,2,2), dropout=0.0):
        super().__init__()
        self.conv=get_conv_layer(spatial_dims, in_channels, out_channels, kernel_size, stride, dropout, conv_only=True)
        self.norm=get_norm_layer(name=("group", {"num_groups": in_channels}), channels=out_channels)
    
    def forward(self,x):
        x=self.conv(x)
        x=self.norm(x)
        return x

In [27]:
x=torch.zeros(1,24,4,64,64).cuda() # [B,C,D,H,W] input: 64 x 64 x 4 x 24
model=Downsample(in_channels=x.shape[1], out_channels=x.shape[1]*2).to(device)

print(summary(model,x))
print('input:',x.shape)
print('output:',model(x).shape)

--------------------------------------------------------------------------
      Layer (type)           Output Shape         Param #     Tr. Param #
          Conv3d-1     [1, 48, 4, 32, 32]           4,608           4,608
       GroupNorm-2     [1, 48, 4, 32, 32]              96              96
Total params: 4,704
Trainable params: 4,704
Non-trainable params: 0
--------------------------------------------------------------------------
input: torch.Size([1, 24, 4, 64, 64])
output: torch.Size([1, 48, 4, 32, 32])


Upsample

In [29]:
class Upsample(nn.Module):
    def __init__(self, in_channels, out_channels, spatial_dims=2, kernel_size=(2,2), stride=(2,2), dropout=0.0):
        super().__init__()
        self.deconv=get_conv_layer(spatial_dims, in_channels, out_channels, kernel_size, stride, dropout, conv_only=True, is_transposed=True)
    
    def forward(self,x):
        x=self.deconv(x)
        return x

In [30]:
x=torch.zeros(1,48,32,32).cuda() # [B,C,H,W] input: 32 x 32 x 48
model=Upsample(in_channels=x.shape[1], out_channels=x.shape[1]//2).to(device)

print(summary(model,x))
print('input:',x.shape)
print('output:',model(x).shape)

-------------------------------------------------------------------------
        Layer (type)        Output Shape         Param #     Tr. Param #
   ConvTranspose2d-1     [1, 24, 64, 64]           4,608           4,608
Total params: 4,608
Trainable params: 4,608
Non-trainable params: 0
-------------------------------------------------------------------------
input: torch.Size([1, 48, 32, 32])
output: torch.Size([1, 24, 64, 64])


Video Patch Merging (1,2,2)

In [80]:
class VideoPatchMerging(nn.Module):
    def __init__(self, in_channels):
        '''
        we assume that h,w,d are even numbers.
        out_channels = 2 * in_channels.
        '''
        super().__init__()
        self.dim = in_channels
        self.reduction = nn.Linear(4 * in_channels, 2 * in_channels, bias=False)

    def forward(self, x):
        '''
        x: B,C,D,H,W
        '''
        x = x.permute(0,2,3,4,1) # [B, D, H, W, C]

        x0 = x[:, :, 0::2, 0::2, :]  # [B, D, H/2, W/2, C]
        x1 = x[:, :, 1::2, 0::2, :] 
        x2 = x[:, :, 0::2, 1::2, :]  
        x3 = x[:, :, 1::2, 1::2, :]  
        x = torch.cat([x0, x1, x2, x3], -1)  # [B, D, H/2, W/2, 4C]

        x = self.reduction(x) # [B, D, H/2, W/2, 2C]
        x = x.permute(0, 4, 1, 2, 3) # [B, 2C, D, H/2, W/2]

        return x

In [81]:
x=torch.zeros(1,24,4,64,64).cuda() # [B,C,D,H,W] input: 64 x 64 x 4 x 24
model=VideoPatchMerging(in_channels=x.shape[1]).to(device)

print(summary(model,x))
print('input:',x.shape)
print('output:',model(x).shape)

--------------------------------------------------------------------------
      Layer (type)           Output Shape         Param #     Tr. Param #
          Linear-1     [1, 4, 32, 32, 48]           4,608           4,608
Total params: 4,608
Trainable params: 4,608
Non-trainable params: 0
--------------------------------------------------------------------------
input: torch.Size([1, 24, 4, 64, 64])
output: torch.Size([1, 48, 4, 32, 32])


Patch Expanding (2,2)

In [18]:
class PatchExpanding(nn.Module):
    def __init__(self, in_channels):
        super().__init__()
        self.expand = nn.Linear(in_channels, 2 * in_channels, bias=False)

    def forward(self, y):
        """
        y: B,C,H,W
        """
        y=y.permute(0,2,3,1) # [B, H, W, C]
        B, H, W, C = y.size()

        y=self.expand(y) # B, H, W, 2*C
    
        y=rearrange(y,'b h w (p1 p2 c)-> b (h p1) (w p2) c', p1=2, p2=2, c=C//2) # B, 2*H, 2*W, C//2

        y=y.permute(0,3,1,2) # B, C//2, 2*H, 2*W
        
        return y

In [19]:
x=torch.zeros(1,48,32,32).cuda() # [B,C,H,W] input: 32 x 32 x 48
model=PatchExpanding(in_channels=x.shape[1]).to(device)

print(summary(model,x))
print('input:',x.shape)
print('output:',model(x).shape)

-----------------------------------------------------------------------
      Layer (type)        Output Shape         Param #     Tr. Param #
          Linear-1     [1, 32, 32, 96]           4,608           4,608
Total params: 4,608
Trainable params: 4,608
Non-trainable params: 0
-----------------------------------------------------------------------
input: torch.Size([1, 48, 32, 32])
output: torch.Size([1, 24, 64, 64])


Global Max Pooling

In [36]:
class Globalpool(nn.Module):
    def __init__(self, height, width):
        super().__init__()
        self.pool=nn.AdaptiveMaxPool3d((1, height, width))
    
    def forward(self,x):
        x=self.pool(x) # [B,C,1,H,W]
        x=x.squeeze(2) # [B,C,H,W]
        return x

In [37]:
x=torch.zeros(1,24,4,64,64).cuda() # [B,C,D,H,W] input: 64 x 64 x 4 x 24
model=Globalpool(height=x.shape[3], width=x.shape[4]).to(device)

print(summary(model,x))
print('input:',x.shape)
print('output:',model(x).shape)

------------------------------------------------------------------------------
          Layer (type)           Output Shape         Param #     Tr. Param #
   AdaptiveMaxPool3d-1     [1, 24, 1, 64, 64]               0               0
Total params: 0
Trainable params: 0
Non-trainable params: 0
------------------------------------------------------------------------------
input: torch.Size([1, 24, 4, 64, 64])
output: torch.Size([1, 24, 64, 64])


ResBlock

In [84]:
class ResBlock(nn.Module):
    def __init__(self, in_channels, out_channels, spatial_dims=2, kernel_size=3, stride=1, norm_name="instance",depth=2):
        super().__init__()

        self.depth = depth 
        self.resblock_set = nn.ModuleList()

        for i in range(depth):
            if i==0:
                self.resblock_set.append(UnetResBlock(spatial_dims=spatial_dims, in_channels=in_channels, out_channels=out_channels, 
                                        kernel_size=kernel_size, stride=stride, norm_name=norm_name))
            else:
                self.resblock_set.append(UnetResBlock(spatial_dims=spatial_dims, in_channels=out_channels, out_channels=out_channels, 
                         kernel_size=kernel_size, stride=stride, norm_name=norm_name))
    
    def forward(self,x):
        for i in range(self.depth):
            x = self.resblock_set[i](x)
        return x

In [87]:
x=torch.zeros(1,24,64,64).cuda() # [B,C,H,W] input: 64 x 64 x 24
model=ResBlock(in_channels=x.shape[1], out_channels=x.shape[1]//2, depth=2).to(device)

print(summary(model,x))
print('input:',x.shape)
print('output:',model(x).shape)

-----------------------------------------------------------------------
      Layer (type)        Output Shape         Param #     Tr. Param #
    UnetResBlock-1     [1, 12, 64, 64]           4,176           4,176
    UnetResBlock-2     [1, 12, 64, 64]           2,592           2,592
Total params: 6,768
Trainable params: 6,768
Non-trainable params: 0
-----------------------------------------------------------------------
input: torch.Size([1, 24, 64, 64])
output: torch.Size([1, 12, 64, 64])


ConcatConv

In [105]:
class ConcatConv(nn.Module):
    def __init__(self, in_channels, depth=1):
        super().__init__()
        self.conv = ResBlock(in_channels=in_channels*2, out_channels=in_channels, depth=depth)
    
    def forward(self, x1, x2):
        '''
        x1, x2: [B, C, H, W]
        '''
        x = torch.cat((x1,x2),dim=1)
        x = self.conv(x)

        return x

In [106]:
x1=torch.zeros(1,24,64,64).cuda() # [B,C,H,W] input: 64 x 64 x 24
x2=torch.zeros(1,24,64,64).cuda()

model=ConcatConv(in_channels=24).to(device)

print(summary(model,x1,x2))
print('input:',x1.shape, x2.shape)
print('output:',model(x1,x2).shape)

-----------------------------------------------------------------------
      Layer (type)        Output Shape         Param #     Tr. Param #
        ResBlock-1     [1, 24, 64, 64]          16,704          16,704
Total params: 16,704
Trainable params: 16,704
Non-trainable params: 0
-----------------------------------------------------------------------
input: torch.Size([1, 24, 64, 64]) torch.Size([1, 24, 64, 64])
output: torch.Size([1, 24, 64, 64])


Head

In [189]:
class Head(nn.Module):
    def __init__(self, in_channels, out_channels=3, spatial_dims=2, dropout=0.0):
        super().__init__()
        self.conv = ResBlock(in_channels=in_channels, out_channels=in_channels)
        self.head=get_conv_layer(spatial_dims, in_channels, out_channels, kernel_size=1, dropout=dropout, bias=True, conv_only=True)
    
    def forward(self,x):
        x=self.conv(x)
        x=self.head(x)
        return x

In [190]:
x=torch.zeros(1,12,256,256).cuda() # [B,C,H,W] input: 256 x 256 x 12
model=Head(in_channels=x.shape[1]).to(device)

print(summary(model,x))
print('input:',x.shape)
print('output:',model(x).shape)

-------------------------------------------------------------------------
      Layer (type)          Output Shape         Param #     Tr. Param #
        ResBlock-1     [1, 12, 256, 256]           5,184           5,184
          Conv2d-2      [1, 3, 256, 256]              39              39
Total params: 5,223
Trainable params: 5,223
Non-trainable params: 0
-------------------------------------------------------------------------
input: torch.Size([1, 12, 256, 256])
output: torch.Size([1, 3, 256, 256])


[Encoder] Spatial Attention (3D)  
-> Q, K, V : HW x DC

In [174]:
class SpatialAttn(nn.Module):
    def __init__(self, input_size, dim, num_heads=4, qkv_bias=False, attn_drop=0.1, proj_drop=0.1):
        '''
        input_size: resolution (H*W)
        dim: channel * depth (C*D)
        '''
        super().__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        # qkv are 3 linear layers (query, key, value)
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)

        self.attn_drop = nn.Dropout(attn_drop) 

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
    
    def forward(self, x):        
        '''
        Spatial Attention
        : no projection 

        x: [B, HW, DC] 
        '''
        B, HW, DC = x.shape 

        qkv = self.qkv(x).reshape(B, HW, 3, self.num_heads, DC // self.num_heads) # B x HW x 3 x h x C/h
        qkv = qkv.permute(2, 0, 3, 1, 4) # 3 x B x h x HW x DC/h
        q, k, v = qkv[0], qkv[1], qkv[2] # B x h x HW x DC/h

        q = torch.nn.functional.normalize(q, dim=-2)
        k = torch.nn.functional.normalize(k, dim=-2)
        k_t = k.permute(0, 1, 3, 2) # K_T : B x h x DC/h x HW

        attn_SA = (q @ k_t) * self.temperature  # [Q x K_T] B x h x HW x HW
        
        attn_SA = attn_SA.softmax(dim=-1)
        attn_SA = self.attn_drop(attn_SA) # [Spatial Attn Map] B x h x HW x HW
        
        # [Spatial Attn Map x V] B x h x HW x DC/h -> B x HW x h x DC/h -> B x HW x DC
        x_SA = (attn_SA @ v).permute(0, 2, 1, 3).reshape(B, HW, DC) 
        
        # linear projection for msa
        x = self.proj(x_SA)
        x = self.proj_drop(x)

        return x

In [175]:
class SpatialAttnBlock(nn.Module):
     def __init__(self, conv_hidden, input_size, dim, num_heads=4, qkv_bias=False, attn_drop=0.1, proj_drop=0.1,is_pos_embed=False):
          '''
          input_size: resolution (H*W)
          dim: channel * depth (C*D)
          '''
          super().__init__()

          self.norm = nn.LayerNorm(dim)
          self.is_pos_embed = is_pos_embed
          self.pos_embed = nn.Parameter(torch.zeros(1, input_size, dim))
          self.spatial_attn = SpatialAttn(input_size, dim, num_heads, qkv_bias, attn_drop, proj_drop)
          self.dsconv = DSConv(in_channels=conv_hidden, out_channels=conv_hidden)

     def forward(self,x):
          '''
          x: [B, C, D, H, W]
          '''
          B, C, D, H, W = x.shape
          save = x
          
          x = rearrange(x,'b c d h w-> b (h w) (c d)', b=B, c=C, d=D, h=H, w=W) # [B,HW,DC]
          if self.is_pos_embed:
               x = x + self.pos_embed

          # spatial attn -> norm
          x = self.norm(self.spatial_attn(x))
          x = rearrange(x,'b (h w) (c d)-> b c d h w', b=B, c=C, d=D, h=H, w=W) # [B,C,D,H,W]
          x += save

          # conv -> norm
          x += self.dsconv(x)

          return x

In [177]:
x=torch.zeros(1,24,4,64,64).cuda() # [B,C,D,H,W] input: 64 x 64 x 4 x 24
model=SpatialAttnBlock(conv_hidden=x.shape[1], input_size=x.shape[3]*x.shape[4], dim=x.shape[1]*x.shape[2],is_pos_embed=True).to(device)

print(summary(model,x))
print('input:',x.shape)
print('output:',model(x).shape)

--------------------------------------------------------------------------
      Layer (type)           Output Shape         Param #     Tr. Param #
     SpatialAttn-1          [1, 4096, 96]          36,964          36,964
       LayerNorm-2          [1, 4096, 96]             192             192
          DSConv-3     [1, 24, 4, 64, 64]           1,272           1,272
Total params: 38,428
Trainable params: 38,428
Non-trainable params: 0
--------------------------------------------------------------------------
input: torch.Size([1, 24, 4, 64, 64])
output: torch.Size([1, 24, 4, 64, 64])


[Encoder] Temporal Attention (3D)  
-> Q, K, V : D x HWC

In [27]:
class TemporalAttn(nn.Module):
    def __init__(self, input_size, dim, proj_size, squeeze=8, num_heads=4, qkv_bias=False, attn_drop=0.1, proj_drop=0.1):
        '''
        input_size: depth (D)
        dim: resolution * channel (H*W*C)
        '''
        super().__init__()
        self.num_heads = num_heads
        self.dim = dim
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        # qkv are 3 linear layers (query, key, value)
        # we use bottlenect architecture for efficient calculation!!
        self.qkv = nn.Sequential(
            nn.Linear(dim, squeeze, bias=qkv_bias),
            nn.Linear(squeeze, proj_size, bias=qkv_bias),
            nn.Linear(proj_size,squeeze, bias=qkv_bias),
            nn.Linear(squeeze, 3*dim, bias=qkv_bias),
        )

        # projection matrices with shared weights used in attention module to project
        self.proj_size = proj_size
        self.proj_q = self.proj_k = nn.Linear(dim, proj_size)

        self.attn_drop = nn.Dropout(attn_drop) 

        self.proj = nn.Sequential(
            nn.Linear(dim, squeeze, bias=qkv_bias),
            nn.Linear(squeeze, proj_size, bias=qkv_bias),
            nn.Linear(proj_size,squeeze, bias=qkv_bias),
            nn.Linear(squeeze, dim, bias=qkv_bias),
        )
        self.proj_drop = nn.Dropout(proj_drop)
    
    def forward(self, x):        
        '''
        Temporal Attention
        : Q -> Q(p), K -> K(p) [ Q(p) x K_T(p) ]

        x: [B, D, HWC] 
        '''
        B, D, HWC = x.shape 

        qkv = self.qkv(x).reshape(B, D, 3, HWC).permute(2,0,1,3) # B x D x 3 x HWC -> 3 x B x D x HWC
        q, k, v = qkv[0], qkv[1], qkv[2] # B x D x HWC

        q_projected = self.proj_q(q) # B x D x P
        k_projected = self.proj_k(k) # B x D x p

        q_projected = q_projected.reshape(B, D, self.num_heads, self.proj_size // self.num_heads).permute(0,2,1,3) # B x D x h x P/h -> B x h x D x P/h
        k_projected = k_projected.reshape(B, D, self.num_heads, self.proj_size // self.num_heads).permute(0,2,1,3) # B x D x h x P/h -> B x h x D x P/h
        v = v.reshape(B, D, self.num_heads, self.dim // self.num_heads).permute(0,2,1,3) # B x D x h x HWC/h -> B x h x D x HWC/h

        q_projected = torch.nn.functional.normalize(q_projected, dim=-2)
        k_projected = torch.nn.functional.normalize(k_projected, dim=-2)
        k_t_projected = k_projected.transpose(-2, -1) # K_T : B x h x P/h x D

        attn_TA = (q_projected @ k_t_projected) * self.temperature  # [Q(p) x K_T(p)] B x h x D x D
        
        attn_TA = attn_TA.softmax(dim=-1)
        attn_TA = self.attn_drop(attn_TA) # [Temporal Attn Map] B x h x D x D
        
        # [Temporal Attn Map x V(p)] B x h x D x HWC/h -> B x D x h x HWC/h -> B x D x HWC
        x_TA = (attn_TA @ v).permute(0, 2, 1, 3).reshape(B, D, HWC) 
        
        # linear projection for msa
        x = self.proj(x_TA)
        x = self.proj_drop(x)

        return x

In [170]:
class TemporalAttnBlock(nn.Module):
     def __init__(self, conv_hidden, input_size, dim, proj_size=64, squeeze=8, num_heads=4, qkv_bias=False, attn_drop=0.1, proj_drop=0.1, is_pos_embed=False):
          '''
          input_size: depth (D)
          dim: resolution * channel (H*W*C)
          '''
          super().__init__()         

          self.norm = nn.LayerNorm(dim)
          self.is_pos_embed = is_pos_embed
          self.pos_embed = nn.Parameter(torch.zeros(1, input_size, dim))
          self.temporal_attn = TemporalAttn(input_size, dim, proj_size, squeeze, num_heads, qkv_bias, attn_drop, proj_drop)
          self.dsconv = DSConv(in_channels=conv_hidden, out_channels=conv_hidden)

     def forward(self,x):
          '''
          x: [B, C, D, H, W]
          '''
          B, C, D, H, W = x.shape
          save = x
          
          x = rearrange(x,'b c d h w-> b d (h w c)', b=B, c=C, d=D, h=H, w=W) # [B,D,HWC]
          if self.is_pos_embed:
            x = x + self.pos_embed

          # temporal attn -> norm
          x = self.norm(self.temporal_attn(x))
          x = rearrange(x,'b d (h w c)-> b c d h w', b=B, c=C, d=D, h=H, w=W) # [B,C,D,H,W]
          x += save

          # conv -> norm
          x += self.dsconv(x)
        
          return x

In [172]:
x=torch.zeros(1,24,4,64,64).cuda() # [B,C,D,H,W] input: 64 x 64 x 4 x 24
model=TemporalAttnBlock(conv_hidden=x.shape[1], input_size=x.shape[2], dim=x.shape[1]*x.shape[3]*x.shape[4], is_pos_embed=True).to(device)

print(summary(model,x))
print('input:',x.shape)
print('output:',model(x).shape)

--------------------------------------------------------------------------
      Layer (type)           Output Shape         Param #     Tr. Param #
    TemporalAttn-1          [1, 4, 98304]      11,012,164      11,012,164
       LayerNorm-2          [1, 4, 98304]         196,608         196,608
          DSConv-3     [1, 24, 4, 64, 64]           1,272           1,272
Total params: 11,210,044
Trainable params: 11,210,044
Non-trainable params: 0
--------------------------------------------------------------------------
input: torch.Size([1, 24, 4, 64, 64])
output: torch.Size([1, 24, 4, 64, 64])


[Encoder] Spatio-Temporal Attention (3D)  
-> Q, K, V : HWD x C

In [137]:
class SpatioTemporalAttn(nn.Module):
    def __init__(self, input_size, dim, proj_size, num_heads=4, qkv_bias=False, attn_drop=0., proj_drop=0.1):
        super().__init__()
        '''
        input_size: resolution * depth (H*W*D)
        dim: channel (C)
        '''
        self.num_heads = num_heads
        self.dim = dim
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        # qkv are 3 linear layers (query, key, value)
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)

        # projection matrices with shared weights used in attention module to project
        self.proj_k = self.proj_v = nn.Linear(input_size, proj_size)
        self.attn_drop = nn.Dropout(attn_drop) 

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
    
    def forward(self, x):
        B, HWD, C = x.shape 
        
        '''
        Spatio-Temporal Attention
        : K -> K(p), V -> V(p) [ Q x K_T(p) ]
        '''
        qkv = self.qkv(x).reshape(B, HWD, 3, self.num_heads, C // self.num_heads) # B x HWD x 3 x h x C/h
        qkv = qkv.permute(2, 0, 3, 1, 4) # 3 x B x h x HWD x C/h
        q, k, v = qkv[0], qkv[1], qkv[2] # B x h x HWD x C/h

        q = torch.nn.functional.normalize(q, dim=-2)
        k = torch.nn.functional.normalize(k, dim=-2)

        k_t = k.transpose(-2, -1) # B x h x C/h x HWD
        v_t = v.transpose(-2, -1) # B x h x C/h x HWD

        k_t_projected = self.proj_k(k_t) # B x h x C/h x p
        v_t_projected = self.proj_v(v_t) # B x h x C/h x p

        attn_STA = (q @ k_t_projected) * self.temperature  # [Q x K_T(p)] B x h x HWD x p
        
        attn_STA = attn_STA.softmax(dim=-1)
        attn_STA = self.attn_drop(attn_STA) # [Spatial-Temporal Attn Map] B x h x HWD x p
        
        v_projected = v_t_projected.transpose(-2, -1) # V(p) : B x h x p x C/h

        # [Spatio-Temporal Attn Map x V] B x h x HWD x C/h -> B x HWD x h x C/h -> B x HWD x C
        x_STA = (attn_STA @ v_projected).permute(0, 2, 1, 3).reshape(B, HWD, C) 
        
        # linear projection for msa
        x = self.proj(x_STA)
        x = self.proj_drop(x)

        return x

In [167]:
class SpatioTemporalAttnBlock(nn.Module):
     def __init__(self, conv_hidden, input_size, dim, proj_size=64, num_heads=4, qkv_bias=False, attn_drop=0.1, proj_drop=0.1,is_pos_embed=False):
          '''
          input_size: resolution * depth (H*W*D)
          dim: channel (C)
          '''
          super().__init__()

          self.norm = nn.LayerNorm(dim)
          self.is_pos_embed = is_pos_embed
          self.pos_embed = nn.Parameter(torch.zeros(1, input_size, dim))
          self.spatio_temporal_attn = SpatioTemporalAttn(input_size, dim, proj_size, num_heads, qkv_bias, attn_drop, proj_drop)
          self.dsconv = DSConv(in_channels=conv_hidden, out_channels=conv_hidden)

     def forward(self,x):
          '''
          x: [B, C, D, H, W]
          '''
          B, C, D, H, W = x.shape
          save = x
          
          x = rearrange(x,'b c d h w-> b (h w d) c', b=B, c=C, d=D, h=H, w=W) # [B,HWD,C]
          if self.is_pos_embed:
               x = x + self.pos_embed

          # spatio temporal attn -> norm
          x = self.norm(self.spatio_temporal_attn(x))
          x = rearrange(x,'b (h w d) c-> b c d h w', b=B, c=C, d=D, h=H, w=W) # [B,C,D,H,W]
          x += save

          # conv -> norm
          x += self.dsconv(x)

          return x

In [169]:
x=torch.zeros(2,24,4,64,64).cuda() # [B,C,D,H,W] input: 64 x 64 x 4 x 24
model=SpatioTemporalAttnBlock(conv_hidden=x.shape[1], input_size=x.shape[2]*x.shape[3]*x.shape[4], dim=x.shape[1], is_pos_embed=True).to(device)

print(summary(model,x))
print('input:',x.shape)
print('output:',model(x).shape)

-------------------------------------------------------------------------------
           Layer (type)           Output Shape         Param #     Tr. Param #
   SpatioTemporalAttn-1         [2, 16384, 24]       1,050,972       1,050,972
            LayerNorm-2         [2, 16384, 24]              48              48
               DSConv-3     [2, 24, 4, 64, 64]           1,272           1,272
Total params: 1,052,292
Trainable params: 1,052,292
Non-trainable params: 0
-------------------------------------------------------------------------------
input: torch.Size([2, 24, 4, 64, 64])
output: torch.Size([2, 24, 4, 64, 64])


[Encoder] Channel Attention (3D)  
-> Q, K, V : C x HWD

In [126]:
class ChannelAttn(nn.Module):
    def __init__(self, input_size, dim, proj_size, num_heads=4, qkv_bias=False, attn_drop=0.1, proj_drop=0.1):
        '''
        input_size: channel (C)
        dim: resolution * Depth (H*W*D)
        '''
        super().__init__()
        self.num_heads = num_heads
        self.dim = dim
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        # qkv are 3 linear layers (query, key, value)
        self.qkv = nn.Sequential(
            nn.Linear(dim, proj_size, bias=qkv_bias),
            nn.Linear(proj_size, 3*dim, bias=qkv_bias),
        )

        # projection matrices with shared weights used in attention module to project
        self.proj_size = proj_size
        self.proj_q = self.proj_k = nn.Linear(dim, proj_size)

        self.attn_drop = nn.Dropout(attn_drop) 

        self.proj = nn.Sequential(
            nn.Linear(dim, proj_size, bias=qkv_bias),
            nn.Linear(proj_size, dim, bias=qkv_bias),
        )
        self.proj_drop = nn.Dropout(proj_drop)
    
    def forward(self, x):        
        '''
        Channel Attention
        : Q -> Q(p), K -> K(p) [ Q(p) x K_T(p) ]

        x: [B, C, HWD] 
        '''
        B, C, HWD = x.shape 

        qkv = self.qkv(x).reshape(B, C, 3, HWD).permute(2,0,1,3) # B x C x 3 x HWD -> 3 x B x C x HWD
        q, k, v = qkv[0], qkv[1], qkv[2] # B x C x HWD

        q_projected = self.proj_q(q) # B x C x P
        k_projected = self.proj_k(k) # B x C x p

        q_projected = q_projected.reshape(B, C, self.num_heads, self.proj_size // self.num_heads).permute(0,2,1,3) # B x C x h x P/h -> B x h x C x P/h
        k_projected = k_projected.reshape(B, C, self.num_heads, self.proj_size // self.num_heads).permute(0,2,1,3) # B x C x h x P/h -> B x h x C x P/h
        v = v.reshape(B, C, self.num_heads, self.dim // self.num_heads).permute(0,2,1,3) # B x C x h x HWD/h -> B x h x C x HWD/h

        q_projected = torch.nn.functional.normalize(q_projected, dim=-2)
        k_projected = torch.nn.functional.normalize(k_projected, dim=-2)
        k_t_projected = k_projected.transpose(-2, -1) # K_T : B x h x P/h x C

        attn_CA = (q_projected @ k_t_projected)   # [Q(p) x K_T(p)] B x h x C x C
        attn_CA = attn_CA * self.temperature
        
        attn_CA = attn_CA.softmax(dim=-1)
        attn_CA = self.attn_drop(attn_CA) # [Channel Attn Map] B x h x C x C

        # [Channel Attn Map x V(p)] B x h x C x HWD/h -> B x C x h x HWD/h -> B x C x HWD
        x_CA = (attn_CA @ v).permute(0, 2, 1, 3).reshape(B, C, HWD) 
        
        # linear projection for msa
        x = self.proj(x_CA)
        x = self.proj_drop(x)

        return x

In [165]:
class ChannelAttnBlock(nn.Module):
     def __init__(self, conv_hidden, input_size, dim, proj_size=64, num_heads=4, qkv_bias=False, attn_drop=0.1, proj_drop=0.1,is_pos_embed=False):
          '''
          input_size: channel (C)
          dim: resolution * Depth (H*W*D)
          '''
          super().__init__()
        
          self.norm = nn.LayerNorm(dim)
          self.is_pos_embed = is_pos_embed
          self.pos_embed = nn.Parameter(torch.zeros(1, input_size, dim))
          self.channel_attn = ChannelAttn(input_size, dim, proj_size, num_heads, qkv_bias, attn_drop, proj_drop)
          self.dsconv = DSConv(in_channels=conv_hidden, out_channels=conv_hidden)

     def forward(self,x):
          '''
          x: [B, C, D, H, W]
          '''
          B, C, D, H, W = x.shape
          save = x

          x = rearrange(x,'b c d h w-> b c (h w d)', b=B, c=C, d=D, h=H, w=W) # [B,C,HWD]
          if self.is_pos_embed:
               x = x + self.pos_embed
          
          # channel attn -> norm
          x = self.norm(self.channel_attn(x))
          x = rearrange(x,'b c (h w d)-> b c d h w', b=B, c=C, d=D, h=H, w=W) # [B,C,D,H,W]
          x += save

          # conv -> norm
          x += self.dsconv(x)

          return x

In [166]:
x=torch.zeros(1,24,4,64,64).cuda() # [B,C,D,H,W] input: 64 x 64 x 4 x 24
model=ChannelAttnBlock(conv_hidden=x.shape[1], input_size=x.shape[1], dim=x.shape[2]*x.shape[3]*x.shape[4], is_pos_embed=True).to(device)

print(summary(model,x))
print('input:',x.shape)
print('output:',model(x).shape)

--------------------------------------------------------------------------
      Layer (type)           Output Shape         Param #     Tr. Param #
     ChannelAttn-1         [1, 24, 16384]       7,340,100       7,340,100
       LayerNorm-2         [1, 24, 16384]          32,768          32,768
          DSConv-3     [1, 24, 4, 64, 64]           1,272           1,272
Total params: 7,374,140
Trainable params: 7,374,140
Non-trainable params: 0
--------------------------------------------------------------------------
input: torch.Size([1, 24, 4, 64, 64])
output: torch.Size([1, 24, 4, 64, 64])


[Encoder] Fusion Module

In [178]:
class AttentionFusion(nn.Module):
    def __init__(self,in_depths,hidden_size,is_three,norm_name="instance",depth=1):
        super().__init__()

        self.is_three=is_three

        if is_three:
            self.out_proj = nn.Linear(hidden_size, int(hidden_size // 3))
            self.out_proj2 = nn.Linear(hidden_size, int(hidden_size // 3))
            self.out_proj3 = nn.Linear(hidden_size, int(hidden_size // 3))
        
        else:
            self.out_proj = nn.Linear(hidden_size, int(hidden_size // 2))
            self.out_proj2 = nn.Linear(hidden_size, int(hidden_size // 2))
        
        self.conv2d = ResBlock(spatial_dims=2, in_channels=hidden_size*in_depths, out_channels=hidden_size, norm_name=norm_name, depth=depth)

    def forward(self,x1,x2,x3=None):
        if self.is_three:
            x1 = rearrange(x1, "b c d h w -> b d h w c")        
            x2 = rearrange(x2, "b c d h w -> b d h w c")   
            x3 = rearrange(x3,"b c d h w -> b d h w c")

            x1 = self.out_proj(x1)
            x2 = self.out_proj2(x2)
            x3 = self.out_proj3(x3)

            x = torch.cat((x1,x2,x3),dim=-1)            

        else:
            x1 = rearrange(x1, "b c d h w -> b d h w c")        
            x2 = rearrange(x2, "b c d h w -> b d h w c")   

            x1 = self.out_proj(x1)
            x2 = self.out_proj2(x2)
            
            x = torch.cat((x1,x2),dim=-1)
        
        x = rearrange(x,"b d h w c -> b c d h w")
        x = rearrange(x,'b c d h w-> b (c d) h w') # [B,CD,H,W]
        x = self.conv2d(x)

        return x

In [179]:
x1=torch.zeros(1,24,4,64,64).cuda() # [B,C,D,H,W] input: 64 x 64 x 4 x 24
x2=torch.zeros(1,24,4,64,64).cuda()
x3=torch.zeros(1,24,4,64,64).cuda()

model=AttentionFusion(in_depths=4,hidden_size=24,is_three=True).to(device)

print(summary(model,x1,x2,x3))
print('input:',x1.shape, x2.shape, x3.shape)
print('output:',model(x1,x2,x3).shape)

-------------------------------------------------------------------------
      Layer (type)          Output Shape         Param #     Tr. Param #
          Linear-1     [1, 4, 64, 64, 8]             200             200
          Linear-2     [1, 4, 64, 64, 8]             200             200
          Linear-3     [1, 4, 64, 64, 8]             200             200
        ResBlock-4       [1, 24, 64, 64]          28,224          28,224
Total params: 28,824
Trainable params: 28,824
Non-trainable params: 0
-------------------------------------------------------------------------
input: torch.Size([1, 24, 4, 64, 64]) torch.Size([1, 24, 4, 64, 64]) torch.Size([1, 24, 4, 64, 64])
output: torch.Size([1, 24, 64, 64])


In [180]:
x1=torch.zeros(1,24,4,64,64).cuda() # [B,C,D,H,W] input: 64 x 64 x 4 x 24
x2=torch.zeros(1,24,4,64,64).cuda()

model=AttentionFusion(in_depths=4,hidden_size=24,is_three=False).to(device)

print(summary(model,x1,x2))
print('input:',x1.shape, x2.shape)
print('output:',model(x1,x2).shape)

--------------------------------------------------------------------------
      Layer (type)           Output Shape         Param #     Tr. Param #
          Linear-1     [1, 4, 64, 64, 12]             300             300
          Linear-2     [1, 4, 64, 64, 12]             300             300
        ResBlock-3        [1, 24, 64, 64]          28,224          28,224
Total params: 28,824
Trainable params: 28,824
Non-trainable params: 0
--------------------------------------------------------------------------
input: torch.Size([1, 24, 4, 64, 64]) torch.Size([1, 24, 4, 64, 64])
output: torch.Size([1, 24, 64, 64])


[Decoder] Spatial Attention (2D)  
-> Q, K, V: HW x C

In [183]:
class SpatialAttn2D(nn.Module):
    def __init__(self, input_size, dim, num_heads=4, qkv_bias=False, attn_drop=0.1, proj_drop=0.1):
        '''
        input_size: resolution (H*W)
        dim: channel (C)
        '''
        super().__init__()
        self.num_heads = num_heads
        self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))

        # qkv are 3 linear layers (query, key, value)
        self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)

        self.attn_drop = nn.Dropout(attn_drop) 

        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)
    
    def forward(self, x):        
        '''
        Spatial Attention
        : no projection 

        x: [B, HW, C] 
        '''
        B, HW, C = x.shape 

        qkv = self.qkv(x).reshape(B, HW, 3, self.num_heads, C // self.num_heads) # B x HW x 3 x h x C/h
        qkv = qkv.permute(2, 0, 3, 1, 4) # 3 x B x h x HW x C/h
        q, k, v = qkv[0], qkv[1], qkv[2] # B x h x HW x C/h

        q = torch.nn.functional.normalize(q, dim=-2)
        k = torch.nn.functional.normalize(k, dim=-2)
        k_t = k.permute(0, 1, 3, 2) # K_T : B x h x C/h x HW

        attn_SA = (q @ k_t) * self.temperature  # [Q x K_T] B x h x HW x HW
        
        attn_SA = attn_SA.softmax(dim=-1)
        attn_SA = self.attn_drop(attn_SA) # [Spatial Attn Map] B x h x HW x HW
        
        # [Spatial Attn Map x V] B x h x HW x C/h -> B x HW x h x DC/h -> B x HW x C
        x_SA = (attn_SA @ v).permute(0, 2, 1, 3).reshape(B, HW, C) 
        
        # linear projection for msa
        x = self.proj(x_SA)
        x = self.proj_drop(x)

        return x

In [184]:
class SpatialAttn2DBlock(nn.Module):
     def __init__(self, conv_hidden, input_size, dim, num_heads=4, qkv_bias=False, attn_drop=0.1, proj_drop=0.1,is_pos_embed=False):
          '''
          input_size: resolution (H*W)
          dim: channel (C)
          '''
          super().__init__()

          self.norm = nn.LayerNorm(dim)
          self.is_pos_embed = is_pos_embed
          self.pos_embed = nn.Parameter(torch.zeros(1, input_size, dim))
          self.spatial_attn_2d = SpatialAttn2D(input_size, dim, num_heads, qkv_bias, attn_drop, proj_drop)
          self.dsconv = DSConv(in_channels=conv_hidden, out_channels=conv_hidden, spatial_dims=2)

     def forward(self,x):
          '''
          x: [B, C, H, W]
          '''
          B, C, H, W = x.shape
          save = x
          
          x = rearrange(x,'b c h w-> b (h w) c', b=B, c=C, h=H, w=W) # [B,HW,C]
          if self.is_pos_embed:
               x = x + self.pos_embed

          # spatial attn -> norm
          x = self.norm(self.spatial_attn_2d(x))
          x = rearrange(x,'b (h w) c-> b c h w', b=B, c=C, h=H, w=W) # [B,C,H,W]
          x += save

          # conv -> norm
          x += self.dsconv(x)

          return x

In [187]:
x=torch.zeros(1,96,16,16).cuda() # [B,C,H,W] input: 64 x 64 x 24
model=SpatialAttn2DBlock(conv_hidden=x.shape[1], input_size=x.shape[2]*x.shape[3], dim=x.shape[1], is_pos_embed=True).to(device)

print(summary(model,x))
print('input:',x.shape)
print('output:',model(x).shape)

-----------------------------------------------------------------------
      Layer (type)        Output Shape         Param #     Tr. Param #
   SpatialAttn2D-1        [1, 256, 96]          36,964          36,964
       LayerNorm-2        [1, 256, 96]             192             192
          DSConv-3     [1, 96, 16, 16]          10,272          10,272
Total params: 47,428
Trainable params: 47,428
Non-trainable params: 0
-----------------------------------------------------------------------
input: torch.Size([1, 96, 16, 16])
output: torch.Size([1, 96, 16, 16])
