In [1]:
# Codeblock 1
import torch
import torch.nn as nn

In [2]:
# Codeblock 2
R            = 16
CARDINALITY  = 32
NUM_CHANNELS = [3, 64, 256, 512, 1024, 2048]
NUM_BLOCKS   = [3, 4, 6, 3]
NUM_CLASSES  = 1000

In [5]:
# Codeblock 3
class CAM(nn.Module):
    def __init__(self, num_channels, r=16):
        super().__init__()
        
        self.maxpool = nn.AdaptiveMaxPool2d(output_size=(1,1))  #(1)
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))  #(2)
        
        self.mlp = nn.Sequential(
            nn.Linear(in_features=num_channels,
                      out_features=num_channels//r,   #(3)
                      bias=False),
            
            nn.ReLU(inplace=True),                    #(4)
            
            nn.Linear(in_features=num_channels//r,    #(5)
                      out_features=num_channels, 
                      bias=False)
        )
        
        self.sigmoid = nn.Sigmoid()                   #(6)
        
    def forward(self, x):         #(7)
        original = x
        #print(f'original\t\t: {x.size()}\n')
        
        
        x_max = self.maxpool(x)   #(8)
        #print(f'x after maxpool (x_max)\t: {x_max.size()}')
        
        x_avg = self.avgpool(x)   #(9)
        #print(f'x after avgpool (x_avg)\t: {x_avg.size()}\n')
        
        
        x_max = torch.flatten(x_max, start_dim=1)    #(10)
        #print(f'x_max after flatten\t: {x_max.size()}')
        
        x_avg = torch.flatten(x_avg, start_dim=1)    #(11)
        #print(f'x_avg after flatten\t: {x_avg.size()}\n')
        
        
        x_max = self.mlp(x_max)   #(12)
        #print(f'x_max after mlp\t\t: {x_max.size()}')
        
        x_avg = self.mlp(x_avg)   #(13)
        #print(f'x_avg after mlp\t\t: {x_avg.size()}\n')
        
        
        x = x_max + x_avg         #(14)
        #print(f'after sum\t\t: {x.size()}')
        
        x = self.sigmoid(x)       #(15)
        #print(f'after sigmoid\t\t: {x.size()}')
        
        x = x[:, :, None, None]   #(16)
        #print(f'after reshape\t\t: {x.size()}')
        
        x = x * original          #(17)           
        #print(f'after multiply\t\t: {x.size()}')
        
        return x

In [4]:
# Codeblock 4
cam = CAM(num_channels=512, r=16)
x = torch.randn(1, 512, 28, 28)

out = cam(x)

original		: torch.Size([1, 512, 28, 28])

x after maxpool (x_max)	: torch.Size([1, 512, 1, 1])
x after avgpool (x_avg)	: torch.Size([1, 512, 1, 1])

x_max after flatten	: torch.Size([1, 512])
x_avg after flatten	: torch.Size([1, 512])

x_max after mlp		: torch.Size([1, 512])
x_avg after mlp		: torch.Size([1, 512])

after sum		: torch.Size([1, 512])
after sigmoid		: torch.Size([1, 512])
after reshape		: torch.Size([1, 512, 1, 1])
after multiply		: torch.Size([1, 512, 28, 28])


