In [17]:
import torch
from torch import nn
from torchvision import models

# Swin

In [2]:
model = models.swin_t(weights=models.Swin_T_Weights.DEFAULT)

In [3]:
for name, _ in model.named_children():
    print(name)

features
norm
permute
avgpool
flatten
head


In [4]:
del model.norm
del model.permute
del model.avgpool
del model.flatten
del model.head

In [5]:
model.features

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
    (1): Permute()
    (2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
  )
  (1): Sequential(
    (0): SwinTransformerBlock(
      (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      (attn): ShiftedWindowAttention(
        (qkv): Linear(in_features=96, out_features=288, bias=True)
        (proj): Linear(in_features=96, out_features=96, bias=True)
      )
      (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      (mlp): MLP(
        (0): Linear(in_features=96, out_features=384, bias=True)
        (1): GELU(approximate='none')
        (2): Dropout(p=0.0, inplace=False)
        (3): Linear(in_features=384, out_features=96, bias=True)
        (4): Dropout(p=0.0, inplace=False)
      )
    )
    (1): SwinTransformerBlock(
      (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      (attn): Sh

In [6]:
x = torch.ones(1, 3, 224, 224)
print(x.size())
model.eval()
with torch.inference_mode():
    x = model.features[:2](x)
    print(x.size())
    x = model.features[2:4](x)
    feat1 = x.permute(0, 3, 1, 2)
    print(x.size())
    x = model.features[4:6](x)
    feat2 = x.permute(0, 3, 1, 2)
    print(x.size())
    x = model.features[6:](x)
    feat3 = x.permute(0, 3, 1, 2)
    print(x.size())

torch.Size([1, 3, 224, 224])
torch.Size([1, 56, 56, 96])
torch.Size([1, 28, 28, 192])
torch.Size([1, 14, 14, 384])
torch.Size([1, 7, 7, 768])


In [7]:
print(feat1.size(), feat2.size(), feat3.size())

torch.Size([1, 192, 28, 28]) torch.Size([1, 384, 14, 14]) torch.Size([1, 768, 7, 7])


# SwinV2

In [8]:
model = models.swin_v2_t(weights=models.Swin_V2_T_Weights.DEFAULT)

In [9]:
for name, _ in model.named_children():
    print(name)

features
norm
permute
avgpool
flatten
head


In [10]:
del model.norm
del model.permute
del model.avgpool
del model.flatten
del model.head

In [11]:
model.features

Sequential(
  (0): Sequential(
    (0): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
    (1): Permute()
    (2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
  )
  (1): Sequential(
    (0): SwinTransformerBlockV2(
      (norm1): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      (attn): ShiftedWindowAttentionV2(
        (qkv): Linear(in_features=96, out_features=288, bias=True)
        (proj): Linear(in_features=96, out_features=96, bias=True)
        (cpb_mlp): Sequential(
          (0): Linear(in_features=2, out_features=512, bias=True)
          (1): ReLU(inplace=True)
          (2): Linear(in_features=512, out_features=3, bias=False)
        )
      )
      (stochastic_depth): StochasticDepth(p=0.0, mode=row)
      (norm2): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      (mlp): MLP(
        (0): Linear(in_features=96, out_features=384, bias=True)
        (1): GELU(approximate='none')
        (2): Dropout(p=0.0, inplace=False)
        (3): Linear(in_fe

In [12]:
x = torch.ones(1, 3, 256, 256)
print(x.size())
model.eval()
with torch.inference_mode():
    x = model.features[:2](x)
    print(x.size())
    x = model.features[2:4](x)
    feat1 = x.permute(0, 3, 1, 2)
    print(x.size())
    x = model.features[4:6](x)
    feat2 = x.permute(0, 3, 1, 2)
    print(x.size())
    x = model.features[6:](x)
    feat3 = x.permute(0, 3, 1, 2)
    print(x.size())

torch.Size([1, 3, 256, 256])
torch.Size([1, 64, 64, 96])
torch.Size([1, 32, 32, 192])
torch.Size([1, 16, 16, 384])
torch.Size([1, 8, 8, 768])


In [13]:
print(feat1.size(), feat2.size(), feat3.size())

torch.Size([1, 192, 32, 32]) torch.Size([1, 384, 16, 16]) torch.Size([1, 768, 8, 8])


# Swin backbone

In [18]:
class Swin(nn.Module):
    def __init__(self, variance: str = "swin_t", pretrained=True) -> None:
        """swin swin_v2 backbone

        Args:
            variance (str, optional): swin version in [swin_t, swin_s, swin_b, swin_v2_t, swin_v2_s, swin_v2_b]. Defaults to "swin_t".
        """
        super().__init__()
        if variance == "swin_t":
            self.model = models.swin_t(weights=models.Swin_T_Weights.DEFAULT if pretrained else None)
        elif variance == "swin_s":
            self.model = models.swin_s(weights=models.Swin_S_Weights.DEFAULT if pretrained else None)
        elif variance == "swin_b":
            self.model = models.swin_b(weights=models.Swin_B_Weights.DEFAULT if pretrained else None)
        elif variance == "swin_v2_t":
            self.model = models.swin_v2_t(weights=models.Swin_V2_T_Weights.DEFAULT if pretrained else None)
        elif variance == "swin_v2_s":
            self.model = models.swin_v2_s(weights=models.Swin_V2_S_Weights.DEFAULT if pretrained else None)
        elif variance == "swin_v2_b":
            self.model = models.swin_v2_b(weights=models.Swin_V2_B_Weights.DEFAULT if pretrained else None)

        del self.model.norm
        del self.model.permute
        del self.model.avgpool
        del self.model.flatten
        del self.model.head

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.model.features[:4](x)
        feat1 = x.permute(0, 3, 1, 2)
        x = self.model.features[4:6](x)
        feat2 = x.permute(0, 3, 1, 2)
        x = self.model.features[6:](x)
        feat3 = x.permute(0, 3, 1, 2)

        return [feat1, feat2, feat3]

In [40]:
model = Swin("swin_v2_t", False)

In [41]:
x = torch.ones(1, 3, 640, 640)
feats = model(x)
for feat in feats:
    print(feat.size())

torch.Size([1, 192, 80, 80])
torch.Size([1, 384, 40, 40])
torch.Size([1, 768, 20, 20])


> v1 v2 channel 数量相同

| ch    | t    | s    | b    |
| ----- | ---- | ---- | ---- |
| feat1 | 192  | 192  | 256  |
| feat2 | 384  | 384  | 512  |
| feat3 | 768  | 768  | 1024 |