In [15]:
import torch
import torch.nn as nn
from torchinfo import summary

In [16]:
class Generator(nn.Module):
    """
    输入Shape为(N, in_dim)，N为batch_size, in_dim是随机向量的维度
    输出Shape为(N, 3, 64, 64)，即生成N张64x64的彩色图像
    """

    def __init__(self, in_dim, dim=64):
        super(Generator, self).__init__()

        def dconv_bn_relu(in_dim, out_dim):
            return nn.Sequential(
                nn.ConvTranspose2d(in_dim, out_dim, 5, 2,
                                   padding=2, output_padding=1, bias=False),
                nn.BatchNorm2d(out_dim),
                nn.ReLU()
            )

        # 1. 先用线性层将随机向量变成 dim*8 个通道，大小为4x4的图片
        self.l1 = nn.Sequential(
            nn.Linear(in_dim, dim * 8 * 4 * 4, bias=False),
            nn.BatchNorm1d(dim * 8 * 4 * 4),
            nn.ReLU()
        )

        # 2. 然后就一直反卷积，不断的将图片变大，同时通道不断减小，最终变成一个3通道，64x64大小的图片
        self.l2_5 = nn.Sequential(
            dconv_bn_relu(dim * 8, dim * 4),
            dconv_bn_relu(dim * 4, dim * 2),
            dconv_bn_relu(dim * 2, dim),
            nn.ConvTranspose2d(dim, 3, 5, 2, padding=2, output_padding=1),
            nn.Tanh()
        )

    def forward(self, x):
        y = self.l1(x)
        y = y.view(y.size(0), -1, 4, 4)
        y = self.l2_5(y)
        return y

In [17]:
net = Generator(in_dim=100)

summary(net, (100, 100))

  action_fn=lambda data: sys.getsizeof(data.storage()),
  return super().__sizeof__() + self.nbytes()


Layer (type:depth-idx)                   Output Shape              Param #
Generator                                [100, 3, 64, 64]          --
├─Sequential: 1-1                        [100, 8192]               --
│    └─Linear: 2-1                       [100, 8192]               819,200
│    └─BatchNorm1d: 2-2                  [100, 8192]               16,384
│    └─ReLU: 2-3                         [100, 8192]               --
├─Sequential: 1-2                        [100, 3, 64, 64]          --
│    └─Sequential: 2-4                   [100, 256, 8, 8]          --
│    │    └─ConvTranspose2d: 3-1         [100, 256, 8, 8]          3,276,800
│    │    └─BatchNorm2d: 3-2             [100, 256, 8, 8]          512
│    │    └─ReLU: 3-3                    [100, 256, 8, 8]          --
│    └─Sequential: 2-5                   [100, 128, 16, 16]        --
│    │    └─ConvTranspose2d: 3-4         [100, 128, 16, 16]        819,200
│    │    └─BatchNorm2d: 3-5             [100, 128, 16, 16]    

In [18]:
class Discriminator(nn.Module):
    """
    输入Shape为(N, 3, 64, 64)，即N张64x64的彩色图片
    输出Shape为(N,), 即这N个图片每张图片的真实率，越接近1表示Discriminator越觉得它是真的
    """

    def __init__(self, in_dim=3, dim=64): # 注意这里的in_dim是指的图片的通道数，所以是3
        super(Discriminator, self).__init__()

        def conv_bn_lrelu(in_dim, out_dim):
            return nn.Sequential(
                nn.Conv2d(in_dim, out_dim, 5, 2, 2),
                nn.BatchNorm2d(out_dim),
                nn.LeakyReLU(0.2),
            )

        # 就是一堆卷积一直卷，把原始的图片最终卷成一个数字
        self.ls = nn.Sequential(
            nn.Conv2d(in_dim, dim, 5, 2, 2),
            nn.LeakyReLU(0.2),
            conv_bn_lrelu(dim, dim * 2),
            conv_bn_lrelu(dim * 2, dim * 4),
            conv_bn_lrelu(dim * 4, dim * 8),
            nn.Conv2d(dim * 8, 1, 4),
            nn.Sigmoid(),
        )

    def forward(self, x):
        y = self.ls(x)
        y = y.view(-1)
        return y

In [19]:
d_net = Discriminator()
summary(d_net, (100, 3, 64, 64))

Layer (type:depth-idx)                   Output Shape              Param #
Discriminator                            [100]                     --
├─Sequential: 1-1                        [100, 1, 1, 1]            --
│    └─Conv2d: 2-1                       [100, 64, 32, 32]         4,864
│    └─LeakyReLU: 2-2                    [100, 64, 32, 32]         --
│    └─Sequential: 2-3                   [100, 128, 16, 16]        --
│    │    └─Conv2d: 3-1                  [100, 128, 16, 16]        204,928
│    │    └─BatchNorm2d: 3-2             [100, 128, 16, 16]        256
│    │    └─LeakyReLU: 3-3               [100, 128, 16, 16]        --
│    └─Sequential: 2-4                   [100, 256, 8, 8]          --
│    │    └─Conv2d: 3-4                  [100, 256, 8, 8]          819,456
│    │    └─BatchNorm2d: 3-5             [100, 256, 8, 8]          512
│    │    └─LeakyReLU: 3-6               [100, 256, 8, 8]          --
│    └─Sequential: 2-5                   [100, 512, 4, 4]          --
