In [1]:
import torch

In [2]:
a = torch.tensor([
    [1, 2, 0, 3, 1],
    [0, 1, 2, 3, 1],
    [1, 2, 1, 0, 0],
    [5, 2, 3, 1, 1],
    [2, 1, 0, 1, 1]
], dtype=torch.float32)  # 如果不使用float类型，MaxPool2d()将报错：RuntimeError: "max_pool2d" not implemented for 'Long'

a = torch.reshape(a, (-1, 1, 5, 5))
a.shape

torch.Size([1, 1, 5, 5])

In [13]:
class MyModule(torch.nn.Module):
    def __init__(self):
        super(MyModule, self).__init__()
        self.maxpool_00 = torch.nn.MaxPool2d(
            kernel_size=3,
            ceil_mode=True
        )

    def forward(self, input):
        output = self.maxpool_00(input)
        return output


module = MyModule()
module(a)

tensor([[[[2., 3.],
          [5., 1.]]]])

## 演示最大池化的作用

卷积——提取特征；池化——压缩特征。

In [10]:
import torchvision

transformer = torchvision.transforms.Compose([
    torchvision.transforms.ToTensor()
])

DATA_DIR = '../05-Transforms/data/'

test_data = torchvision.datasets.CIFAR10(
    root=DATA_DIR,
    transform=transformer,
    train=False,
    download=True
)

test_loader = torch.utils.data.DataLoader(
    dataset=test_data,
    batch_size=64,
    shuffle=True,
    num_workers=4,
    drop_last=False
)

Files already downloaded and verified


In [11]:
from torch.utils.tensorboard import SummaryWriter

writer = SummaryWriter('./logs')

In [14]:
for i, data in enumerate(test_loader):
    imgs, targets = data
    print(imgs.shape)
    # 输入imgs的shape为[64, 3, 32, 32]（loader的batch_size=64，RGB三通道图像，图片尺寸为32×32），输出outputs的形状为[64, 6, 30, 30]（卷继层的out_channels=6，卷积核大小为3×3，图片尺寸经过一次卷积后形状变成了30×30）
    outputs = module(imgs)
    print(outputs.shape)

    writer.add_images('Input', imgs, i)
    writer.add_images('Output', outputs, i)

writer.close()
# tensorboard --logdir='logs'

torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 11, 11])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 11, 11])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 11, 11])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 11, 11])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 11, 11])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 11, 11])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 11, 11])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 11, 11])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 11, 11])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 11, 11])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 11, 11])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 11, 11])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 11, 11])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 11, 11])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 11, 11])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 11, 11])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 11, 11])
torch.Size([64, 3, 32, 32])
torch.Size([64, 3, 1