In [3]:
%load_ext autoreload
%autoreload

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [4]:
import argparse
import os
import sys
import torch
import torchvision
from torch.backends import cudnn
from torch import optim
from torch.autograd import Variable
from torch.utils import data
from torchvision import transforms
from torchvision import datasets
from PIL import Image

import numpy as np

from matplotlib import pyplot as plt

In [5]:
from utils import *
# import infogan2

In [6]:
# test_noise = noise_sample(10,10,1,1,1,'cpu')
params = {}
params['num_z']      = 62
params['dis_c_dim']  = 10
params['num_con_c']  =  2

params['batch_size'] = 10
params['num_epochs'] = 30
params['epochs'] = 100

params['image_size']  = 28
params['num_workers'] = 1

#optimization params
params['lrD']   = 0.0002  # Learning rate for Discriminator
params['lrG']   = 0.001   # Learning rate for Generator
params['beta1'] = 0.5   # Momentum 1 in Adam
params['beta2'] = 0.999 # Momentum 2 in Adam

######################
test_noise = noise_sample(params['dis_c_dim'], 
                          params['num_con_c'], params['num_z'],
                          params['batch_size'],'cpu')

In [7]:
test_noise[0].size()

torch.Size([10, 74, 1, 1])

In [8]:
class Generator(nn.Module):
    def __init__(self, z_dim=62, cc_dim=2, dc_dim=10):
        super(Generator, self).__init__()
        i_dim = z_dim + cc_dim + dc_dim
        self.main = nn.Sequential(
          nn.ConvTranspose2d(i_dim, 1024, 1, 1, bias=False),
          nn.BatchNorm2d(1024),
          nn.ReLU(True),
          nn.ConvTranspose2d(1024, 128, 7, 1, bias=False),
          nn.BatchNorm2d(128),
          nn.ReLU(True),
          nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
          nn.BatchNorm2d(64),
          nn.ReLU(True),
          nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
          nn.Sigmoid()
    )
    def forward(self, z):
        out = self.main(z)
        return out
    
class Discriminator(nn.Module):
    def __init__(self, cc_dim = 1, dc_dim = 10):
        super(Discriminator, self).__init__()
        self.cc_dim = cc_dim
        self.dc_dim = dc_dim
        
        self.main = nn.Sequential(
          nn.Conv2d(1, 64, 4, 2, 1),
          nn.LeakyReLU(0.1, inplace=True),
          nn.Conv2d(64, 128, 4, 2, 1, bias=False),
          nn.BatchNorm2d(128),
          nn.LeakyReLU(0.1, inplace=True),
          nn.Conv2d(128, 1024, 7, bias=False),
          nn.BatchNorm2d(1024),
          nn.LeakyReLU(0.1, inplace=True),
        )
        
        self.linear = nn.Sequential(
            nn.Conv2d(1024, 1, 1),
            nn.Sigmoid()
        )
        
    def front_end(self, x):
        return self.main(x)
    
    def forward(self, x):
        out = self.main(x)
        out = self.linear(out).squeeze()
        return out #prob of being real

class Q(nn.Module):
    def __init__(self, cc_dim=2, dc_dim=10):
        super(Q, self).__init__()
        self.conv = nn.Conv2d(1024, 128, 1, bias=False)
        self.bn = nn.BatchNorm2d(128)
        self.lReLU = nn.LeakyReLU(0.1, inplace=True)
        self.conv_disc = nn.Conv2d(128, 10, 1)
        self.conv_mu = nn.Conv2d(128, 2, 1)
        self.conv_var = nn.Conv2d(128, 2, 1)
    def forward(self, x):
        y = self.lReLU(self.bn(self.conv(x)))
        
        disc_logits = self.conv_disc(y).squeeze()
        mu = self.conv_mu(y).squeeze()
        var = self.conv_var(y).squeeze().exp()

        return disc_logits, mu, var

In [17]:
gen = Generator(params['num_z'], params['num_con_c'], params['dis_c_dim'])
dsc = Discriminator(params['num_con_c'], params['dis_c_dim'])
q   = Q(params['num_con_c'], params['dis_c_dim'])
# gen(test_noise[0]).size()
q(dsc.front_end(gen(test_noise[0])))[2].size()
# dsc.front_end(gen(test_noise[0])).size()

torch.Size([10, 2])

