In [1]:
# Codeblock 1
import torch
import torch.nn as nn

In [4]:
# Codeblock 2a
class ConvNeXtBlock(nn.Module):
    def __init__(self, num_channels):         #(1)
        super().__init__()
        hidden_channels = num_channels * 4    #(2)

        
        self.conv0 = nn.Conv2d(in_channels=num_channels,         #(3) 
                               out_channels=num_channels,        #(4)
                               kernel_size=7,    #(5)
                               stride=1,
                               padding=3,        #(6)
                               groups=num_channels)              #(7)
        
        self.norm = nn.LayerNorm(normalized_shape=num_channels)  #(8)
        
        self.conv1 = nn.Conv2d(in_channels=num_channels,         #(9)
                               out_channels=hidden_channels, 
                               kernel_size=1, 
                               stride=1, 
                               padding=0)
        
        self.gelu = nn.GELU()  #(10)
        
        self.conv2 = nn.Conv2d(in_channels=hidden_channels,      #(11)
                               out_channels=num_channels, 
                               kernel_size=1, 
                               stride=1, 
                               padding=0)
        
# Codeblock 2b
    def forward(self, x):
        residual = x                 #(1)
        #print(f'x & residual\t: {x.size()}')
        
        x = self.conv0(x)
        #print(f'after conv0\t: {x.size()}')
        
        x = x.permute(0, 2, 3, 1)    #(2)
        #print(f'after permute\t: {x.size()}')
        
        x = self.norm(x)
        #print(f'after norm\t: {x.size()}')
        
        x = x.permute(0, 3, 1, 2)    #(3)
        #print(f'after permute\t: {x.size()}')
        
        x = self.conv1(x)
        #print(f'after conv1\t: {x.size()}')
        
        x = self.gelu(x)
        #print(f'after gelu\t: {x.size()}')
        
        x = self.conv2(x)
        #print(f'after conv2\t: {x.size()}')
        
        x = x + residual             #(4)
        #print(f'after summation\t: {x.size()}')
        
        return x

In [3]:
# Codeblock 3
convnext_block_test = ConvNeXtBlock(num_channels=96)  #(1)
x_test = torch.rand(1, 96, 56, 56)  #(2)

out_test = convnext_block_test(x_test)

x & residual	: torch.Size([1, 96, 56, 56])
after conv0	: torch.Size([1, 96, 56, 56])
after permute	: torch.Size([1, 56, 56, 96])
after norm	: torch.Size([1, 56, 56, 96])
after permute	: torch.Size([1, 96, 56, 56])
after conv1	: torch.Size([1, 384, 56, 56])
after gelu	: torch.Size([1, 384, 56, 56])
after conv2	: torch.Size([1, 96, 56, 56])
after summation	: torch.Size([1, 96, 56, 56])


In [7]:
# Codeblock 4a
class ConvNeXtBlockTransition(nn.Module):
    def __init__(self, in_channels, out_channels):  #(1)
        super().__init__()
        hidden_channels = out_channels * 4
        
        self.projection = nn.Conv2d(in_channels=in_channels,      #(2) 
                                    out_channels=out_channels, 
                                    kernel_size=1, 
                                    stride=2,
                                    padding=0)
        
        self.conv0 = nn.Conv2d(in_channels=in_channels, 
                               out_channels=out_channels, 
                               kernel_size=7,
                               stride=1,
                               padding=3,
                               groups=in_channels)
        
        self.norm0 = nn.LayerNorm(normalized_shape=out_channels)
        
        self.conv1 = nn.Conv2d(in_channels=out_channels, 
                               out_channels=hidden_channels, 
                               kernel_size=1, 
                               stride=1, 
                               padding=0)
        
        self.gelu = nn.GELU()
        
        self.conv2 = nn.Conv2d(in_channels=hidden_channels, 
                               out_channels=out_channels, 
                               kernel_size=1, 
                               stride=1,
                               padding=0)
        
        self.norm1 = nn.LayerNorm(normalized_shape=out_channels)  #(3)

        self.downsample = nn.Conv2d(in_channels=out_channels,     #(4)
                                    out_channels=out_channels, 
                                    kernel_size=2, 
                                    stride=2)
        
