In [1]:
# adapted (copy pasted) from https://github.com/znxlwm/pytorch-MNIST-CelebA-GAN-DCGAN
import os, time
import matplotlib.pyplot as plt
import itertools
import pickle
# import imageio
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
# from torch.autograd import Variable

In [None]:
# current with mnist_dcgan.py as of 4/3

In [121]:
# G(z)
class generator(nn.Module):
    def __init__(self, d=64):
        super(generator, self).__init__()
        self.d = d
        self.linear = nn.Linear(100, 2*2*d*8)
        self.linear_bn = nn.BatchNorm1d(2*2*d*8)
        self.deconv1 = nn.ConvTranspose2d(d*8, d*4, 5, 2, 1) # changed things
        self.deconv1_bn = nn.BatchNorm2d(d*4)
        self.deconv2 = nn.ConvTranspose2d(d*4, d*2, 5, 2, 2)
        self.deconv2_bn = nn.BatchNorm2d(d*2)
        self.deconv3 = nn.ConvTranspose2d(d*2, d, 5, 2, 1)
        self.deconv3_bn = nn.BatchNorm2d(d)
        self.deconv4 = nn.ConvTranspose2d(d, 3, 5, 2, 1)
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)
    def forward(self, input):
        x = F.relu(self.linear_bn(self.linear(input)))
        # print(x.shape)
        x = x.view(-1, self.d*8, 2, 2)
        # print(x.shape)
        x = F.relu(self.deconv1_bn(self.deconv1(x)))
        x = x[:,:,:-1,:-1] # hacky way to get shapes right (like "SAME" in tf)
        # print(x.shape)
        x = F.relu(self.deconv2_bn(self.deconv2(x)))
        # print(x.shape)
        x = F.relu(self.deconv3_bn(self.deconv3(x)))
        x = x[:,:,:-1,:-1]
        # print(x.shape)
        x = torch.tanh(self.deconv4(x))
        x = x[:,:,:-1,:-1]
        # print(x.shape)
        return x

In [122]:
import time
G = generator()
start = time.time()
x = torch.zeros(43, 100)
y = G(x)
print(time.time()-start)

0.30336642265319824


In [123]:
class discriminator(nn.Module):
    def __init__(self, d=64):
        super(discriminator, self).__init__()
        self.d = d
        self.conv1 = nn.Conv2d(3, d, 5, 2, 2)
        self.conv2 = nn.Conv2d(d, d*2, 5, 2, 2)
        self.conv2_bn = nn.BatchNorm2d(d*2)
        self.conv3 = nn.Conv2d(d*2, d*4, 5, 2, 2)
        self.conv3_bn = nn.BatchNorm2d(d*4)
        self.conv4 = nn.Conv2d(d*4, d*8, 5, 2, 2)
        self.conv4_bn = nn.BatchNorm2d(d*8)
        self.linear = nn.Linear(2*2*d*8, 1)
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)
    def forward(self, input):
        x = F.leaky_relu(self.conv1(input), 0.2)
        # print(x.shape)
        x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)
        # print(x.shape)
        x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)
        # print(x.shape)
        x = F.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2)
        # print(x.shape)
        x = x.view(-1, 2*2*self.d*8)
        x = torch.sigmoid(self.linear(x))
        # print(x.shape)
        return x

In [124]:
start = time.time()
D = discriminator()
x = torch.zeros(43, 3, 28, 28)
y = D(x)
print(time.time()-start)

0.24593663215637207


In [125]:
def normal_init(m, mean, std):
    if isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Conv2d):
        m.weight.data.normal_(mean, std)
        m.bias.data.zero_()

fixed_z_ = torch.randn((5 * 5, 100)).view(-1, 100)    # fixed noise
# fixed_z_ = Variable(fixed_z_.cuda(), volatile=True)

