In [1]:
import torch
import torch
import torch.nn as nn
import torch.nn.functional as F
import einops

In [60]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

# Define the CPA (Channel Pixel Attention) block
class CPA(nn.Module):
    """
    *same=False:
    This scenario can be easily embedded after any CNNs, if size is same.
    x (OG) ---------------
    |                    |
    sc_x (from CNNs)     CPA(x)
    |                    |
    out + <---------------
    
    *same=True:
    This can be embedded after the CNNs where the size are different.
    x (OG) ---------------
    |                    |
    sc_x (from CNNs)     |
    |                    CPA(x)
    CPA(sc_x)            |
    |                    |
    out + <---------------
        
    *sc_x=False
    This operation can be seen a channel embedding with CPA
    EX: x (3, 32, 32) => (16, 32, 32)
    x (OG) 
    |      
    CPA(x)
    |    
    out 
    """
     
    # Constructor
    def __init__(self, in_dimension, out_dimension, stride=1, same=False, sc_x=True):
    
        # Call parent constructor
        super(CPA, self).__init__()

        # Initialize parameters
        self.in_dimension = in_dimension
        self.out_dimension = out_dimension
        self.stride = stride
        self.same = same
        self.sc_x = sc_x

        # Define the CP_FFC layer
        self.CP_FFC = nn.Linear(in_dimension, out_dimension)
        self.BN = nn.BatchNorm3d(out_dimension)

        # If the stride is 2 or they're the same
        if stride == 2 or same:
            # If sc_x is true
            if sc_x:
                # Define the CP_FFC layer
                self.CP_FFC_sc = nn.Linear(in_dimension, out_dimension)
                self.BN_sc = nn.BatchNorm3d(out_dimension)

            # If it's just that the stride is 2
            if stride == 2:
                # Define the average pooling layer
                self.avg_pool = nn.AvgPool3d((2, 2, 2))

    # Forward function
    def forward(self, x, sc_x):
     
        # Get the shape of the input
        b, c, h, w, d = x.shape

        # Rearrange the input
        out = rearrange(x, 'b c h w d -> b h w d c', c=c, h=h, w=w, d=d)
        out = self.CP_FFC(out)
        out = rearrange(out, 'b h w d c -> b c h w d', c=self.out_dimension, h=h, w=w, d=d)
        out = self.BN(out)

        # If they have the same shape
        if out.shape == sc_x.shape:
            # If sc_x is true
            if self.sc_x:
                # Add the two
                out = out + sc_x
            # Layer norm
            out = F.layer_norm(out, out.size()[1:])
        
        # If they're not the same shape
        else:
            # Layer norm
            out = F.layer_norm(out, out.size()[1:])
            # If sc_x is true
            if self.sc_x:
                # Set x to sc_x
                x = sc_x
        
        # If the stride is 2
        if self.stride == 2 or self.same:
            # If sc_x is true
            if self.sc_x:
                # Get the shape of the input
                _, c, h, w, d = x.shape
                # Rearrange the input
                x = rearrange(x, 'b c h w d -> b (h w d) c')
                x = self.CP_FFC_sc(x)
                x = rearrange(x, 'b (h w d) c -> b c h w d', h=h, w=w, d=d)
                x = self.BN_sc(x)
                out = out + x   

            # If they're the same
            if self.same:
                # Return out
                return out

            # Average pool
            out = self.avg_pool(out)

        # Return out
        return out  

# Define the spatial pixel attention
class SPA(nn.Module):

    # Constructor
    def __init__(self, img, out=1):

        # Call parent constructor
        super(SPA, self).__init__()

        # Initialize parameters
        self.SP_FFC = nn.Sequential(
            nn.Linear(img**3, out**3),
        )

    # Forward function
    def forward(self, x):

        # Get the shape of the input
        b, c, h, w, d = x.shape

        # Rearrange the input
        x = rearrange(x, 'b c h w d -> b c (h w d)', c=c, w=w, h=h, d=d)
        x = self.SP_FFC(x)
        # Get the shape of x
        print("x.shape", x.shape)
        _, c, l = x.shape
        # Rearrange the input
        out = rearrange(x, 'b c (h w d) -> b c h w d', c=c, w=int(l**0.5), h=int(l**0.5), d=int(l**0.5))

        # Return out
        return out

