In [2]:
import torch
import torch.nn as nn

#### Генератор и дискриминатор будем делать исходя из статьи
<br>

![Схема генератора из статьи](./images/generator_scheme.png)

In [3]:
class Discriminator(nn.Module):
    def __init__(self, channels):
        super(Discriminator, self).__init__()
        self.discriminator = nn.Sequential(
            nn.Conv2d(
                in_channels=channels,
                out_channels=128,
                kernel_size=(4, 4),
                stride=(2,2),
                padding=1
            ),
            nn.LeakyReLU(0.2),
            self._conv_block(128, 256, 4, 2, 1),
            self._conv_block(256, 512, 4, 2, 1),
            self._conv_block(512, 1024, 4, 2, 1),
            nn.Conv2d(in_channels=1024, out_channels=1, kernel_size=(4,4), stride=(1,1), padding=0),
            nn.Sigmoid()
        )
    def _conv_block(self, in_channels, out_channels, kernel, stride, padding):
        return nn.Sequential(
            nn.Conv2d(
                in_channels = in_channels,
                out_channels=out_channels,
                kernel_size=kernel,
                stride=stride,
                padding=padding,
                bias=False
            ),
            nn.BatchNorm2d(out_channels),
            nn.LeakyReLU(0.2)
        )

    def forward(self, x):
        return self.discriminator(x)

In [60]:
class Generator(nn.Module):
    def __init__(self, channels, z_dim):
        super(Generator, self).__init__()
        self.generator = nn.Sequential(
            self._unconv_block(in_channels=z_dim, out_channels=1024, kernel=(4,4), stride=(1,1), padding=0),
            self._unconv_block(in_channels=1024, out_channels=512, kernel=(4,4), stride=(2,2), padding=1),
            self._unconv_block(in_channels=512, out_channels=256, kernel=(4,4), stride=(2,2), padding=1),
            self._unconv_block(in_channels=256, out_channels=128, kernel=(4,4), stride=(2,2), padding=1),
            nn.ConvTranspose2d(in_channels=128,
                               out_channels=channels,
                               kernel_size=(4,4),
                               stride=(2,2),
                               padding=(1,1)),
            nn.Tanh()
        )
    def _unconv_block(self, in_channels, out_channels, kernel, stride, padding):
        return nn.Sequential(
            nn.ConvTranspose2d(in_channels=in_channels,
                               out_channels=out_channels,
                               kernel_size=kernel,
                               stride=stride,
                               padding=padding,
                               bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
    def forward(self, x):
        return self.generator(x)

In [61]:
a = Generator(3, 100)

In [62]:
test_batch = torch.ones((32, 100, 1,1))

In [63]:
test_batch.shape

torch.Size([32, 100, 1, 1])

In [64]:
a.forward(test_batch).shape

torch.Size([32, 3, 64, 64])