In [13]:
def show_result(num_epoch, show = False, save = False, path = 'result.png', isFix=False):
    z_ = torch.randn((5*5, 100)).view(-1, 100)
    # z_ = Variable(z_.cuda(), volatile=True)
    G.eval()
    if isFix:
        test_images = G(fixed_z_)
    else:
        test_images = G(z_)
    G.train()
    size_figure_grid = 5
    fig, ax = plt.subplots(size_figure_grid, size_figure_grid, figsize=(5, 5))
    for i, j in itertools.product(range(size_figure_grid), range(size_figure_grid)):
        ax[i, j].get_xaxis().set_visible(False)
        ax[i, j].get_yaxis().set_visible(False)
    for k in range(5*5):
        i = k // 5
        j = k % 5
        ax[i, j].cla()
        ti = test_images[k].cpu().data.numpy().transpose(1,2,0) # all 3 channels!
        ax[i, j].imshow(ti, cmap='gray')
    label = 'Epoch {0}'.format(num_epoch)
    fig.text(0.5, 0.04, label, ha='center')
    plt.savefig(path)
    if show:
        plt.show()
    else:
        plt.close()
        
def show_train_hist(hist, show = False, save = False, path = 'Train_hist.png'):
    x = range(len(hist['D_losses']))
    y1 = hist['D_losses']
    y2 = hist['G_losses']
    plt.plot(x, y1, label='D_loss')
    plt.plot(x, y2, label='G_loss')
    plt.xlabel('Iter')
    plt.ylabel('Loss')
    plt.legend(loc=4)
    plt.grid(True)
    plt.tight_layout()
    if save:
        plt.savefig(path)
    if show:
        plt.show()
    else:
        plt.close()

In [14]:
# training parameters
batch_size = 63
lr = 0.0002
train_epoch = 50

# data_loader
img_size = 28
transform = transforms.Compose([
        transforms.Scale(img_size),
        transforms.ToTensor(),
        transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
])
train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True, transform=transform),
    batch_size=batch_size, shuffle=True)

In [128]:
# 60000 dataset stacked is 20000
# repeat 6 times per epoch to get 120000 (pacgan does 128000)
# alternatively, can use load_mnist() function, but it is much slower
# from load_mnist import *
# img, lab = load_mnist(128000)

In [129]:
G = generator()
D = discriminator()
G.weight_init(mean=0.0, std=0.02)
D.weight_init(mean=0.0, std=0.02)
# G.cuda()
# D.cuda()

# Binary Cross Entropy loss
BCE_loss = nn.BCELoss()

# Adam optimizer
G_optimizer = optim.Adam(G.parameters(), lr=lr, betas=(0.5, 0.999))
D_optimizer = optim.Adam(D.parameters(), lr=lr, betas=(0.5, 0.999))

In [130]:
def stack(x):
    # assert(x.shape[0] % 3 == 0)
    return torch.cat([x[::3], x[1::3], x[2::3]], dim=1)

In [131]:
# results save folder
if not os.path.isdir('MNIST_DCGAN_results'):
    os.mkdir('MNIST_DCGAN_results')
if not os.path.isdir('MNIST_DCGAN_results/Random_results'):
    os.mkdir('MNIST_DCGAN_results/Random_results')
if not os.path.isdir('MNIST_DCGAN_results/Fixed_results'):
    os.mkdir('MNIST_DCGAN_results/Fixed_results')

train_hist = {}
train_hist['D_losses'] = []
train_hist['G_losses'] = []
train_hist['per_epoch_ptimes'] = []
train_hist['total_ptime'] = []
num_iter = 0

In [132]:
x = next(iter(train_loader))[0]
x.shape

torch.Size([63, 1, 28, 28])

In [133]:
print(z_.shape)
zz = z_.squeeze(3).squeeze(2)
y = G(zz)
y.shape

torch.Size([21, 100, 1, 1])


torch.Size([21, 3, 28, 28])

