In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torchvision.datasets
import torchvision.transforms as transforms
import torchvision.utils as vutils

In [2]:
class Generator(nn.Module):
    def __init__(self, z_dim=100, out_dim=784):
        super(Generator, self).__init__()
        self.deconv1 = nn.ConvTranspose2d(z_dim, 256, kernel_size=7, stride=1, padding=0)
        self.nonlin1 = nn.ReLU()
        self.deconv2 = nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1)
        self.nonlin2 = nn.ReLU()
        self.deconv3 = nn.ConvTranspose2d(128, 1, kernel_size=4, stride=2, padding=1)

    def forward(self, x):
        x = x.view(x.size(0), -1, 1, 1)  # reshape input to (batch_size, z_dim, 1, 1)
        h = self.nonlin1(self.deconv1(x))
        h = self.nonlin2(self.deconv2(h))
        out = torch.tanh(self.deconv3(h))
        return out

class Discriminator(nn.Module):
    def __init__(self, inp_dim=784):
        super(Discriminator, self).__init__()
        self.conv1 = nn.Conv2d(1, 64, kernel_size=4, stride=2, padding=1)
        self.nonlin1 = nn.LeakyReLU(0.2)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)
        self.nonlin2 = nn.LeakyReLU(0.2)
        self.fc = nn.Linear(7 * 7 * 128, 1)

    def forward(self, x):
        h = self.nonlin1(self.conv1(x))
        h = self.nonlin2(self.conv2(h))
        h = h.view(h.size(0), -1)  # flatten (batch_size, num_features)
        out = torch.sigmoid(self.fc(h))
        return out

In [3]:
D = Discriminator()
G = Generator()

In [4]:
print(D)
print(G)

Discriminator(
  (conv1): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (nonlin1): LeakyReLU(negative_slope=0.2)
  (conv2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (nonlin2): LeakyReLU(negative_slope=0.2)
  (fc): Linear(in_features=6272, out_features=1, bias=True)
)
Generator(
  (deconv1): ConvTranspose2d(100, 256, kernel_size=(7, 7), stride=(1, 1))
  (nonlin1): ReLU()
  (deconv2): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
  (nonlin2): ReLU()
  (deconv3): ConvTranspose2d(128, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
)
