In [17]:
"""imports"""

import os #,argparse 
import gzip
import torch.nn as nn
import numpy as np
import scipy.misc
import imageio
import matplotlib.pyplot as plt

import torch, time, pickle
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
from torch.utils.data import DataLoader
from torchvision import datasets, transforms

In [3]:
"""input arguments"""

dataset = 'mnist'
#dataset = 'celebA'
epoch = 1 #25
batch_size = 64
sample_num = 100 #16
save_dir = './models'
result_dir = './results'
log_dir = './logs'
lrG = 0.0002
lrD = 0.0002
beta1 = 0.5
beta2 = 0.999
gpu_mode = False
model_name = 'GAN'

c = 0.01     # clipping value
n_critic = 5 # the number of iterations of the critic per generator iteration

In [4]:
"""checking arguments"""

# --save_dir
if not os.path.exists(save_dir):
    os.makedirs(save_dir)

# --result_dir
if not os.path.exists(result_dir):
    os.makedirs(result_dir)

# --result_dir
if not os.path.exists(log_dir):
    os.makedirs(log_dir)

# --epoch
try:
    assert epoch >= 1
except:
    print('number of epochs must be larger than or equal to one')

# --batch_size
try:
    assert batch_size >= 1
except:
    print('batch size must be larger than or equal to one')

In [5]:
"""print network"""

def print_network(net):
    num_params = 0
    for param in net.parameters():
        num_params += param.numel()
    print(net)
    print('Total number of parameters: %d' % num_params)

In [6]:
"""save images"""

def save_images(images, size, image_path):
    return imsave(images, size, image_path)

def imsave(images, size, path):
    image = np.squeeze(merge(images, size))
    return scipy.misc.imsave(path, image)

"""merge images"""

def merge(images, size):
    h, w = images.shape[1], images.shape[2]
    if (images.shape[3] in (3,4)):
        c = images.shape[3]
        img = np.zeros((h * size[0], w * size[1], c))
        for idx, image in enumerate(images):
            i = idx % size[1]
            j = idx // size[1]
            img[j * h:j * h + h, i * w:i * w + w, :] = image
        return img
    elif images.shape[3]==1:
        img = np.zeros((h * size[0], w * size[1]))
        for idx, image in enumerate(images):
            i = idx % size[1]
            j = idx // size[1]
            img[j * h:j * h + h, i * w:i * w + w] = image[:,:,0]
        return img
    else:
        raise ValueError('in merge(images,size) images parameter ''must have dimensions: HxW or HxWx3 or HxWx4')

"""generate animation"""
        
def generate_animation(path, num):
    images = []
    for e in range(num):
        img_name = path + '_epoch%03d' % (e+1) + '.png'
        images.append(imageio.imread(img_name))
    imageio.mimsave(path + '_generate_animation.gif', images, fps=5)

In [7]:
"""plot loss"""

def loss_plot(hist, path = 'Train_hist.png', model_name = ''):
    x = range(len(hist['D_loss']))

    y1 = hist['D_loss']
    y2 = hist['G_loss']

    plt.plot(x, y1, label='D_loss')
    plt.plot(x, y2, label='G_loss')

    plt.xlabel('Iter')
    plt.ylabel('Loss')

    plt.legend(loc=4)
    plt.grid(True)
    plt.tight_layout()

    path = os.path.join(path, model_name + '_loss.png')

    plt.savefig(path)

    plt.close()

In [8]:
"""initialize weights"""

def initialize_weights(net):
    for m in net.modules():
        if isinstance(m, nn.Conv2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.ConvTranspose2d):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.weight.data.normal_(0, 0.02)
            m.bias.data.zero_()

In [9]:
"""load celebA"""

def load_celebA(dir, transform, batch_size, shuffle):
    # transform = transforms.Compose([
    #     transforms.CenterCrop(160),
    #     transform.Scale(64)
    #     transforms.ToTensor(),
    #     transforms.Normalize(mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5))
    # ])

    # data_dir = 'data/celebA'  # this path depends on your computer
    dset = datasets.ImageFolder(dir, transform)
    data_loader = torch.utils.data.DataLoader(dset, batch_size, shuffle)

    return data_loader