In [134]:
print('training start!')
start_time = time.time()
for epoch in range(train_epoch):
    D_losses = []
    G_losses = []
    epoch_start_time = time.time()
    for i in range(6):
        for x_, _ in train_loader:
            # train discriminator D
            D.zero_grad()

            x_ = stack(x_) # new!
            mini_batch = x_.size()[0]

            y_real_ = torch.ones(mini_batch)
            y_fake_ = torch.zeros(mini_batch)

            # x_, y_real_, y_fake_ = Variable(x_.cuda()), Variable(y_real_.cuda()), Variable(y_fake_.cuda())
            D_result = D(x_).squeeze()
            D_real_loss = BCE_loss(D_result, y_real_)

            z_ = torch.randn((mini_batch, 100)).view(-1, 100)
            # z_ = Variable(z_.cuda())
            G_result = G(z_)

            D_result = D(G_result).squeeze()
            D_fake_loss = BCE_loss(D_result, y_fake_)
            D_fake_score = D_result.data.mean()

            D_train_loss = D_real_loss + D_fake_loss

            D_train_loss.backward()
            D_optimizer.step()

            # D_losses.append(D_train_loss.item())
            D_losses.append(D_train_loss.item())

            # train generator G
            # Run g_optim twice to make sure that d_loss does not go to zero (different from paper)
            for j in range(2):
                G.zero_grad()

                z_ = torch.randn((mini_batch, 100)).view(-1, 100)
                # z_ = Variable(z_.cuda())

                G_result = G(z_)
                D_result = D(G_result).squeeze()
                G_train_loss = BCE_loss(D_result, y_real_)
                G_train_loss.backward()
                G_optimizer.step()
                G_losses.append(G_train_loss.item())

            num_iter += 1

    epoch_end_time = time.time()
    per_epoch_ptime = epoch_end_time - epoch_start_time


    print('[%d/%d] - ptime: %.2f, loss_d: %.3f, loss_g: %.3f' % ((epoch + 1), train_epoch, per_epoch_ptime, torch.mean(torch.FloatTensor(D_losses)),
                                                              torch.mean(torch.FloatTensor(G_losses))))
    p = 'MNIST_DCGAN_results/Random_results/MNIST_DCGAN_' + str(epoch + 1) + '.png'
    fixed_p = 'MNIST_DCGAN_results/Fixed_results/MNIST_DCGAN_' + str(epoch + 1) + '.png'
    show_result((epoch+1), save=True, path=p, isFix=False)
    show_result((epoch+1), save=True, path=fixed_p, isFix=True)
    train_hist['D_losses'].append(torch.mean(torch.FloatTensor(D_losses)))
    train_hist['G_losses'].append(torch.mean(torch.FloatTensor(G_losses)))
    train_hist['per_epoch_ptimes'].append(per_epoch_ptime)

    if epoch % 2 == 0:
        torch.save(G.state_dict(), "MNIST_DCGAN_results/generator_param.pkl")
        torch.save(D.state_dict(), "MNIST_DCGAN_results/discriminator_param.pkl") # for safety!

training start!


KeyboardInterrupt: 

In [None]:
end_time = time.time()
total_ptime = end_time - start_time
train_hist['total_ptime'].append(total_ptime)

print("Avg per epoch ptime: %.2f, total %d epochs ptime: %.2f" % (torch.mean(torch.FloatTensor(train_hist['per_epoch_ptimes'])), train_epoch, total_ptime))
print("Training finish!... save training results")
torch.save(G.state_dict(), "MNIST_DCGAN_results/generator_param.pkl")
torch.save(D.state_dict(), "MNIST_DCGAN_results/discriminator_param.pkl")
with open('MNIST_DCGAN_results/train_hist.pkl', 'wb') as f:
    pickle.dump(train_hist, f)

show_train_hist(train_hist, save=True, path='MNIST_DCGAN_results/MNIST_DCGAN_train_hist.png')

# images = []
# for e in range(train_epoch):
#     img_name = 'MNIST_DCGAN_results/Fixed_results/MNIST_DCGAN_' + str(e + 1) + '.png'
#     images.append(imageio.imread(img_name))
# imageio.mimsave('MNIST_DCGAN_results/generation_animation.gif', images, fps=5)

Extra

