In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torchvision
from torchvision import datasets, transforms
from torch.autograd import Variable
from torchvision.utils import save_image
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.pyplot import imshow, imsave
%matplotlib inline

In [2]:
bs = 64
n_epoch = 40
z_dim = 100
mnist_dim = 784
lr = 0.001
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.5], std=[0.5])
])

train_dataset = datasets.MNIST(root='../data/', train=True, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=bs, shuffle=True, drop_last=True)

In [3]:
class Discriminator(nn.Module):
    """
        Convolutional Discriminator for MNIST
    """
    def __init__(self, in_channel=1, num_classes=1):
        super(Discriminator, self).__init__()
        self.conv = nn.Sequential(
            # 28 -> 14
            nn.Conv2d(in_channel, 512, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            # 14 -> 7
            nn.Conv2d(512, 256, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            # 7 -> 4
            nn.Conv2d(256, 128, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.AvgPool2d(4),
        )
        self.fc = nn.Sequential(
            # reshape input, 128 -> 1
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )
    
    def forward(self, x, y=None):
        y_ = self.conv(x)
        y_ = y_.view(y_.size(0), -1)
        y_ = self.fc(y_)
        return y_
    
class Generator(nn.Module):
    """
        Convolutional Generator for MNIST
    """
    def __init__(self, input_size=100, num_classes=784):
        super(Generator, self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(input_size, 4*4*512),
            nn.ReLU(),
        )
        self.conv = nn.Sequential(
            # input: 4 by 4, output: 7 by 7
            nn.ConvTranspose2d(512, 256, 3, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(),
            # input: 7 by 7, output: 14 by 14
            nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(),
            # input: 14 by 14, output: 28 by 28
            nn.ConvTranspose2d(128, 1, 4, stride=2, padding=1, bias=False),
            nn.Tanh(),
        )
        
    def forward(self, x, y=None):
        x = x.view(x.size(0), -1)
        y_ = self.fc(x)
        y_ = y_.view(y_.size(0), 512, 4, 4)
        y_ = self.conv(y_)
        return y_

In [4]:
D = Discriminator().to(device)
G = Generator().to(device)

In [5]:
G

Generator(
  (fc): Sequential(
    (0): Linear(in_features=100, out_features=8192, bias=True)
    (1): ReLU()
  )
  (conv): Sequential(
    (0): ConvTranspose2d(512, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): ReLU()
    (6): ConvTranspose2d(128, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1), bias=False)
    (7): Tanh()
  )
)

In [6]:
D

Discriminator(
  (conv): Sequential(
    (0): Conv2d(1, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): LeakyReLU(negative_slope=0.2)
    (3): Conv2d(512, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (4): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (5): LeakyReLU(negative_slope=0.2)
    (6): Conv2d(256, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
    (7): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (8): LeakyReLU(negative_slope=0.2)
    (9): AvgPool2d(kernel_size=4, stride=4, padding=0)
  )
  (fc): Sequential(
    (0): Linear(in_features=128, out_features=1, bias=True)
    (1): Sigmoid()
  )
)

In [None]:
criterion = nn.BCELoss()

# optimizer
G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

D_real = torch.ones([bs, 1]).to(device) # Discriminator Label to real
D_fake = torch.zeros([bs, 1]).to(device) # Discriminator Label to fake

In [None]:
D_losses, G_losses = [], []
real_scores, fake_scores = [], []
fixed_noise = Variable(torch.randn(bs, z_dim).to(device))

In [None]:
for epoch in range(1, n_epoch+1):
    for batch_idx, (x, _) in enumerate(train_loader):
        # train discriminator on real data
        D_output = D(x.to(device))
        real_score = torch.mean(D_output)
        D_x_loss = criterion(D_output, D_real)
        
        # train discriminator on fake data
        z = torch.randn(bs, z_dim).to(device)
        D_output = D(G(z))
        fake_score = torch.mean(D_output)
        D_z_loss = criterion(D_output, D_fake)
        
        D_loss = D_x_loss + D_z_loss
        D_optimizer.zero_grad()
        D_loss.backward()
        D_optimizer.step()
        
        # train generator
        z = torch.randn(bs, z_dim).to(device)
        D_output = D(G(z))
        G_loss = criterion(D_output, D_real)
        D_optimizer.zero_grad()
        G_optimizer.zero_grad()
        G_loss.backward()
        G_optimizer.step()
        
        if batch_idx % 100 == 99:
            D_losses.append(D_loss)
            G_losses.append(G_loss)
            real_scores.append(real_score)
            fake_scores.append(fake_score)
            print('epoch[%d/%d]batch[%d/%d]: loss_d: %.3f, loss_g: %.3f, D(x): %.3f, D(G(z)): %.3f' %
                 (epoch, n_epoch, batch_idx, len(train_loader), D_loss.item(), G_loss.item(), real_score.item(), fake_score.item()))
    
    # save models
    torch.save(G.state_dict(), './models/dcgan/G.pth')
    torch.save(D.state_dict(), './models/dcgan/D.pth')
    
    # save fake images
    with torch.no_grad():
        generated = G(fixed_noise)
        save_image(generated.view(generated.size(0), 1, 28, 28), './samples/dcgan/sample_%03d.png' % epoch, 8)

epoch[1/40]batch[99/937]: loss_d: 0.753, loss_g: 1.313, D(x): 0.698, D(G(z)): 0.322
epoch[1/40]batch[199/937]: loss_d: 0.550, loss_g: 2.006, D(x): 0.832, D(G(z)): 0.283
epoch[1/40]batch[299/937]: loss_d: 0.449, loss_g: 2.007, D(x): 0.776, D(G(z)): 0.176
epoch[1/40]batch[399/937]: loss_d: 0.686, loss_g: 1.234, D(x): 0.554, D(G(z)): 0.078
epoch[1/40]batch[499/937]: loss_d: 0.298, loss_g: 2.094, D(x): 0.850, D(G(z)): 0.123
epoch[1/40]batch[599/937]: loss_d: 0.186, loss_g: 3.230, D(x): 0.941, D(G(z)): 0.117
epoch[1/40]batch[699/937]: loss_d: 0.197, loss_g: 2.490, D(x): 0.960, D(G(z)): 0.142
epoch[1/40]batch[799/937]: loss_d: 0.557, loss_g: 2.456, D(x): 0.610, D(G(z)): 0.030
epoch[1/40]batch[899/937]: loss_d: 0.030, loss_g: 4.401, D(x): 0.985, D(G(z)): 0.016
epoch[2/40]batch[99/937]: loss_d: 0.028, loss_g: 4.431, D(x): 0.989, D(G(z)): 0.016
epoch[2/40]batch[199/937]: loss_d: 0.618, loss_g: 2.195, D(x): 0.861, D(G(z)): 0.364
epoch[2/40]batch[299/937]: loss_d: 0.152, loss_g: 2.789, D(x): 0.91

In [None]:
plt.figure()
plt.plot(D_losses, label='d loss')
plt.plot(G_losses, label='g loss')    
plt.legend()
plt.savefig('loss_dcgan.png')
plt.show()

plt.figure()
plt.plot(fake_scores, label='fake score')
plt.plot(real_scores, label='real score')    
plt.legend()
plt.savefig('score_dcgan.png')
plt.show()