In [10]:
"""generator"""

class generator(nn.Module):
    # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
    # Architecture : FC1024_BR-FC7x7x128_BR-(64)4dc2s_BR-(1)4dc2s_S
    def __init__(self):
        #print('---------- generator -------------')
        super(generator, self).__init__()
        if dataset == 'mnist' or dataset == 'fashion-mnist':
            self.input_height = 28
            self.input_width = 28
            self.input_dim = 62
            self.output_dim = 1
        elif dataset == 'celebA':
            self.input_height = 64
            self.input_width = 64
            self.input_dim = 62
            self.output_dim = 3

        self.fc = nn.Sequential(
            nn.Linear(self.input_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 128 * (self.input_height // 4) * (self.input_width // 4)),
            nn.BatchNorm1d(128 * (self.input_height // 4) * (self.input_width // 4)),
            nn.ReLU(),
        )
        self.deconv = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64),
            nn.ReLU(),
            nn.ConvTranspose2d(64, self.output_dim, 4, 2, 1),
            nn.Sigmoid(),
        )
        #utils.
        initialize_weights(self)

    def forward(self, input):
        x = self.fc(input)
        x = x.view(-1, 128, (self.input_height // 4), (self.input_width // 4))
        x = self.deconv(x)

        return x

In [11]:
"""discriminator"""

class discriminator(nn.Module):
    # Network Architecture is exactly same as in infoGAN (https://arxiv.org/abs/1606.03657)
    # Architecture : (64)4c2s-(128)4c2s_BL-FC1024_BL-FC1_S
    def __init__(self):
        super(discriminator, self).__init__()
        if dataset == 'mnist' or dataset == 'fashion-mnist':
            self.input_height = 28
            self.input_width = 28
            self.input_dim = 1
            self.output_dim = 1
        elif dataset == 'celebA':
            self.input_height = 64
            self.input_width = 64
            self.input_dim = 3
            self.output_dim = 1

        self.conv = nn.Sequential(
            nn.Conv2d(self.input_dim, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
        )
        self.fc = nn.Sequential(
            nn.Linear(128 * (self.input_height // 4) * (self.input_width // 4), 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, self.output_dim),
            nn.Sigmoid(),
        )
        #utils.
        initialize_weights(self)
        
    def forward(self, input):
        x = self.conv(input)
        x = x.view(-1, 128 * (self.input_height // 4) * (self.input_width // 4))
        x = self.fc(x)

        return x

In [15]:
class WGAN():
    def __init__(self):
        # parameters
        #self.epoch = args.epoch
        #self.sample_num = 16
        #self.batch_size = args.batch_size
        #self.save_dir = args.save_dir
        #self.result_dir = args.result_dir
        #self.dataset = 'mnist'
        #self.log_dir = args.log_dir
        #self.gpu_mode = args.gpu_mode
        #self.model_name = args.gan_type

        # networks init
        self.G = generator()
        self.D = discriminator()
        #print('---------- GAN -------------')
        self.G_optimizer = optim.Adam(self.G.parameters(), lr=lrG, betas=(beta1, beta2))
        self.D_optimizer = optim.Adam(self.D.parameters(), lr=lrD, betas=(beta1, beta2))

        if gpu_mode:
            self.G.cuda()
            self.D.cuda()

        print('---------- Networks architecture -------------')
        #utils.
        print_network(self.G)
        print('-----------------------------------------------')
        #utils.
        print_network(self.D)
        print('-----------------------------------------------')

        # load dataset
        if dataset == 'mnist':
            self.data_loader = DataLoader(datasets.MNIST('data/mnist', train=True, download=True,
                                                                          transform=transforms.Compose(
                                                                              [transforms.ToTensor()])),
                                                           batch_size=batch_size, shuffle=True)
            print "Size of %s data loader : " % (dataset), self.data_loader.dataset.__len__()
            print('-----------------------------------------------')
            
        elif dataset == 'fashion-mnist':
            self.data_loader = DataLoader(
                datasets.FashionMNIST('data/fashion-mnist', train=True, download=True, transform=transforms.Compose(
                    [transforms.ToTensor()])),
                batch_size=batch_size, shuffle=True)
            print "Size of %s data loader : " % (dataset), self.data_loader.dataset.__len__()
            print('-----------------------------------------------')
            
        elif dataset == 'celebA':
            self.data_loader = load_celebA('data/celebA', transform=transforms.Compose(
                [transforms.CenterCrop(160), transforms.Scale(64), transforms.ToTensor()]), batch_size=batch_size, shuffle=True)
            print "Size of %s data loader : " % (dataset), self.data_loader.dataset.__len__()
            print('-----------------------------------------------')
            
        self.z_dim = 62

        # fixed noise
        if gpu_mode:
            self.sample_z_ = Variable(torch.rand((batch_size, self.z_dim)).cuda(), volatile=True)
        else:
            self.sample_z_ = Variable(torch.rand((batch_size, self.z_dim)), volatile=True)

    def train(self):
        self.train_hist = {}
        self.train_hist['D_loss'] = []
        self.train_hist['G_loss'] = []
        self.train_hist['per_epoch_time'] = []
        self.train_hist['total_time'] = []

        if gpu_mode:
            self.y_real_, self.y_fake_ = Variable(torch.ones(batch_size, 1).cuda()), Variable(torch.zeros(batch_size, 1).cuda())
        else:
            self.y_real_, self.y_fake_ = Variable(torch.ones(batch_size, 1)), Variable(torch.zeros(batch_size, 1))

        self.D.train()
        print('training start!!')
        start_time = time.time()
        for epochs in range(epoch):
            self.G.train()
            epoch_start_time = time.time()
            for iter, (x_, _) in enumerate(self.data_loader):
                if iter == self.data_loader.dataset.__len__() // batch_size:
                    break

                z_ = torch.rand((batch_size, self.z_dim))

                if gpu_mode:
                    x_, z_ = Variable(x_.cuda()), Variable(z_.cuda())
                else:
                    x_, z_ = Variable(x_), Variable(z_)

                # update D network
                self.D_optimizer.zero_grad()

                D_real = self.D(x_)
                D_real_loss = -torch.mean(D_real)

                G_ = self.G(z_)
                D_fake = self.D(G_)
                D_fake_loss = torch.mean(D_fake)

                D_loss = D_real_loss + D_fake_loss
                #self.train_hist['D_loss'].append(D_loss.data[0])

                D_loss.backward()
                self.D_optimizer.step()

                # clipping D
                for p in self.D.parameters():
                    p.data.clamp_(-c, c)

                if ((iter+1) % n_critic) == 0:
                    # update G network
                    self.G_optimizer.zero_grad()

                    G_ = self.G(z_)
                    D_fake = self.D(G_)
                    G_loss = -torch.mean(D_fake)
                    self.train_hist['G_loss'].append(G_loss.data[0])

                    G_loss.backward()
                    self.G_optimizer.step()

                    self.train_hist['D_loss'].append(D_loss.data[0])

                if ((iter + 1) % 100) == 0:
                    print("Epoch: [%2d] [%4d/%4d] D_loss: %.8f, G_loss: %.8f" %
                          ((epochs + 1), (iter + 1), self.data_loader.dataset.__len__() // batch_size, D_loss.data[0], G_loss.data[0]))

            self.train_hist['per_epoch_time'].append(time.time() - epoch_start_time)
            self.visualize_results((epochs+1))

        self.train_hist['total_time'].append(time.time() - start_time)
        print("Avg one epoch time: %.2f, total %d epochs time: %.2f" % (np.mean(self.train_hist['per_epoch_time']),
              epoch, self.train_hist['total_time'][0]))
        print("Training finish!... save training results")

        self.save()
        #utils.
        generate_animation(result_dir + '/' + dataset + '/' + model_name + '/' + model_name,
                                 epoch)
        #utils.
        loss_plot(self.train_hist, os.path.join(save_dir, dataset, model_name), model_name)

    def visualize_results(self, epochs, fix=True):
        self.G.eval()

        if not os.path.exists(result_dir + '/' + dataset + '/' + model_name):
            os.makedirs(result_dir + '/' + dataset + '/' + model_name)

        tot_num_samples = min(sample_num, batch_size)
        image_frame_dim = int(np.floor(np.sqrt(tot_num_samples)))

        if fix:
            """ fixed noise """
            samples = self.G(self.sample_z_)
        else:
            """ random noise """
            if self.gpu_mode:
                sample_z_ = Variable(torch.rand((batch_size, self.z_dim)).cuda(), volatile=True)
            else:
                sample_z_ = Variable(torch.rand((batch_size, self.z_dim)), volatile=True)

            samples = self.G(sample_z_)

        if gpu_mode:
            samples = samples.cpu().data.numpy().transpose(0, 2, 3, 1)
        else:
            samples = samples.data.numpy().transpose(0, 2, 3, 1)

        #utils.
        save_images(samples[:image_frame_dim * image_frame_dim, :, :, :], [image_frame_dim, image_frame_dim],
                          result_dir + '/' + dataset + '/' + model_name + '/' + model_name + '_epoch%03d' % epochs + '.png')
    
    def forward(self, input):
        x = self.conv(input)
        x = x.view(-1, 128 * (self.input_height // 4) * (self.input_width // 4))
        x = self.fc(x)

        return x


    def save(self):
        save_dir_full = os.path.join(save_dir, dataset, model_name)

        if not os.path.exists(save_dir_full):
            os.makedirs(save_dir_full)

        torch.save(self.G.state_dict(), os.path.join(save_dir_full, model_name + '_G.pkl'))
        torch.save(self.D.state_dict(), os.path.join(save_dir_full, model_name + '_D.pkl'))

        with open(os.path.join(save_dir_full, model_name + '_history.pkl'), 'wb') as f:
            pickle.dump(self.train_hist, f)

    def load(self):
        save_dir_full = os.path.join(save_dir, dataset, model_name)

        self.G.load_state_dict(torch.load(os.path.join(save_dir_full, model_name + '_G.pkl')))
        self.D.load_state_dict(torch.load(os.path.join(save_dir_full, model_name + '_D.pkl')))

In [18]:
"""run GAN"""

gan = WGAN()

---------- Networks architecture -------------
generator (
  (fc): Sequential (
    (0): Linear (62 -> 1024)
    (1): BatchNorm1d(1024, eps=1e-05, momentum=0.1, affine=True)
    (2): ReLU ()
    (3): Linear (1024 -> 6272)
    (4): BatchNorm1d(6272, eps=1e-05, momentum=0.1, affine=True)
    (5): ReLU ()
  )
  (deconv): Sequential (
    (0): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True)
    (2): ReLU ()
    (3): ConvTranspose2d(64, 1, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (4): Sigmoid ()
  )
)
Total number of parameters: 6640193
-----------------------------------------------
discriminator (
  (conv): Sequential (
    (0): Conv2d(1, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): LeakyReLU (0.2)
    (2): Conv2d(64, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True)
    (4): LeakyReLU (0.2)
  )
  

In [19]:
"""train GAN"""

gan.train()
print(" [*] Training finished!")


training start!!
Epoch: [ 1] [ 100/ 937] D_loss: -0.00425082, G_loss: -0.49844608
Epoch: [ 1] [ 200/ 937] D_loss: -0.00443804, G_loss: -0.49809074
Epoch: [ 1] [ 300/ 937] D_loss: -0.00451562, G_loss: -0.49785259
Epoch: [ 1] [ 400/ 937] D_loss: -0.00467357, G_loss: -0.49759185
Epoch: [ 1] [ 500/ 937] D_loss: -0.00453326, G_loss: -0.49762213
Epoch: [ 1] [ 600/ 937] D_loss: -0.00469473, G_loss: -0.49750751
Epoch: [ 1] [ 700/ 937] D_loss: -0.00471956, G_loss: -0.49749386
Epoch: [ 1] [ 800/ 937] D_loss: -0.00466073, G_loss: -0.49747309
Epoch: [ 1] [ 900/ 937] D_loss: -0.00472355, G_loss: -0.49756277
Avg one epoch time: 407.00, total 1 epochs time: 407.14
Training finish!... save training results
 [*] Training finished!


In [20]:
"""test GAN"""

# visualize learned generator
gan.visualize_results(epoch)
print(" [*] Testing finished!")

 [*] Testing finished!