In [113]:
class generator(nn.Module):
    # initializers
    def __init__(self, d=128):
        super(generator, self).__init__()
        self.deconv1 = nn.ConvTranspose2d(100, d*4, 4, 1, 0) # changed things
        self.deconv1_bn = nn.BatchNorm2d(d*4)
        self.deconv2 = nn.ConvTranspose2d(d*4, d*2, 3, 2, 1)
        self.deconv2_bn = nn.BatchNorm2d(d*2)
        self.deconv3 = nn.ConvTranspose2d(d*2, d, 4, 2, 1)
        self.deconv3_bn = nn.BatchNorm2d(d)
        self.deconv4 = nn.ConvTranspose2d(d, 3, 4, 2, 1) # 1 to 3, 4 to 3
        # self.deconv4 = nn.ConvTranspose2d(d*2, d, 4, 2, 1)
        # self.deconv4_bn = nn.BatchNorm2d(d)
        # self.deconv5 = nn.ConvTranspose2d(d, 3, 4, 2, 1) # 1 to 3
    # weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    # forward method
    def forward(self, input):
        # x = F.relu(self.deconv1(input))
        x = F.relu(self.deconv1_bn(self.deconv1(input)))
        x = F.relu(self.deconv2_bn(self.deconv2(x)))
        x = F.relu(self.deconv3_bn(self.deconv3(x)))
        x = torch.tanh(self.deconv4(x))
        # x = F.relu(self.deconv4_bn(self.deconv4(x)))
        # x = torch.tanh(self.deconv5(x)) # deprecated
        return x

In [114]:
G = generator(32)
z = torch.randn((43, 100)).view(-1, 100, 1, 1)

In [115]:
print(G.deconv1(z).shape)
print(G.deconv2(G.deconv1(z)).shape)
print(G.deconv3(G.deconv2(G.deconv1(z))).shape)
print(G.deconv4(G.deconv3(G.deconv2(G.deconv1(z)))).shape)

torch.Size([43, 128, 4, 4])
torch.Size([43, 64, 7, 7])
torch.Size([43, 32, 14, 14])
torch.Size([43, 3, 28, 28])


In [128]:
class discriminator(nn.Module):
    # initializers
    def __init__(self, d=128):
        super(discriminator, self).__init__()
        self.conv1 = nn.Conv2d(3, d, 4, 2, 1) # changed things
        self.conv2 = nn.Conv2d(d, d*2, 4, 2, 1)
        self.conv2_bn = nn.BatchNorm2d(d*2)
        self.conv3 = nn.Conv2d(d*2, d*4, 3, 2, 1)
        self.conv3_bn = nn.BatchNorm2d(d*4)
        self.conv4 = nn.Conv2d(d*4, 1, 4, 1, 0)
        # self.conv4 = nn.Conv2d(d*4, d*8, 4, 2, 1)
        # self.conv4_bn = nn.BatchNorm2d(d*8)
        # self.conv5 = nn.Conv2d(d*8, 1, 4, 1, 0)

    # weight_init
    def weight_init(self, mean, std):
        for m in self._modules:
            normal_init(self._modules[m], mean, std)

    # forward method
    def forward(self, input):
        x = F.leaky_relu(self.conv1(input), 0.2)
        x = F.leaky_relu(self.conv2_bn(self.conv2(x)), 0.2)
        x = F.leaky_relu(self.conv3_bn(self.conv3(x)), 0.2)
        x = torch.sigmoid(self.conv4(x))
        # x = F.leaky_relu(self.conv4_bn(self.conv4(x)), 0.2)
        # x = F.sigmoid(self.conv5(x))
        return x

In [129]:
D = discriminator(32)
# x = stack(next(iter(train_loader))[0])
x = torch.zeros(43,3,28,28)
x.shape

torch.Size([43, 3, 28, 28])

In [130]:
print(D.conv1(x).shape)
print(D.conv2(D.conv1(x)).shape)
print(D.conv3(D.conv2(D.conv1(x))).shape)
print(D.conv4(D.conv3(D.conv2(D.conv1(x)))).shape)
print(torch.sigmoid(D.conv4(D.conv3(D.conv2(D.conv1(x))))).shape)

torch.Size([43, 32, 14, 14])
torch.Size([43, 64, 7, 7])
torch.Size([43, 128, 4, 4])
torch.Size([43, 1, 1, 1])
torch.Size([43, 1, 1, 1])
