In [6]:
# 两个相加的feature map的形状必须是一样的

import time
import torch
from torch import nn, optim
import torch.nn.functional as F

import sys
sys.path.append("..")
import PyCode.PyTorch_Learn.d2lzh_pytorch as d2l
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


# 定义残差块,可以设定输入\输出的通道数,以及是否使用1X1卷积和设定步长对输出的形状进行调整
class Residual(nn.Module):
    def __init__(self, in_channels, out_channels, use_1x1conv=False, stride=1):
        super(Residual, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, stride=stride, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, padding=1)
        if use_1x1conv:
            self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride)
        else:
            self.conv3 = None
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.bn2 = nn.BatchNorm2d(out_channels)

    def forward(self, X):
        Y = F.relu(self.bn1(self.conv1(X)))
        Y = self.bn2(self.conv2(Y))
        if self.conv3:
            X = self.conv3(X)
        return F.relu(Y + X)


blk = Residual(3, 3)
X = torch.rand((2, 3, 6, 6)) # 样本组数\样本层数\样本行数\样本列数
print(X)

tensor([[[[0.5267, 0.9650, 0.4840],
          [0.3138, 0.7809, 0.8616],
          [0.1053, 0.4977, 0.0979],
          [0.0635, 0.6376, 0.7701],
          [0.6356, 0.4748, 0.6783],
          [0.8096, 0.5614, 0.5748]],

         [[0.5338, 0.6107, 0.3046],
          [0.1046, 0.2932, 0.4811],
          [0.8798, 0.7367, 0.9893],
          [0.6315, 0.2950, 0.8836],
          [0.2864, 0.1221, 0.0390],
          [0.1969, 0.4689, 0.5067]],

         [[0.0467, 0.0797, 0.6226],
          [0.6337, 0.9890, 0.1672],
          [0.9020, 0.2979, 0.7176],
          [0.5038, 0.8381, 0.2298],
          [0.4558, 0.3417, 0.6988],
          [0.9260, 0.7905, 0.0434]]],


        [[[0.4777, 0.9853, 0.2678],
          [0.9205, 0.0718, 0.7372],
          [0.7431, 0.5045, 0.8241],
          [0.7875, 0.0451, 0.4490],
          [0.2426, 0.3428, 0.3994],
          [0.7265, 0.9255, 0.3337]],

         [[0.8017, 0.1918, 0.6678],
          [0.9714, 0.8726, 0.1209],
          [0.1733, 0.7318, 0.6322],
          [0.5113,

In [7]:
blk(X).shape



torch.Size([2, 3, 6, 3])