In [1]:
import torch
from torch import nn, einsum
from einops import rearrange
from pytorch_model_summary import summary
from monai.networks.layers.utils import get_norm_layer
from unetr_plus_plus.unetr_pp.network_architecture.dynunet_block import get_conv_layer, UnetResBlock
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") 
print(device)

cuda


patch merging 테스트 (embedding 없이)

In [2]:
class PatchMerging(nn.Module):
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        # self.norm = norm_layer(8 * dim)
        # self.reduction = nn.Linear(8 * dim, 2 * dim, bias=False)

    def forward(self, x):
        """
        x: B,C,D,H,W
        """
        x=x.permute(0,3,4,2,1) # [B,H,W,D,C]
        B=x.shape[0];H=x.shape[1];W=x.shape[2];D=x.shape[3];C=x.shape[4]

        y=None
        for i in range(0,D,2):
            # process 2 slice
            x_=x[:, :, :, i:i+2, :] # B, H/2, W/2, 2, C
            
            x_0=x_[:, 0::2, 0::2, :, :] # B, H/2, W/2, 2, C
            x_1=x_[:, 0::2, 1::2, :, :] # B, H/2, W/2, 2, C 
            x_2=x_[:, 1::2, 0::2, :, :]  # B, H/2, W/2, 2, C 
            x_3=x_[:, 1::2, 1::2, :, :] # B, H/2, W/2, 2, C

            # width, height information -> channel information
            rst=torch.cat([x_0,x_1,x_2,x_3],-1) # B, H/2, W/2, 2, 4*C

            # dimension information -> channel information
            rst=rst.view(B, H//2, W//2, 1, 8*C) # B, H/2, W/2, 1, 8*C

            # concat 
            if i==0:
                y=rst.clone() # B, H/2, W/2, 1, 8*C
            else:
                y=torch.cat([y,rst],-2) # final shape -> [B, H/2, W/2, D/2, 8*C]
        
        # # normalization
        # y=self.norm(y) # B, H/2, W/2, D/2, 8*C
        
        # # embedding
        # y=self.reduction(y) # B, H/2, W/2, D/2, 2*C

        y=y.permute(0,4,3,1,2) # B, 2*C, D/2, H/2, W/2
        return y

In [4]:
x=torch.tensor([[[1,2,3,4],
                 [5,6,7,8],
                 [9,10,11,12],
                 [13,14,15,16]],
                [[17,18,19,20],
                 [21,22,23,24],
                 [25,26,27,28],
                 [29,30,31,32]],
                [[33,34,35,36],
                 [37,38,39,40],
                 [41,42,43,44],
                 [45,46,47,48]],
                [[49,50,51,52],
                 [53,54,55,56],
                 [57,58,59,60],
                 [61,62,63,64]]])
# x=torch.rand(6,6,6)
print('Input:',x,sep='\n')
x=x.reshape(1,1,4,4,4) # B,C,D,H,W

model=PatchMerging(dim=x.shape[1])
y=model(x)

print('input:',x.shape)
print('output:',y.shape)

Input:
tensor([[[ 1,  2,  3,  4],
         [ 5,  6,  7,  8],
         [ 9, 10, 11, 12],
         [13, 14, 15, 16]],

        [[17, 18, 19, 20],
         [21, 22, 23, 24],
         [25, 26, 27, 28],
         [29, 30, 31, 32]],

        [[33, 34, 35, 36],
         [37, 38, 39, 40],
         [41, 42, 43, 44],
         [45, 46, 47, 48]],

        [[49, 50, 51, 52],
         [53, 54, 55, 56],
         [57, 58, 59, 60],
         [61, 62, 63, 64]]])
input: torch.Size([1, 1, 4, 4, 4])
output: torch.Size([1, 8, 2, 2, 2])


In [5]:
print(y[0,:,0,0,0])

tensor([ 1,  2,  5,  6, 17, 18, 21, 22])


In [6]:
print(y[0,:,1,0,0])

tensor([33, 34, 37, 38, 49, 50, 53, 54])


In [7]:
print(y[0,:,0,0,1])

tensor([ 3,  4,  7,  8, 19, 20, 23, 24])


In [8]:
print(y[0,:,1,0,1])

tensor([35, 36, 39, 40, 51, 52, 55, 56])


Patch Expanding 테스트

In [14]:
class PatchExpanding(nn.Module):
    def __init__(self, dim, norm_layer=nn.LayerNorm):
        super().__init__()
        self.dim = dim
        self.norm = norm_layer(dim//2)
        self.expand = nn.Linear(dim, 4 * dim, bias=False)

    def forward(self, y):
        """
        y: B,C,D,H,W
        """
        y=y.permute(0,3,4,2,1) # [B,H,W,D,C]
        B=y.shape[0];H=y.shape[1];W=y.shape[2];D=y.shape[3];C=y.shape[4]

        # channel expand
        # y=self.expand(y) # B, H, W, D, 4*C

        x=None
        for i in range(0,D):
            y_=y[:,:,:,i,:] # B, H, W, 1, 4*C
            y_=y_.view(B,H,W,1,8) 

            # channel information -> dimension information
            y_=y_.view(B, H, W, 2, 4) # B, H, W, 2, 2*C

            # channel informatinon -> width, height information
            rst=rearrange(y_,'b h w d (p1 p2 c)-> b (h p1) (w p2) d c', p1=2, p2=2, c=C//8) # B, 2*H, 2*W, 2, C//2

            # concat
            if i==0:
                x=rst.clone() # B, 2*H, 2*W, 2, C//2
            else:
                x=torch.cat([x,rst],-2) # final shape -> [B, 2*H, 2*W, 2*D, C//2]
                        
        # normalization
        # x=self.norm(x) # B, 2*H, 2*W, 2*D, C//2

        x=x.permute(0,4,3,1,2) # B, C//2, 2*D, 2*H, 2*W
        return x

In [15]:
y=y.reshape(1, 8, 2, 2, 2)

model=PatchExpanding(dim=y.shape[1])
z=model(y)


In [17]:
z=z.view(4,4,4)
print('Output:',z,sep='\n')

Output:
tensor([[[ 1,  2,  3,  4],
         [ 5,  6,  7,  8],
         [ 9, 10, 11, 12],
         [13, 14, 15, 16]],

        [[17, 18, 19, 20],
         [21, 22, 23, 24],
         [25, 26, 27, 28],
         [29, 30, 31, 32]],

        [[33, 34, 35, 36],
         [37, 38, 39, 40],
         [41, 42, 43, 44],
         [45, 46, 47, 48]],

        [[49, 50, 51, 52],
         [53, 54, 55, 56],
         [57, 58, 59, 60],
         [61, 62, 63, 64]]])
