In [25]:
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
import torch
import torch.nn.functional as F
from torch import nn
from torch.autograd import Variable

import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
from tqdm import tqdm_notebook

from torchvision.utils import save_image

In [26]:
train_loader = DataLoader(
    datasets.FashionMNIST('./fashion', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
#                        transforms.Normalize(0, 1)
                   ])),
    batch_size=32, shuffle=True)

In [27]:
test_loader = DataLoader(
    datasets.FashionMNIST('./fashion', train=False, transform=transforms.Compose([
                       transforms.ToTensor(),
#                        transforms.Normalize((0), (1))
                   ])),
    batch_size=32, shuffle=True)


In [28]:
class DCGenerator(nn.Module):
    def __init__(self):
        super(DCGenerator, self).__init__()
        g_convs = [(64, 7, 1, 0), (32, 4, 2, 1), (1, 4, 2, 1)]

        self.convs = nn.Sequential(
            nn.ConvTranspose2d(1, 64, 7, 1, 0),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, 32, 4, 2, 1),
            nn.BatchNorm2d(32),
            nn.ReLU(),
            nn.ConvTranspose2d(32, 1, 4, 2, 1),
            nn.Tanh()
        )
        
    def forward(self, x):
        return self.convs(x)

In [39]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.convs = nn.Sequential(
            nn.Conv2d(1, 32, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(32, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 1, 7, 1, 0),
        )

    def forward(self, x):
        x = self.convs(x)
        x = x.view(x.size(0), -1)
        x = F.sigmoid(x)
        return x

In [40]:
def sample_noise(batch_size, channels):
    return torch.randn(batch_size, channels, 1, 1).float()

In [41]:
max_iter = 25
trans = transforms.Compose([transforms.ToTensor(),
                            transforms.Normalize([0.5,], [0.5,])])

mnist = datasets.MNIST('./mnist/', train=True, transform=trans, download=True)

batch_size = 64

In [42]:
discriminator = Discriminator()
generator = DCGenerator()
print(discriminator)
print(generator)

Discriminator(
  (convs): Sequential(
    (0): Conv2d(1, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU(0.2)
    (2): Conv2d(32, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (4): LeakyReLU(0.2)
    (5): Conv2d(64, 1, kernel_size=(7, 7), stride=(1, 1))
  )
)
DCGenerator(
  (convs): Sequential(
    (0): ConvTranspose2d(1, 64, kernel_size=(7, 7), stride=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (2): ReLU()
    (3): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True)
    (5): ReLU()
    (6): ConvTranspose2d(32, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): Tanh()
  )
)


In [43]:
dataloader = DataLoader(mnist, batch_size=batch_size, shuffle=True, drop_last=True)

optimizer_d = torch.optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_g = torch.optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
real_label = 1
fake_label = 0
criterion = nn.BCELoss()

In [44]:
fixed_noise = sample_noise(batch_size, 1)
fixed_noise = Variable(fixed_noise, volatile=True)

In [None]:
for epoch in tqdm_notebook(range(1, max_iter+1)):
    for i, (image, label) in enumerate(dataloader):
        
        #real image

        optimizer_d.zero_grad()
        image = Variable(image)
        output = discriminator(image)
        real_v = Variable(torch.ones(batch_size).float())
        loss_d = criterion(output, real_v)
        loss_d.backward()
        
        Dx = torch.mean(output.data, dim=0)[0]
        
        # training D on fake data
        z = sample_noise(batch_size, 1)
        z = Variable(z)

        fake = generator(z)
        
        output = discriminator(fake.detach())
        
        fake_v = Variable(torch.zeros(batch_size).float())

        loss_g = criterion(output, fake_v)
        loss_g.backward()
        
        optimizer_d.step()

        err_D = loss_d.data[0] + loss_g.data[0]

        # training G
        optimizer_g.zero_grad()
        output = discriminator(fake)
        real_v = Variable(torch.ones(batch_size).float())

        loss = criterion(output, real_v)
        loss.backward()
        optimizer_g.step()
        
        err_G = loss.data[0]
        DGz = torch.mean(output.data, dim=0)[0]

        print('[{:02d}/{:02d}],[{:03d}/{:03d}], errD: {:.4f}, D(x): {:.4f}, errG: {:.4f}, D(G(z)): {:.4f}'.format(
              epoch, max_iter, i, len(dataloader), err_D, Dx, err_G, DGz))

    fake = generator(fixed_noise)

    save_image(fake.data, './gans/mnist-fake-{:02d}.png'.format(epoch),
               normalize=True)

  "Please ensure they have the same size.".format(target.size(), input.size()))


[01/25],[000/937], errD: 1.3461, D(x): 0.5256, errG: 0.7659, D(G(z)): 0.4661
[01/25],[001/937], errD: 1.2844, D(x): 0.5537, errG: 0.7789, D(G(z)): 0.4603
[01/25],[002/937], errD: 1.2544, D(x): 0.5639, errG: 0.8002, D(G(z)): 0.4516
[01/25],[003/937], errD: 1.1886, D(x): 0.5942, errG: 0.8212, D(G(z)): 0.4422
[01/25],[004/937], errD: 1.1382, D(x): 0.6087, errG: 0.8597, D(G(z)): 0.4268
[01/25],[005/937], errD: 1.0995, D(x): 0.6327, errG: 0.8674, D(G(z)): 0.4237
[01/25],[006/937], errD: 1.0446, D(x): 0.6475, errG: 0.9088, D(G(z)): 0.4072
[01/25],[007/937], errD: 1.0234, D(x): 0.6513, errG: 0.9321, D(G(z)): 0.3978
[01/25],[008/937], errD: 0.9554, D(x): 0.6673, errG: 0.9959, D(G(z)): 0.3748
[01/25],[009/937], errD: 0.9145, D(x): 0.6864, errG: 1.0219, D(G(z)): 0.3657
[01/25],[010/937], errD: 0.8849, D(x): 0.6791, errG: 1.0787, D(G(z)): 0.3459
[01/25],[011/937], errD: 0.8321, D(x): 0.6985, errG: 1.1300, D(G(z)): 0.3308
[01/25],[012/937], errD: 0.7938, D(x): 0.7164, errG: 1.1607, D(G(z)): 0.3218