In [18]:
#create dataset
transform = transforms.Compose([
    transforms.Resize((params['image_size'], params['image_size'])),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset = datasets.MNIST('./MNIST', train=True, transform=transform, target_transform=None, download=True)

data_loader = data.DataLoader(dataset = dataset,
                              batch_size = params['batch_size'],
                              shuffle = True,
                              num_workers = params['num_workers'])

In [19]:
g_o = optim.Adam(gen.parameters(), 
                 lr=params['lrG'], betas=[params['beta1'], params['beta2']])
d_o = optim.Adam([{'params': dsc.parameters(), 'params': q.parameters()}], 
                 lr=params['lrD'], betas=[params['beta1'], params['beta2']])

In [20]:
device = 'cpu'
if torch.cuda.is_available():
    print('using cuda')
    gen.cuda()
    dsc.cuda()
    q.cuda()
    device = 'cuda'

gen.apply(weights_init)
dsc.apply(weights_init)
q.apply(weights_init)

using cuda


Q(
  (conv): Conv2d(1024, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)
  (bn): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (lReLU): LeakyReLU(negative_slope=0.1, inplace)
  (conv_disc): Conv2d(128, 10, kernel_size=(1, 1), stride=(1, 1))
  (conv_mu): Conv2d(128, 2, kernel_size=(1, 1), stride=(1, 1))
  (conv_var): Conv2d(128, 2, kernel_size=(1, 1), stride=(1, 1))
)

In [21]:
G_losses = []
D_losses = []

BCE = nn.BCELoss().to(device) #Binary Cross Entropy loss
CE  = nn.CrossEntropyLoss().to(device)
CQ  = NormalNLLLoss()

for epoch in range(params['epochs']):
    for n_i, batch_data in enumerate(data_loader):
        ##########################
        # Optimize Discriminator #
        ##########################
        d_o.zero_grad()
        ####### Real Data  #######
        real_im   = batch_data[0].to(device) # batch data also contains label info
        real_prob = dsc(real_im)  # propability of classifying as real
#         label.data.fill_(1)
        label = torch.full((params['batch_size'], ), 1, device=device)
        loss_real = BCE(real_prob, label) #
        loss_real.backward()
        ####### Fake Data  #######
        noise, idx = noise_sample(params['dis_c_dim'], 
                          params['num_con_c'], params['num_z'],
                          params['batch_size'],device)
        fake_im   = gen(noise)
        fake_prob = dsc(real_im)  # propability of classifying as real
        label.data.fill_(0)
        loss_fake = BCE(fake_prob, label)
        loss_fake.backward()
        
        discriminator_loss = loss_real + loss_fake
        d_o.step()
        
        ##########################
        #   Optimize Generator   #
        ##########################
        g_o.zero_grad()
        
        label.fill_(1)
        fake_prob = dsc(real_im)  # propability of classifying as real
        reconstruct_loss = BCE(fake_prob, label)
        
        q_logits, q_mu, q_var = q(dsc.front_end(fake_im))
        target = torch.LongTensor(idx).to(device)
        # Calculating loss for discrete latent code.
        dis_loss = CE(q_logits, target)
        # Calculating loss for continuous latent code.
        con_c = noise[:, params['num_z'] + params['dis_c_dim'] : ].view(-1, params['num_con_c'])

        con_loss = CQ(con_c, q_mu, q_var) * 0.1
        # Generator Loss (Reconstruct, Discreate and Latent code)
        generator_loss = reconstruct_loss + dis_loss + con_loss
        generator_loss.backward()
        g_o.step()
        ##########################
        #      Logging Part      #
        ##########################
        if (n_i) % 100 == 0 and n_i > 0:
            sys.stdout.write('Epoch/Iter:{0}/{1}, D_Loss: {2}, G_Loss: {3}\r'.format(
                epoch + 1, n_i, discriminator_loss.item(), generator_loss.item())
            )
            sys.stdout.flush()
    # Save the losses for plotting.
    G_losses.append(generator_loss.item())
    D_losses.append(discriminator_loss.item())
    print('')

Epoch/Iter:1/3200, D_Loss: 1.4587225914001465, G_Loss: 1.9764257669448853

KeyboardInterrupt: 

In [None]:
plt.plot(G_losses,label="G")
plt.plot(D_losses,label="D")
plt.show()

In [None]:
torch.save(gen.state_dict(), 'mnist_info_gan_gen')
torch.save(dsc.state_dict(), 'mnist_info_gan_dsc')