In [1]:
import torch
from torch import nn
from torchvision import models

# ConvNeXt

In [2]:
model = models.convnext_tiny(weights=models.ConvNeXt_Tiny_Weights.DEFAULT)

Downloading: "https://download.pytorch.org/models/convnext_tiny-983f1562.pth" to C:\Users\Administrator/.cache\torch\hub\checkpoints\convnext_tiny-983f1562.pth
100%|██████████| 109M/109M [00:09<00:00, 11.7MB/s] 


In [3]:
for name, _ in model.named_children():
    print(name)

features
avgpool
classifier


In [4]:
del model.avgpool
del model.classifier

In [6]:
model.features

Sequential(
  (0): Conv2dNormActivation(
    (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
    (1): LayerNorm2d((96,), eps=1e-06, elementwise_affine=True)
  )
  (1): Sequential(
    (0): CNBlock(
      (block): Sequential(
        (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
        (1): Permute()
        (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
        (3): Linear(in_features=96, out_features=384, bias=True)
        (4): GELU(approximate='none')
        (5): Linear(in_features=384, out_features=96, bias=True)
        (6): Permute()
      )
      (stochastic_depth): StochasticDepth(p=0.0, mode=row)
    )
    (1): CNBlock(
      (block): Sequential(
        (0): Conv2d(96, 96, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=96)
        (1): Permute()
        (2): LayerNorm((96,), eps=1e-06, elementwise_affine=True)
        (3): Linear(in_features=96, out_features=384, bias=True)
        (4): GELU(approximate='none')

In [10]:
x = torch.ones(1, 3, 224, 224)
print(x.size())
model.eval()
with torch.inference_mode():
    x = model.features[:2](x)
    print(x.size())
    x = model.features[2:4](x)
    feat1 = x
    print(x.size())
    x = model.features[4:6](x)
    feat2 = x
    print(x.size())
    x = model.features[6:](x)
    feat3 = x
    print(x.size())

torch.Size([1, 3, 224, 224])
torch.Size([1, 96, 56, 56])
torch.Size([1, 192, 28, 28])
torch.Size([1, 384, 14, 14])
torch.Size([1, 768, 7, 7])


In [11]:
print(feat1.size(), feat2.size(), feat3.size())

torch.Size([1, 192, 28, 28]) torch.Size([1, 384, 14, 14]) torch.Size([1, 768, 7, 7])


# ConvNeXt backbone

In [14]:
class ConvNeXt(nn.Module):
    def __init__(self, variance: str = "convnext_tiny", pretrained=True) -> None:
        """convnext backbone

        Args:
            variance (str, optional): convnext version in [convnext_tiny, convnext_small, convnext_base, convnext_large]. Defaults to "convnext_tiny".
        """
        super().__init__()
        if variance == "convnext_tiny":
            self.model = models.convnext_tiny(weights=models.ConvNeXt_Tiny_Weights.DEFAULT if pretrained else None)
        elif variance == "convnext_small":
            self.model = models.convnext_small(weights=models.ConvNeXt_Small_Weights.DEFAULT if pretrained else None)
        elif variance == "convnext_base":
            self.model = models.convnext_base(weights=models.ConvNeXt_Base_Weights.DEFAULT if pretrained else None)
        elif variance == "convnext_large":
            self.model = models.convnext_large(weights=models.ConvNeXt_Large_Weights.DEFAULT if pretrained else None)

        del self.model.avgpool
        del self.model.classifier

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.model.features[:4](x)
        feat1 = x
        x = self.model.features[4:6](x)
        feat2 = x
        x = self.model.features[6:](x)
        feat3 = x

        return [feat1, feat2, feat3]

In [35]:
model = ConvNeXt("convnext_tiny", False)

In [36]:
x = torch.ones(1, 3, 640, 640)
feats = model(x)
for feat in feats:
    print(feat.size())

torch.Size([1, 192, 80, 80])
torch.Size([1, 384, 40, 40])
torch.Size([1, 768, 20, 20])


> v1 v2 channel 数量相同

| ch    | t    | s    | b    | l    |
| ----- | ---- | ---- | ---- | ---- |
| feat1 | 192  | 192  | 256  | 384  |
| feat2 | 384  | 384  | 512  | 768  |
| feat3 | 768  | 768  | 1024 | 1536 |