# ConvNext in PyTorch

An implementation of the [ConvNext Paper](https://arxiv.org/pdf/2201.03545.pdf)

In [1]:
import torch
from torch import nn, optim
import torch.nn.functional as F
from torchinfo import summary

In [2]:
import random
import numpy as np
import math
from timm.models.layers import trunc_normal_, DropPath

In [4]:
def conv1x1(inplanes, out_planes, stride=1):
    return nn.Conv2d(inplanes, out_planes, kernel_size=1, stride=stride, bias=False)

In [5]:
class LayerNorm(nn.Module):
    # from https://github.com/facebookresearch/ConvNeXt
    r""" LayerNorm that supports two data formats: channels_last (default) or channels_first. 
    The ordering of the dimensions in the inputs. channels_last corresponds to inputs with 
    shape (batch_size, height, width, channels) while channels_first corresponds to inputs 
    with shape (batch_size, channels, height, width).
    """
    def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
        super().__init__()
        self.weight = nn.Parameter(torch.ones(normalized_shape))
        self.bias = nn.Parameter(torch.zeros(normalized_shape))
        self.eps = eps
        self.data_format = data_format
        if self.data_format not in ["channels_last", "channels_first"]:
            raise NotImplementedError 
        self.normalized_shape = (normalized_shape, )
    
    def forward(self, x):
        if self.data_format == "channels_last":
            return F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps)
        elif self.data_format == "channels_first":
            u = x.mean(1, keepdim=True)
            s = (x - u).pow(2).mean(1, keepdim=True)
            x = (x - u) / torch.sqrt(s + self.eps)
            x = self.weight[:, None, None] * x + self.bias[:, None, None]
            return x

In [6]:
class ConvNextBlock(nn.Module):
    def __init__(self, dim, drop_path=0., layer_scale_init_value = 1e-6):
        super(ConvNextBlock, self).__init__()
        
        self.layers = nn.Sequential(
            nn.Conv2d(dim, dim, 7, padding=3, groups=dim),
            LayerNorm(dim, eps=1e-6, data_format="channels_first"),
            conv1x1(dim, dim*4),
            nn.GELU(),
            conv1x1(dim*4, dim)
        )
        self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((dim)), 
                                    requires_grad=True) if layer_scale_init_value > 0 else None
        self.drop_path = DropPath(drop_path)
    def forward(self, x):
        original = x
        x = self.layers(x)
        if self.gamma is not None:
            x = x.mul(self.gamma.reshape(1, -1, 1, 1))
        x = original + self.drop_path(x)
        return x

In [7]:
class ConvNextStage(nn.Module):
    def __init__(self, dim, depth, dw_layer, drop_rate):
        super(ConvNextStage, self).__init__()
        
        self.stage = nn.Sequential(*[
            dw_layer,
            *[ConvNextBlock(dim) for j in range(depth)]
        ])
        
    def forward(self, x):
        x = self.stage(x)
        return x

In [8]:
class ConvNext(nn.Module):
    def __init__(self, n_classes=1000, 
                 depths = [3, 3, 9, 3], dims = [96, 192, 384, 768]):
        super(ConvNext, self).__init__()
        
        stages = []
            
        for i in range(4):
            if i == 0:
                dw_layer = nn.Sequential(
                    nn.Conv2d(3, dims[0], 4, stride=4),
                    LayerNorm(dims[0], eps=1e-6, data_format="channels_first")
                )
            else:
                dw_layer = nn.Sequential(
                    LayerNorm(dims[i-1], eps=1e-6, data_format="channels_first"),
                    nn.Conv2d(dims[i-1], dims[i], 2, stride=2)
                )
            
            stage = ConvNextStage(dims[i], depths[i], dw_layer, drop_rate=0.0)
            stages.append(stage)
            
        self.stages = nn.Sequential(*stages)
        self.norm = LayerNorm(dims[-1], eps=1e-6)
        self.head = nn.Linear(dims[-1], n_classes)
    
    def forward(self, x):
        x = self.stages(x)
        x = x.mean([-2, -1])
        x = self.norm(x)
        x = self.head(x)
        return x

In [19]:
def convnext_tiny(n_classes=1000):
    model = ConvNext(n_classes=n_classes, depths=[3, 3, 9, 3], dims=[96, 192, 384, 768])
    return model
def convnext_base(n_classes=1000):
    model = ConvNext(n_classes=n_classes, depths=[3, 3, 27, 3], dims=[128, 256, 512, 1024])
    return model
def convnext_large(n_classes=1000):
    model = ConvNext(n_classes=n_classes, depths=[3, 3, 27, 3], dims=[192, 384, 768, 1536])
    return model
def convnext_xlarge(n_classes=1000):
    model = ConvNext(n_classes=n_classes, depths=[3, 3, 27, 3], dims=[256, 512, 1024, 2048])
    return model

In [20]:
cn_tiny = convnext_tiny()
cn_base = convnext_base()
cn_large = convnext_large()
cn_xlarge = convnext_xlarge()

In [21]:
models = [cn_tiny, cn_base, cn_large, cn_xlarge]

In [22]:
%%time
inp = torch.randn(1, 3, 224, 224)
with torch.no_grad():
    print(cn_tiny(inp).shape)

torch.Size([1, 1000])
CPU times: user 264 ms, sys: 59.6 ms, total: 323 ms
Wall time: 187 ms


In [23]:
import os
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

In [24]:
for m in models:
    print_size_of_model(m)

Size (MB): 114.275869
Size (MB): 354.103137
Size (MB): 790.626145
Size (MB): 1400.164257


In [25]:
def fmat(n):
    return "{:.2f}M".format(n / 1_000_000)

In [26]:
def params(model, f=True):
    s = sum(p.numel() for p in model.parameters() if p.requires_grad)
    return fmat(s) if f else s

In [27]:
for m in models:
    print(params(m))

28.56M
88.50M
197.63M
350.02M


In [18]:
summary(cn_tiny, (1, 3, 224, 224), depth=4)

Layer (type:depth-idx)                             Output Shape              Param #
ConvNext                                           --                        --
├─Sequential: 1-1                                  [1, 768, 7, 7]            --
│    └─ConvNextStage: 2-1                          [1, 96, 56, 56]           --
│    │    └─Sequential: 3-1                        [1, 96, 56, 56]           --
│    │    │    └─Sequential: 4-1                   [1, 96, 56, 56]           4,896
│    │    │    └─ConvNextBlock: 4-2                [1, 96, 56, 56]           78,816
│    │    │    └─ConvNextBlock: 4-3                [1, 96, 56, 56]           78,816
│    │    │    └─ConvNextBlock: 4-4                [1, 96, 56, 56]           78,816
│    └─ConvNextStage: 2-2                          [1, 192, 28, 28]          --
│    │    └─Sequential: 3-2                        [1, 192, 28, 28]          --
│    │    │    └─Sequential: 4-5                   [1, 192, 28, 28]          74,112
│    │    │    └