# Codeblock 4b
    def forward(self, x):
        #print(f'original\t\t: {x.size()}')

        residual = self.projection(x)  #(1)
        #print(f'residual after proj\t: {residual.size()}')
        
        x = self.conv0(x)
        #print(f'after conv0\t\t: {x.size()}')
        
        x = x.permute(0, 2, 3, 1)
        #print(f'after permute\t\t: {x.size()}')
        
        x = self.norm0(x)
        #print(f'after norm1\t\t: {x.size()}')
        
        x = x.permute(0, 3, 1, 2)
        #print(f'after permute\t\t: {x.size()}')
        
        x = self.conv1(x)
        #print(f'after conv1\t\t: {x.size()}')
        
        x = self.gelu(x)
        #print(f'after gelu\t\t: {x.size()}')
        
        x = self.conv2(x)
        #print(f'after conv2\t\t: {x.size()}')

        x = x.permute(0, 2, 3, 1)
        #print(f'after permute\t\t: {x.size()}')
        
        x = self.norm1(x)
        #print(f'after norm1\t\t: {x.size()}')
        
        x = x.permute(0, 3, 1, 2)
        #print(f'after permute\t\t: {x.size()}')
        
        x = self.downsample(x)  #(2)
        #print(f'after downsample\t: {x.size()}')
        
        x = x + residual  #(3)
        #print(f'after summation\t\t: {x.size()}')
        
        return x

In [6]:
# Codeblock 5
convnext_block_transition_test = ConvNeXtBlockTransition(in_channels=96, 
                                                         out_channels=192)
x_test = torch.rand(1, 96, 56, 56)

out_test = convnext_block_transition_test(x_test)

original		: torch.Size([1, 96, 56, 56])
residual after proj	: torch.Size([1, 192, 28, 28])
after conv0		: torch.Size([1, 192, 56, 56])
after permute		: torch.Size([1, 56, 56, 192])
after norm1		: torch.Size([1, 56, 56, 192])
after permute		: torch.Size([1, 192, 56, 56])
after conv1		: torch.Size([1, 768, 56, 56])
after gelu		: torch.Size([1, 768, 56, 56])
after conv2		: torch.Size([1, 192, 56, 56])
after permute		: torch.Size([1, 56, 56, 192])
after norm1		: torch.Size([1, 56, 56, 192])
after permute		: torch.Size([1, 192, 56, 56])
after downsample	: torch.Size([1, 192, 28, 28])
after summation		: torch.Size([1, 192, 28, 28])


In [8]:
# Codeblock 6
IN_CHANNELS  = 3     #(1)
IMAGE_SIZE   = 224   #(2)

NUM_BLOCKS   = [3, 3, 9, 3]         #(3)
OUT_CHANNELS = [96, 192, 384, 768]  #(4)
NUM_CLASSES  = 1000  #(5)

In [9]:
# Codeblock 7a
class ConvNeXt(nn.Module):
    def __init__(self):
        super().__init__()
        
        self.stem = nn.Conv2d(in_channels=IN_CHANNELS,    #(1)
                              out_channels=OUT_CHANNELS[0],
                              kernel_size=4,
                              stride=4,
                             )

        self.normstem = nn.LayerNorm(normalized_shape=OUT_CHANNELS[0])  #(2)
        
        #(3)
        self.res2 = nn.ModuleList()
        for _ in range(NUM_BLOCKS[0]):
            self.res2.append(ConvNeXtBlock(num_channels=OUT_CHANNELS[0]))
        
        #(4)
        self.res3 = nn.ModuleList([ConvNeXtBlockTransition(in_channels=OUT_CHANNELS[0], 
                                                           out_channels=OUT_CHANNELS[1])])
        for _ in range(NUM_BLOCKS[1]-1):
            self.res3.append(ConvNeXtBlock(num_channels=OUT_CHANNELS[1]))

        #(5)
        self.res4 = nn.ModuleList([ConvNeXtBlockTransition(in_channels=OUT_CHANNELS[1], 
                                                           out_channels=OUT_CHANNELS[2])])
        for _ in range(NUM_BLOCKS[2]-1):
            self.res4.append(ConvNeXtBlock(num_channels=OUT_CHANNELS[2]))

        #(6)
        self.res5 = nn.ModuleList([ConvNeXtBlockTransition(in_channels=OUT_CHANNELS[2], 
                                                           out_channels=OUT_CHANNELS[3])])
        for _ in range(NUM_BLOCKS[3]-1):
            self.res5.append(ConvNeXtBlock(num_channels=OUT_CHANNELS[3]))

                
        self.avgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))  #(7)
        self.normpool = nn.LayerNorm(normalized_shape=OUT_CHANNELS[3])  #(8)
        self.fc = nn.Linear(in_features=OUT_CHANNELS[3],        #(9)
                            out_features=NUM_CLASSES)
        
        self.relu = nn.ReLU()
        