In [8]:
# Codeblock 5
class SAM(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.conv = nn.Conv2d(in_channels=2,     #(1)
                              out_channels=1, 
                              kernel_size=7, 
                              padding=3, 
                              bias=False)
        self.sigmoid = nn.Sigmoid()              #(2)
    
    def forward(self, x):
        original = x      #(3)
        #print(f'original\t\t: {x.size()}\n')
        
        
        x_max, _ = torch.max(x,  dim=1, keepdim=True)    #(4)
        #print(f'x after maxpool (x_max)\t: {x_max.size()}')
        
        x_avg    = torch.mean(x, dim=1, keepdim=True)    #(5)
        #print(f'x after avgpool (x_avg)\t: {x_avg.size()}\n')
        
        
        x = torch.cat([x_max, x_avg], dim=1)             #(6)
        #print(f'after concatenate\t: {x.size()}')
        
        x = self.conv(x)                                 #(7)
        #print(f'after conv\t\t: {x.size()}')
        
        x = self.sigmoid(x)                              #(8)
        #print(f'after sigmoid\t\t: {x.size()}')
        
        x = x * original                                 #(9)
        #print(f'after multiply\t\t: {x.size()}')
        
        return x

In [7]:
# Codeblock 6
sam = SAM()
x = torch.randn(1, 512, 28, 28)

out = sam(x)

original		: torch.Size([1, 512, 28, 28])

x after maxpool (x_max)	: torch.Size([1, 1, 28, 28])
x after avgpool (x_avg)	: torch.Size([1, 1, 28, 28])

after concatenate	: torch.Size([1, 2, 28, 28])
after conv		: torch.Size([1, 1, 28, 28])
after sigmoid		: torch.Size([1, 1, 28, 28])
after multiply		: torch.Size([1, 512, 28, 28])


In [11]:
# Codeblock 7
class CBAM(nn.Module):
    def __init__(self, num_channels):    #(1)
        super().__init__()
        
        self.cam = CAM(num_channels=num_channels)
        self.sam = SAM()
        
    def forward(self, x):
        #print(f'original\t\t: {x.size()}')
        
        x = self.cam(x)
        #print(f'after cam\t\t: {x.size()}')
        
        x = self.sam(x)
        #print(f'after sam\t\t: {x.size()}')
        
        return x

In [10]:
# Codeblock 8
cbam = CBAM(num_channels=512)
x = torch.randn(1, 512, 28, 28)

out = cbam(x)

original		: torch.Size([1, 512, 28, 28])
after cam		: torch.Size([1, 512, 28, 28])
after sam		: torch.Size([1, 512, 28, 28])


In [14]:
# Codeblock 9
class Block(nn.Module):
    def __init__(self, 
                 in_channels,
                 add_channel=False,
                 channel_multiplier=2,
                 downsample=False):
        super().__init__()

        self.add_channel = add_channel
        self.channel_multiplier = channel_multiplier
        self.downsample = downsample
        
        
        if self.add_channel:
            out_channels = in_channels*self.channel_multiplier
        else:
            out_channels = in_channels
        
        mid_channels = out_channels//2
        
        
        if self.downsample:
            stride = 2
        else:
            stride = 1

        if self.add_channel or self.downsample:
            self.projection = nn.Conv2d(in_channels=in_channels,
                                        out_channels=out_channels, 
                                        kernel_size=1, 
                                        stride=stride, 
                                        padding=0, 
                                        bias=False)
            nn.init.kaiming_normal_(self.projection.weight, nonlinearity='relu')
            self.bn_proj = nn.BatchNorm2d(num_features=out_channels)

        self.conv0 = nn.Conv2d(in_channels=in_channels,
                               out_channels=mid_channels,
                               kernel_size=1, 
                               stride=1, 
                               padding=0, 
                               bias=False)
        nn.init.kaiming_normal_(self.conv0.weight, nonlinearity='relu')
        self.bn0 = nn.BatchNorm2d(num_features=mid_channels)

        self.conv1 = nn.Conv2d(in_channels=mid_channels,
                               out_channels=mid_channels, 
                               kernel_size=3, 
                               stride=stride,
                               padding=1, 
                               bias=False, 
                               groups=CARDINALITY)
        nn.init.kaiming_normal_(self.conv1.weight, nonlinearity='relu')
        self.bn1 = nn.BatchNorm2d(num_features=mid_channels)

        self.conv2 = nn.Conv2d(in_channels=mid_channels,
                               out_channels=out_channels,
                               kernel_size=1, 
                               stride=1, 
                               padding=0, 
                               bias=False)
        nn.init.kaiming_normal_(self.conv2.weight, nonlinearity='relu')
        self.bn2 = nn.BatchNorm2d(num_features=out_channels)
        
        self.relu = nn.ReLU()
        
        self.cbam = CBAM(num_channels=out_channels)               #(1)
        
    def forward(self, x):
        #print(f'original\t\t: {x.size()}')
        
        if self.add_channel or self.downsample:
            residual = self.bn_proj(self.projection(x))
            #print(f'after projection\t: {residual.size()}')
        else:
            residual = x
            #print(f'no projection\t\t: {residual.size()}')
        
        x = self.conv0(x)
        x = self.bn0(x)
        x = self.relu(x)
        #print(f'after conv0-bn0-relu\t: {x.size()}')

        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        #print(f'after conv1-bn1-relu\t: {x.size()}')
        
        x = self.conv2(x)
        x = self.bn2(x)
        #print(f'after conv2-bn2\t\t: {x.size()}')
        
        x = self.cbam(x)                                          #(2)
        #print(f'after cbam\t\t: {x.size()}')
        
        x = x + residual
        x = self.relu(x)
        #print(f'after summation\t\t: {x.size()}')
        
        return x

In [13]:
# Codeblock 10
block = Block(in_channels=512, add_channel=False, downsample=False)
x = torch.randn(1, 512, 28, 28)

out = block(x)

original		: torch.Size([1, 512, 28, 28])
no projection		: torch.Size([1, 512, 28, 28])
after conv0-bn0-relu	: torch.Size([1, 256, 28, 28])
after conv1-bn1-relu	: torch.Size([1, 256, 28, 28])
after conv2-bn2		: torch.Size([1, 512, 28, 28])
after cbam		: torch.Size([1, 512, 28, 28])
after summation		: torch.Size([1, 512, 28, 28])


In [15]:
# Codeblock 11
class CBAMResNeXt(nn.Module):
    def __init__(self):
        super().__init__()

        # conv1 stage
        self.resnext_conv1 = nn.Conv2d(in_channels=NUM_CHANNELS[0],
                                       out_channels=NUM_CHANNELS[1],
                                       kernel_size=7,
                                       stride=2,
                                       padding=3, 
                                       bias=False)
        nn.init.kaiming_normal_(self.resnext_conv1.weight, 
                                nonlinearity='relu')
        self.resnext_bn1 = nn.BatchNorm2d(num_features=NUM_CHANNELS[1])
        self.relu = nn.ReLU()
        self.resnext_maxpool1 = nn.MaxPool2d(kernel_size=3,
                                             stride=2, 
                                             padding=1)

        # conv2 stage
        self.resnext_conv2 = nn.ModuleList([
            Block(in_channels=NUM_CHANNELS[1],
                  add_channel=True,
                  channel_multiplier=4,
                  downsample=False)
        ])
        for _ in range(NUM_BLOCKS[0]-1):
            self.resnext_conv2.append(Block(in_channels=NUM_CHANNELS[2]))

        # conv3 stage
        self.resnext_conv3 = nn.ModuleList([Block(in_channels=NUM_CHANNELS[2],
                                                  add_channel=True, 
                                                  downsample=True)])
        for _ in range(NUM_BLOCKS[1]-1):
            self.resnext_conv3.append(Block(in_channels=NUM_CHANNELS[3]))
            
            
        # conv4 stage
        self.resnext_conv4 = nn.ModuleList([Block(in_channels=NUM_CHANNELS[3],
                                                  add_channel=True, 
                                                  downsample=True)])
        
        for _ in range(NUM_BLOCKS[2]-1):
            self.resnext_conv4.append(Block(in_channels=NUM_CHANNELS[4]))
            
            
        # conv5 stage
        self.resnext_conv5 = nn.ModuleList([Block(in_channels=NUM_CHANNELS[4],
                                                  add_channel=True, 
                                                  downsample=True)])
        
        for _ in range(NUM_BLOCKS[3]-1):
            self.resnext_conv5.append(Block(in_channels=NUM_CHANNELS[5]))
 
       
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))

        self.fc = nn.Linear(in_features=NUM_CHANNELS[5],
                            out_features=NUM_CLASSES)

    def forward(self, x):
        print(f'original\t\t: {x.size()}')
        
        x = self.relu(self.resnext_bn1(self.resnext_conv1(x)))
        print(f'after resnext_conv1\t: {x.size()}')
        
        x = self.resnext_maxpool1(x)
        print(f'after resnext_maxpool1\t: {x.size()}')
        
        for i, block in enumerate(self.resnext_conv2):
            x = block(x)
            print(f'after resnext_conv2 #{i}\t: {x.size()}')
            
        for i, block in enumerate(self.resnext_conv3):
            x = block(x)
            print(f'after resnext_conv3 #{i}\t: {x.size()}')
            
        for i, block in enumerate(self.resnext_conv4):
            x = block(x)
            print(f'after resnext_conv4 #{i}\t: {x.size()}')
            
        for i, block in enumerate(self.resnext_conv5):
            x = block(x)
            print(f'after resnext_conv5 #{i}\t: {x.size()}')
        
        x = self.avgpool(x)
        print(f'after avgpool\t\t: {x.size()}')
        
        x = torch.flatten(x, start_dim=1)
        print(f'after flatten\t\t: {x.size()}')
        
        x = self.fc(x)
        print(f'after fc\t\t: {x.size()}')
        
        return x

