In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [2]:
def window_partition(x, window_size):
    """
    对特征图进行分块
    """
    B, H, W, C = x.shape
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    x = x.permute(0, 1, 3, 2, 4, 5).contiguous()
    x = x.view(B, -1, window_size, window_size, C)
    return x

def window_reverse(x, window_size, H, W):
    """
    反向操作，将特征图块还原成完整的特征图
    """
    B, _, _, _, C = x.shape
    x = x.view(B, -1, window_size, window_size, C)
    x = x.permute(0, 1, 3, 2, 4).contiguous()
    x = x.view(B, H, W, C)
    return x

class SwinBlock(nn.Module):
    def __init__(self, in_channels, out_channels, window_size, shift_size):
        super().__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=window_size, stride=shift_size, padding=window_size//2, groups=out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=1)
        self.bn3 = nn.BatchNorm2d(out_channels)
        self.relu = nn.ReLU(inplace=True)

    def forward(self, x):
        identity = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        x = window_partition(x, window_size=2)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = self.relu(x)
        x = window_reverse(x, window_size=2, H=identity.shape[2], W=identity.shape[3])
        x += identity
        return x


In [3]:
class SwinEncoder(nn.Module):
    def __init__(self, in_channels, hidden_size, num_layers, window_size, shift_size):
        super().__init__()
        self.layers = nn.ModuleList([
            SwinBlock(in_channels, hidden_size, window_size=window_size, shift_size=shift_size)
            for _ in range(num_layers)
        ])

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        return x

class SwinDecoder(nn.Module):
    def __init__(self, hidden_size, num_classes):
        super().__init__()
        self.fc = nn.Linear(hidden_size, num_classes)

    def forward(self, x):
        x = x.mean(dim=[2, 3])
        x = self.fc(x)
        return x


In [6]:
from torchvision.models import resnet152
import torch.nn as nn
class Flatten(nn.Module):
    def __init__(self):
        super(Flatten, self).__init__()
        self.shape = 0

    def forward(self, x):
        self.shape = torch.prod(torch.tensor(x.shape[1:])).item()
        return x.view(-1, self.shape)

class SelfDefineModel(nn.Module):

    def __init__(self):
        super(SelfDefineModel, self).__init__()
        self.trained_model = resnet152(pretrained=True)  # .to(device)
        self.model1 = nn.Sequential(*list(self.trained_model.children())[:-1],  # 测试一下输出维度[b, 512, 1, 1]
                                    Flatten(),
                                    nn.Linear(2048, 256),
                                    nn.ReLU(inplace=True),
                                    nn.Linear(256, 8),
                                    )

In [7]:
def getModelSize(model):
    param_size = 0
    param_sum = 0
    for param in model.parameters():
        param_size += param.nelement() * param.element_size()
        param_sum += param.nelement()
    buffer_size = 0
    buffer_sum = 0
    for buffer in model.buffers():
        buffer_size += buffer.nelement() * buffer.element_size()
        buffer_sum += buffer.nelement()
    all_size = (param_size + buffer_size) / 1024 / 1024
    print('模型总大小为：{:.3f}MB'.format(all_size))
    return (param_size, param_sum, buffer_size, buffer_sum, all_size)


In [8]:
getModelSize(SelfDefineModel())

模型总大小为：232.205MB


(242877632, 60719408, 606936, 151579, 232.2049789428711)

In [11]:
from torchvision.models import vgg16,resnet152

In [12]:
a = resnet152()

In [10]:
list(a.children())[0][0:3]

Sequential(
  (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
  (1): ReLU(inplace=True)
  (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)

In [16]:
a

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): Bottleneck(
      (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (downsample): Sequential(
        (0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 