In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image

from gan import Discriminator, Generator, G_train, D_train, ConvDiscriminator, ConvGenerator

In [2]:
device = torch.device('mps')

This part was highly inspired by [this](https://github.com/lyeoni/pytorch-mnist-GAN) Github repo. 

#### 1. Get the dataset

In [3]:
batch_size = 100

In [4]:
# transform = transforms.Compose([
#     transforms.Grayscale(num_output_channels=1),
#     transforms.ToTensor()
# ])

transform = transforms.Compose([
    transforms.Resize(64),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

In [5]:
train_dataset = datasets.MNIST(root='./data/mnist_data/', train=True, transform=transform, download=True, )
test_dataset = datasets.MNIST(root='./data/mnist_data/', train=False, transform=transform, download=False)

In [6]:
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

#### 2. Set up the GAN itself

In [7]:
z_dim = 1000
mnist_dim = train_dataset.train_data.size(1) * train_dataset.train_data.size(2)

# G = Generator(g_input_dim = z_dim, g_output_dim = mnist_dim).to(device)
# D = Discriminator(mnist_dim).to(device)

G = ConvGenerator(g_input_dim = z_dim, g_output_channels=1).to(device)
D = ConvDiscriminator(d_input_channels=1).to(device)



##### 2.1 Special weight initialization

In [8]:
def weights_init(m):
    if isinstance(m, nn.Conv2d):
        nn.init.normal_(m.weight, mean=0.0, std=0.02)
    
    elif isinstance(m, nn.BatchNorm2d):
        nn.init.normal_(m.weight, mean=1.0, std=0.02)
        nn.init.constant_(m.bias, 0)

In [9]:
G.apply(weights_init)
D.apply(weights_init);

#### 3. Hyperparameters

In [10]:
criterion = nn.BCELoss() 

lr = 0.0002
G_optimizer = optim.Adam(G.parameters(), lr = lr, betas=(0.5, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr = lr, betas=(0.5, 0.999))

In [11]:
G

ConvGenerator(
  (conv1): ConvTranspose2d(1000, 512, kernel_size=(4, 4), stride=(1, 1), bias=False)
  (batch1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv2): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (batch2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv3): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (batch3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv4): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
  (batch4): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (conv5): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
)

#### 4. Train the model

In [12]:
n_epoch = 5
for epoch in range(1, n_epoch+1):           
    D_losses, G_losses = [], []
    for batch_idx, (x, _) in enumerate(train_loader):
        D_losses.append(D_train(D, G, D_optimizer, mnist_dim, z_dim, batch_size, device, criterion, x, use_conv=True))
        G_losses.append(G_train(D, G, G_optimizer, mnist_dim, z_dim, batch_size, device, criterion, x, use_conv=True))

    print('[%d/%d]: loss_d: %.3f, loss_g: %.3f' % (
            (epoch), n_epoch, torch.mean(torch.FloatTensor(D_losses)), torch.mean(torch.FloatTensor(G_losses))))

[1/5]: loss_d: 0.371, loss_g: 26.262
[2/5]: loss_d: 0.747, loss_g: 3.661
[3/5]: loss_d: 0.533, loss_g: 3.055
[4/5]: loss_d: 0.357, loss_g: 4.117


KeyboardInterrupt: 

In [None]:
with torch.no_grad():
    test_z = Variable(torch.randn(batch_size, z_dim).to(device))
    generated = G(test_z)

    save_image(generated.view(generated.size(0), 1, 28, 28), './samples/gan_sample_' + '.png')