# Codeblock 7b
    def forward(self, x):
        print(f'original\t: {x.size()}')
        
        x = self.relu(self.stem(x))
        print(f'after stem\t: {x.size()}')

        x = x.permute(0, 2, 3, 1)
        print(f'after permute\t: {x.size()}')
        
        x = self.normstem(x)
        print(f'after normstem\t: {x.size()}')
        
        x = x.permute(0, 3, 1, 2)
        print(f'after permute\t: {x.size()}')
        
        print()
        for i, block in enumerate(self.res2):    #(1)
            x = block(x)
            print(f'after res2 #{i}\t: {x.size()}')
        
        print()
        for i, block in enumerate(self.res3):    #(2)
            x = block(x)
            print(f'after res3 #{i}\t: {x.size()}')
        
        print()
        for i, block in enumerate(self.res4):    #(3)
            x = block(x)
            print(f'after res4 #{i}\t: {x.size()}')
        
        print()
        for i, block in enumerate(self.res5):    #(4)
            x = block(x)
            print(f'after res5 #{i}\t: {x.size()}')
        
        print()
        x = self.avgpool(x)
        print(f'after avgpool\t: {x.size()}')

        x = x.permute(0, 2, 3, 1)
        print(f'after permute\t: {x.size()}')
        
        x = self.normpool(x)
        print(f'after normpool\t: {x.size()}')
        
        x = x.permute(0, 3, 1, 2)
        print(f'after permute\t: {x.size()}')
        
        x = x.reshape(x.shape[0], -1)             #(5)
        print(f'after reshape\t: {x.size()}')
        
        x = self.fc(x)
        print(f'after fc\t: {x.size()}')          #(6)
        
        return x

In [10]:
# Codeblock 8
convnext_test = ConvNeXt()

x_test   = torch.rand(1, IN_CHANNELS, IMAGE_SIZE, IMAGE_SIZE)
out_test = convnext_test(x_test)

original	: torch.Size([1, 3, 224, 224])
after stem	: torch.Size([1, 96, 56, 56])
after permute	: torch.Size([1, 56, 56, 96])
after normstem	: torch.Size([1, 56, 56, 96])
after permute	: torch.Size([1, 96, 56, 56])

after res2 #0	: torch.Size([1, 96, 56, 56])
after res2 #1	: torch.Size([1, 96, 56, 56])
after res2 #2	: torch.Size([1, 96, 56, 56])

after res3 #0	: torch.Size([1, 192, 28, 28])
after res3 #1	: torch.Size([1, 192, 28, 28])
after res3 #2	: torch.Size([1, 192, 28, 28])

after res4 #0	: torch.Size([1, 384, 14, 14])
after res4 #1	: torch.Size([1, 384, 14, 14])
after res4 #2	: torch.Size([1, 384, 14, 14])
after res4 #3	: torch.Size([1, 384, 14, 14])
after res4 #4	: torch.Size([1, 384, 14, 14])
after res4 #5	: torch.Size([1, 384, 14, 14])
after res4 #6	: torch.Size([1, 384, 14, 14])
after res4 #7	: torch.Size([1, 384, 14, 14])
after res4 #8	: torch.Size([1, 384, 14, 14])

after res5 #0	: torch.Size([1, 768, 7, 7])
after res5 #1	: torch.Size([1, 768, 7, 7])
after res5 #2	: torch.Si