In [16]:
# Codeblock 12
cbamresnext = CBAMResNeXt()
x = torch.randn(1, 3, 224, 224)

out = cbamresnext(x)

original		: torch.Size([1, 3, 224, 224])
after resnext_conv1	: torch.Size([1, 64, 112, 112])
after resnext_maxpool1	: torch.Size([1, 64, 56, 56])
after resnext_conv2 #0	: torch.Size([1, 256, 56, 56])
after resnext_conv2 #1	: torch.Size([1, 256, 56, 56])
after resnext_conv2 #2	: torch.Size([1, 256, 56, 56])
after resnext_conv3 #0	: torch.Size([1, 512, 28, 28])
after resnext_conv3 #1	: torch.Size([1, 512, 28, 28])
after resnext_conv3 #2	: torch.Size([1, 512, 28, 28])
after resnext_conv3 #3	: torch.Size([1, 512, 28, 28])
after resnext_conv4 #0	: torch.Size([1, 1024, 14, 14])
after resnext_conv4 #1	: torch.Size([1, 1024, 14, 14])
after resnext_conv4 #2	: torch.Size([1, 1024, 14, 14])
after resnext_conv4 #3	: torch.Size([1, 1024, 14, 14])
after resnext_conv4 #4	: torch.Size([1, 1024, 14, 14])
after resnext_conv4 #5	: torch.Size([1, 1024, 14, 14])
after resnext_conv5 #0	: torch.Size([1, 2048, 7, 7])
after resnext_conv5 #1	: torch.Size([1, 2048, 7, 7])
after resnext_conv5 #2	: torch.Size([1, 