# UPA (Universal Pixel Attention) block
class UPA_Block(nn.Module):

    # Constructor
    def __init__(self, in_channels, out_channels, stride=1, cat=False, same=False, w=2, l=2):

        # Call parent constructor
        super(UPA_Block, self).__init__()

        # Initialize parameters
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.stride = stride
        self.cat = cat
        self.same = same

        # Create the first convolutional layer
        self.CNN = nn.Sequential(
            nn.Conv3d(in_channels, int(out_channels * w), kernel_size=3, padding=1, bias=False),
            nn.BatchNorm3d(int(out_channels * w)),
            nn.ReLU(inplace=True),
            nn.Conv3d(int(out_channels * w), out_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm3d(out_channels),
            nn.ReLU(inplace=True)
        )

        # Define the CNN depending on the number of layers
        if l == 1:
            w = 1
            self.CNN = nn.Sequential(
                 nn.Conv3d(in_channels, int(out_channels * w), kernel_size=3, padding=1, bias=False),
                nn.BatchNorm3d(int(out_channels * w)),
                nn.ReLU(inplace=True),
            )

        # Define the attention layer
        self.attention = CPA(in_channels, out_channels, stride, same=same)

    # Forward function
    def forward(self, x):

        # Apply the CNN
        out = self.CNN(x)

        # Apply the attention
        out = self.attention(x, out)

        # If cat is true
        if self.cat:
            # Concatenate
            out = torch.cat([x, out], 1)

        # Return out
        return out
    
# Define the UPA Net
class upanets(nn.Module):

    # Constructor
    def __init__(self, block, num_blocks, filter_nums, output_nc=3, img=32):

        # Call parent constructor
        super(upanets, self).__init__()

        # Initialize parameters
        self.in_channels = filter_nums
        self.filters = filter_nums
        w = 2

        # Define the first convolutional layer
        self.root = nn.Sequential(
            nn.Conv3d(3, int(self.in_channels * w), kernel_size=3, padding=1, bias=False),
            nn.BatchNorm3d(int(self.in_channels * w)),
            nn.ReLU(inplace=True),
            nn.Conv3d(int(self.in_channels * w), self.in_channels, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm3d(self.in_channels),
            nn.ReLU(inplace=True)
        )

        # Define the first embedding
        self.embedding = CPA(3, self.in_channels, same=True)

        # Define the layers
        self.layer1 = self._make_layer(block, int(self.filters * 1), num_blocks[0], 1)
        self.layer2 = self._make_layer(block, int(self.filters * 2), num_blocks[1], 2)
        self.layer3 = self._make_layer(block, int(self.filters * 4), num_blocks[2], 2)
        self.layer4 = self._make_layer(block, int(self.filters * 8), num_blocks[3], 2)

        # Define the SPA layers
        self.SPA0 = SPA(img)
        self.SPA1 = SPA(img)
        self.SPA2 = SPA(int(img * 0.5))
        self.SPA3 = SPA(int(img * 0.25))
        self.SPA4 = SPA(int(img * 0.125))

        # Define the linear layer
        self.linear = nn.Linear(int(self.filters * 31), output_nc)

        # Define the batchnorm
        self.BN = nn.BatchNorm1d(int(self.filters * 31))

    # Make layer function
    def _make_layer(self, block, out_channels, num_blocks, stride):

        # Initialize parameters
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        self.out_channels = out_channels
        out_channels = out_channels // num_blocks

        # For every stride
        for i, stride in enumerate(strides):

            # If first stride and stride is 1
            if i == 0 and stride == 1:
                layers.append(block(self.out_channels, self.out_channels, stride, same=True))
                strides.append(1)
                self.in_channels = self.out_channels
                
            elif i != 0 and stride == 1:
                layers.append(block(self.in_channels, out_channels, stride, cat=True))                
                self.in_channels = self.in_channels + out_channels 
                    
            else:   
                layers.append(block(self.in_channels, self.out_channels, stride))
                strides.append(1)
                self.in_channels = self.out_channels
                
        return nn.Sequential(*layers)
    
    # Forward function
    def forward(self, x):

        out01 = self.root(x)
        print("out01 shape: ", out01.shape)
        out0 = self.embedding(x, out01)
        print("out0 shape: ", out0.shape)
        
        out1 = self.layer1(out0)
        print("out1 shape: ", out1.shape)
        out2 = self.layer2(out1)
        print("out2 shape: ", out2.shape)
        out3 = self.layer3(out2)
        print("out3 shape: ", out3.shape)
        out4 = self.layer4(out3)
        print("out4 shape: ", out4.shape)

        out0_spa = self.SPA0(out0)
        print("out0_spa shape: ", out0_spa.shape)
        out1_spa = self.SPA1(out1)
        print("out1_spa shape: ", out1_spa.shape)
        out2_spa = self.SPA2(out2)
        print("out2_spa shape: ", out2_spa.shape)
        out3_spa = self.SPA3(out3)
        print("out3_spa shape: ", out3_spa.shape)
        out4_spa = self.SPA4(out4)
        print("out4_spa shape: ", out4_spa.shape)
        
        out0_gap = F.avg_pool3d(out0, out0.size()[2:])
        print("out0_gap shape: ", out0_gap.shape)
        out1_gap = F.avg_pool3d(out1, out1.size()[2:])
        print("out1_gap shape: ", out1_gap.shape)
        out2_gap = F.avg_pool3d(out2, out2.size()[2:])
        print("out2_gap shape: ", out2_gap.shape)
        out3_gap = F.avg_pool3d(out3, out3.size()[2:])
        print("out3_gap shape: ", out3_gap.shape)
        out4_gap = F.avg_pool3d(out4, out4.size()[2:])
        print("out4_gap shape: ", out4_gap.shape)
      
        out0 = out0_gap + out0_spa
        print("out0 shape: ", out0.shape)
        out1 = out1_gap + out1_spa
        print("out1 shape: ", out1.shape)
        out2 = out2_gap + out2_spa
        print("out2 shape: ", out2.shape)
        out3 = out3_gap + out3_spa
        print("out3 shape: ", out3.shape)
        out4 = out4_gap + out4_spa
        print("out4 shape: ", out4.shape)
        
        out0 = F.layer_norm(out0, out0.size()[1:])
        print("out0 shape: ", out0.shape)
        out1 = F.layer_norm(out1, out1.size()[1:])
        print("out1 shape: ", out1.shape)
        out2 = F.layer_norm(out2, out2.size()[1:])
        print("out2 shape: ", out2.shape)
        out3 = F.layer_norm(out3, out3.size()[1:])
        print("out3 shape: ", out3.shape)
        out4 = F.layer_norm(out4, out4.size()[1:])
        print("out4 shape: ", out4.shape)
        
        out = torch.cat([out4, out3, out2, out1, out0], 1)
        
        out = out.view(out.size(0), -1)
        # out = self.BN(out) # please exclude when using the test function
        out = self.linear(out)

        return out        
    
# Define the actual UPA nets
def UPANets(input_nc, output_nc=3, num_blocks=1, img_size=32):
    
    # Return the architecture
    return upanets(block=UPA_Block,
                   num_blocks=[int(4*num_blocks), int(4*num_blocks), int(4*num_blocks), int(4*num_blocks)],
                   filter_nums=input_nc,
                   output_nc=output_nc,
                   img=img_size)
                   


In [61]:
def test():
    
    net = UPANets(16, 10, 1, 64)
    y = net(torch.randn(1, 3, 64, 64, 64))
    print(y.size())

test()

out01 shape:  torch.Size([1, 16, 64, 64, 64])
out0 shape:  torch.Size([1, 16, 64, 64, 64])
out1 shape:  torch.Size([1, 32, 64, 64, 64])
out2 shape:  torch.Size([1, 64, 32, 32, 32])
out3 shape:  torch.Size([1, 128, 16, 16, 16])
out4 shape:  torch.Size([1, 256, 8, 8, 8])
x.shape torch.Size([1, 16, 1])
out0_spa shape:  torch.Size([1, 16, 1, 1, 1])
x.shape torch.Size([1, 32, 1])
out1_spa shape:  torch.Size([1, 32, 1, 1, 1])
x.shape torch.Size([1, 64, 1])
out2_spa shape:  torch.Size([1, 64, 1, 1, 1])
x.shape torch.Size([1, 128, 1])
out3_spa shape:  torch.Size([1, 128, 1, 1, 1])
x.shape torch.Size([1, 256, 1])
out4_spa shape:  torch.Size([1, 256, 1, 1, 1])
out0_gap shape:  torch.Size([1, 16, 1, 1, 1])
out1_gap shape:  torch.Size([1, 32, 1, 1, 1])
out2_gap shape:  torch.Size([1, 64, 1, 1, 1])
out3_gap shape:  torch.Size([1, 128, 1, 1, 1])
out4_gap shape:  torch.Size([1, 256, 1, 1, 1])
out0 shape:  torch.Size([1, 16, 1, 1, 1])
out1 shape:  torch.Size([1, 32, 1, 1, 1])
out2 shape:  torch.Size([

In [62]:
'''
UPANets in PyTorch.
by Ching-Hsun Tseng and Jia-Nan Feng
'''
#%%
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

class upa_block(nn.Module):
    
    def __init__(self, in_planes, planes, stride=1, cat=False, same=False, w=2, l=2):
        
        super(upa_block, self).__init__()
        
        self.cat = cat
        self.stride = stride
        self.planes = planes
        self.same = same
        self.cnn = nn.Sequential(
            nn.Conv2d(in_planes, int(planes * w), kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(int(planes * w)),
            nn.ReLU(),
            nn.Conv2d(int(planes * w), planes, kernel_size=3, padding=1, bias=False),
            nn.BatchNorm2d(planes),
            nn.ReLU()
            )
        if l == 1:
            w = 1
            self.cnn = nn.Sequential(
                nn.Conv2d(in_planes, int(planes * w), kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(int(planes * w)),
                nn.ReLU(),
                )
        
        self.att = CPA(in_planes, planes, stride, same=same)
            
    def forward(self, x):

        out = self.cnn(x)
        out = self.att(x, out)

        if self.cat == True:
            out = torch.cat([x, out], 1)
            
        return out

class CPA(nn.Module):
    '''Channel Pixel Attention'''
    
#      *same=False:
#       This scenario can be easily embedded after any CNNs, if size is same.
#        x (OG) ---------------
#        |                    |
#        sc_x (from CNNs)     CPA(x)
#        |                    |
#        out + <---------------
#        
#      *same=True:
#       This can be embedded after the CNNs where the size are different.
#        x (OG) ---------------
#        |                    |
#        sc_x (from CNNs)     |
#        |                    CPA(x)
#        CPA(sc_x)            |
#        |                    |
#        out + <---------------
#           
#      *sc_x=False
#       This operation can be seen a channel embedding with CPA
#       EX: x (3, 32, 32) => (16, 32, 32)
#        x (OG) 
#        |      
#        CPA(x)
#        |    
#        out 

    def __init__(self, in_dim, dim, stride=1, same=False, sc_x=True):
        
        super(CPA, self).__init__()
            
        self.dim = dim
        self.stride = stride
        self.same = same
        self.sc_x = sc_x
        
        self.cp_ffc = nn.Linear(in_dim, dim)
        self.bn = nn.BatchNorm2d(dim)

        if self.stride == 2 or self.same == True:
            if sc_x == True:
                self.cp_ffc_sc = nn.Linear(in_dim, dim)
                self.bn_sc = nn.BatchNorm2d(dim)
            
            if self.stride == 2:
                self.avgpool = nn.AvgPool2d(2)
            
    def forward(self, x, sc_x):    
       
        _, c, w, h = x.shape
        out = rearrange(x, 'b c w h -> b w h c', c=c, w=w, h=h)
        out = self.cp_ffc(out)
        out = rearrange(out, 'b w h c-> b c w h', c=self.dim, w=w, h=h)
        out = self.bn(out)  
       
        if out.shape == sc_x.shape:
            if self.sc_x == True:
                out = sc_x + out
            out = F.layer_norm(out, out.size()[1:])
            
        else:
            out = F.layer_norm(out, out.size()[1:])
            if self.sc_x == True:
                x = sc_x
            
        if self.stride == 2 or self.same == True:
            if self.sc_x == True:
                _, c, w, h = x.shape
                x = rearrange(x, 'b c w h -> b w h c', c=c, w=w, h=h)
                x = self.cp_ffc_sc(x)
                x = rearrange(x, 'b w h c-> b c w h', c=self.dim, w=w, h=h)
                x = self.bn_sc(x)
                out = out + x 
            
            if self.same == True:
                return out
            
            out = self.avgpool(out)
           
        return out

   
class SPA(nn.Module):
    '''Spatial Pixel Attention'''

    def __init__(self, img, out=1):
        
        super(SPA, self).__init__()
        
        self.sp_ffc = nn.Sequential(
            nn.Linear(img**2, out**2)
            )   
        
    def forward(self, x):
        
        _, c, w, h = x.shape          
        x = rearrange(x, 'b c w h -> b c (w h)', c=c, w=w, h=h)
        x = self.sp_ffc(x)
        _, c, l = x.shape        
        out = rearrange(x, 'b c (w h) -> b c w h', c=c, w=int(l**0.5), h=int(l**0.5))

        return out
    
class upanets(nn.Module):
    def __init__(self, block, num_blocks, filter_nums, num_classes=100, img=32):
        
        super(upanets, self).__init__()
        
        self.in_planes = filter_nums
        self.filters = filter_nums
        w = 2
        
        self.root = nn.Sequential(
                nn.Conv2d(3, int(self.in_planes*w), kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(int(self.in_planes*w)),
                nn.ReLU(),
                nn.Conv2d(int(self.in_planes*w), self.in_planes*1, kernel_size=3, padding=1, bias=False),
                nn.BatchNorm2d(self.in_planes),
                nn.ReLU(),
                )        
        self.emb = CPA(3, self.in_planes, same=True)
        
        self.layer1 = self._make_layer(block, int(self.filters*1), num_blocks[0], 1)
        self.layer2 = self._make_layer(block, int(self.filters*2), num_blocks[1], 2)
        self.layer3 = self._make_layer(block, int(self.filters*4), num_blocks[2], 2)
        self.layer4 = self._make_layer(block, int(self.filters*8), num_blocks[3], 2)
        
        self.spa0 = SPA(img)
        self.spa1 = SPA(img)
        self.spa2 = SPA(int(img*0.5))
        self.spa3 = SPA(int(img*0.25))
        self.spa4 = SPA(int(img*0.125))

        self.linear = nn.Linear(int(self.filters*31), num_classes)
        self.bn = nn.BatchNorm1d(int(self.filters*31))
     
    def _make_layer(self, block, planes, num_blocks, stride):
        
        strides = [stride] + [1]*(num_blocks - 1)
        layers = []
        self.planes = planes
        planes = planes // num_blocks

        for i, stride in enumerate(strides):
            
            if i == 0 and stride == 1:
                layers.append(block(self.planes, self.planes, stride, same=True))
                strides.append(1)
                self.in_planes = self.planes
                
            elif i != 0 and stride == 1:
                layers.append(block(self.in_planes, planes, stride, cat=True))                
                self.in_planes = self.in_planes + planes 
                    
            else:   
                layers.append(block(self.in_planes, self.planes, stride))
                strides.append(1)
                self.in_planes = self.planes
                
        return nn.Sequential(*layers)

    def forward(self, x):
                
        out01 = self.root(x)
        print("out01 shape:", out01.shape)
        out0 = self.emb(x, out01)
        print("out0 shape:", out0.shape)
        
        out1 = self.layer1(out0)
        print("out1 shape:", out1.shape)
        out2 = self.layer2(out1)
        print("out2 shape:", out2.shape)
        out3 = self.layer3(out2)
        print("out3 shape:", out3.shape)
        out4 = self.layer4(out3)
        print("out4 shape:", out4.shape)

        out0_spa = self.spa0(out0)
        print("out0_spa shape:", out0_spa.shape)
        out1_spa = self.spa1(out1)
        print("out1_spa shape:", out1_spa.shape)
        out2_spa = self.spa2(out2)
        print("out2_spa shape:", out2_spa.shape)
        out3_spa = self.spa3(out3)
        print("out3_spa shape:", out3_spa.shape)
        out4_spa = self.spa4(out4)
        print("out4_spa shape:", out4_spa.shape)
        
        out0_gap = F.avg_pool2d(out0, out0.size()[2:])
        print("out0_gap shape:", out0_gap.shape)
        out1_gap = F.avg_pool2d(out1, out1.size()[2:])
        print("out1_gap shape:", out1_gap.shape)
        out2_gap = F.avg_pool2d(out2, out2.size()[2:])
        print("out2_gap shape:", out2_gap.shape)
        out3_gap = F.avg_pool2d(out3, out3.size()[2:])
        print("out3_gap shape:", out3_gap.shape)
        out4_gap = F.avg_pool2d(out4, out4.size()[2:])
        print("out4_gap shape:", out4_gap.shape)
      
        out0 = out0_gap + out0_spa
        print("out0 shape:", out0.shape)
        out1 = out1_gap + out1_spa
        print("out1 shape:", out1.shape)
        out2 = out2_gap + out2_spa
        print("out2 shape:", out2.shape)
        out3 = out3_gap + out3_spa
        print("out3 shape:", out3.shape)
        out4 = out4_gap + out4_spa
        print("out4 shape:", out4.shape)
        
        out0 = F.layer_norm(out0, out0.size()[1:])
        print("out0 shape:", out0.shape)
        out1 = F.layer_norm(out1, out1.size()[1:])
        print("out1 shape:", out1.shape)
        out2 = F.layer_norm(out2, out2.size()[1:])
        print("out2 shape:", out2.shape)
        out3 = F.layer_norm(out3, out3.size()[1:])
        print("out3 shape:", out3.shape)
        out4 = F.layer_norm(out4, out4.size()[1:])
        print("out4 shape:", out4.shape)
        
        out = torch.cat([out4, out3, out2, out1, out0], 1)
        print("out shape:", out.shape)
        
        out = out.view(out.size(0), -1)
        # out = self.bn(out) # please exclude when using the test function
        out = self.linear(out)
        print("out shape:", out.shape)

        return out

def UPANets(f, c = 100, block = 1, img = 32):
    
    return upanets(upa_block, [int(4*block), int(4*block), int(4*block), int(4*block)], f, num_classes=c, img=img)

#test()

In [63]:
def test():
    
    net = UPANets(16, 10, 1, 64)
    y = net(torch.randn(1, 3, 64, 64))
    print(y.size())

test()

out01 shape: torch.Size([1, 16, 64, 64])
out0 shape: torch.Size([1, 16, 64, 64])
out1 shape: torch.Size([1, 32, 64, 64])
out2 shape: torch.Size([1, 64, 32, 32])
out3 shape: torch.Size([1, 128, 16, 16])
out4 shape: torch.Size([1, 256, 8, 8])
out0_spa shape: torch.Size([1, 16, 1, 1])
out1_spa shape: torch.Size([1, 32, 1, 1])
out2_spa shape: torch.Size([1, 64, 1, 1])
out3_spa shape: torch.Size([1, 128, 1, 1])
out4_spa shape: torch.Size([1, 256, 1, 1])
out0_gap shape: torch.Size([1, 16, 1, 1])
out1_gap shape: torch.Size([1, 32, 1, 1])
out2_gap shape: torch.Size([1, 64, 1, 1])
out3_gap shape: torch.Size([1, 128, 1, 1])
out4_gap shape: torch.Size([1, 256, 1, 1])
out0 shape: torch.Size([1, 16, 1, 1])
out1 shape: torch.Size([1, 32, 1, 1])
out2 shape: torch.Size([1, 64, 1, 1])
out3 shape: torch.Size([1, 128, 1, 1])
out4 shape: torch.Size([1, 256, 1, 1])
out0 shape: torch.Size([1, 16, 1, 1])
out1 shape: torch.Size([1, 32, 1, 1])
out2 shape: torch.Size([1, 64, 1, 1])
out3 shape: torch.Size([1, 12