# InfoGAN Demo
A quick demo of InfoGAN

## Packages & Parameter Setting

In [1]:
from InfoGAN import Generator, DiscriminatorFrontEnd, DiscriminatorBackend,DiscriminatorInfo
from util import *
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable

import torchvision.datasets as dset
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

import matplotlib.pyplot as plt
% matplotlib inline

BATCH_SIZE = 100
NUM_EPOCHS = 200
USE_GPU = True

DISPLAY_STEP = 100
PLOT_EPOCH = 5

# Dataset Preparation
Set Download to True if MNIST is not availible on your machine

In [2]:
mnist_dataset = dset.MNIST(root='./data/',transform=transforms.ToTensor(),download=False)
dataloader = DataLoader(mnist_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=1)

# Model Construction

For details about model implementation, please refer to `InfoGAN.py`

In [None]:
D = DiscriminatorFrontEnd()
G = Generator()
TF = DiscriminatorBackend()
Q = DiscriminatorInfo()

# Loss function
D_criterion = nn.BCEWithLogitsLoss()
Q_discr_criterion = nn.CrossEntropyLoss()
Q_conti_criterion = nn.MSELoss()


if USE_GPU:
    D = D.cuda()
    G = G.cuda()
    TF = TF.cuda()
    Q = Q.cuda()
    D_criterion = D_criterion.cuda()
    Q_discr_criterion = Q_discr_criterion.cuda()
    Q_conti_criterion = Q_conti_criterion.cuda()

optimD = optim.Adam([{'params':D.parameters()}, {'params':TF.parameters()}], lr=0.0002)
optimG = optim.Adam([{'params':G.parameters()}, {'params':Q.parameters()}], lr=0.001)

# Training Progress

Training Log will be stored as `infogan.log` under `InfoGAN_pytorch/`

The result of generator with fixed 10-class input along with 1 countinous latent code assigned in range(-2,2,0.5)

will be stored under `InfoGAN_pytorch/` automatically every 5 epochs.



## WorkFlow
### Step 1. Training Discriminator
   
   - Sample random noise to create fake image
   - Train Discriminator with Real image from MNIST and Fake image from Generator

### Step 2. Training Generator and Q

   - Update Generator with fixed D, loss = D's classification error
   - According to the paper proposing InfoGAN, jointly update Q together (important)

In [None]:
training_message = 'epoch-{:3}-step-{:3}-D_loss-{:.6f}-GQ_loss-{:.4f}'
epoch_end_message = 'epoch-{:3}-D_loss-{:.4f}-GQ_loss-{:.4f}-Image_loss-{:.4f}-Disc_loss-{:.4f}-Conti_loss-{:.4f}'

log = open('infogan.log','w')
log_message = '{:.4f},{:.4f},{:.4f},{:.4f}\n'
log.write('D_loss,G_loss,Disc_loss,Conti_loss\n')


demo_z1, demo_z2 = get_test_noise()
demo_z1 = Variable(demo_z1.cuda()) if USE_GPU else Variable(demo_z1)
demo_z2 = Variable(demo_z2.cuda()) if USE_GPU else Variable(demo_z2)

for epoch in range(NUM_EPOCHS):
    # Output Demo
    if (epoch%PLOT_EPOCH) == 0:
        save_fig(demo_z1,G,'fig/wiegthed_z1_epoch{}.jpg'.format(epoch))
        save_fig(demo_z2,G,'fig/wiegthed_z2_epoch{}.jpg'.format(epoch))

    D_loss = 0 # Accumalate loss of D
    G_loss = 0 # Accumalate loss of G
    Q_loss_dis = 0 # Accumalate loss of Q (discrete)
    Q_loss_conti = 0 # Accumalate loss of Q (continuous)
    for step,batch_data in enumerate(dataloader):
        batch_size = batch_data[0].size(0)
        
        # Step 1.
        optimD.zero_grad()
        ### Real Images
        real_x = batch_data[0].cuda() if USE_GPU else batch_data[0]
        real_x = Variable(real_x)
        conv_feature_1 = D(real_x)
        prob_real = TF(conv_feature_1)
        real_label = torch.ones(batch_size).cuda() if USE_GPU else torch.ones(batch_size)
        real_label = Variable(real_label.view(-1,1),requires_grad=False)
        D_real_loss = D_criterion(prob_real,real_label)
        D_real_loss.backward()
        ### Fake Images
        z, fake_idx = sample_noise(batch_size)
        z = Variable(torch.Tensor(z).cuda()) if USE_GPU else Variable(torch.Tensor(z))
        fake_x = G(z)
        conv_feature_2 = D(fake_x.detach())
        prob_fake = TF(conv_feature_2)
        real_label = torch.zeros(batch_size).cuda() if USE_GPU else torch.zeros(batch_size)
        real_label = Variable(real_label.view(-1,1),requires_grad=False)
        D_fake_loss = D_criterion(prob_fake,real_label)
        D_fake_loss.backward()

        D_loss += D_real_loss+D_fake_loss
        optimD.step()
        
        # Step 2.
        optimG.zero_grad()
        ### Image Reality
        conv_feature_3 = D(fake_x)
        discriminator_prediction = TF(conv_feature_3)
        fake_label = torch.ones(batch_size).cuda() if USE_GPU else torch.ones(batch_size)
        fake_label = Variable(fake_label.view(-1,1),requires_grad=False)
        generator_loss = D_criterion(discriminator_prediction,fake_label)
        G_loss += generator_loss
        ### Mutaul Info
        pred_c = Q(conv_feature_3)
        fake_idx = torch.LongTensor(fake_idx).cuda() if USE_GPU else torch.LongTensor(fake_idx)
        fake_idx = Variable(fake_idx,requires_grad=False)
        digit_classify_loss = Q_discr_criterion(pred_c[:,:10],fake_idx)
        Q_loss_dis += digit_classify_loss
        conti_loss = Q_conti_criterion(pred_c[:,10:],z[:,-2:])
        Q_loss_conti += conti_loss
        
        if epoch >= 50:
            w1 = 1.0
            w2 = 1.0
        else:
            w1 = 0.0
            w2 = 0.0
        
        G_Q_loss = generator_loss + w1*digit_classify_loss + w2*conti_loss
        G_Q_loss.backward()
        optimG.step()
        
        log.write(log_message.format(float((D_real_loss+D_fake_loss).data.cpu().numpy()),
                                     float(generator_loss.data.cpu().numpy()),
                                     float(digit_classify_loss.data.cpu().numpy()),
                                     float(conti_loss.data.cpu().numpy())))
        
        if step%DISPLAY_STEP == 0:
            print(training_message.format(epoch+1,step,float(D_loss.data.cpu().numpy())/(step+1),
                                          float((G_loss+Q_loss_dis+Q_loss_conti).data.cpu().numpy())/(step+1)),
                 flush=True,end='\r')
    # End of epoch
    D_loss = float(D_loss.data.cpu().numpy())/(step+1)
    G_loss = float(G_loss.data.cpu().numpy())/(step+1)
    Q_loss_dis = float(Q_loss_dis.data.cpu().numpy())/(step+1)
    Q_loss_conti = float(Q_loss_conti.data.cpu().numpy())/(step+1)
    print(epoch_end_message.format(epoch+1,D_loss,G_loss+Q_loss_dis+Q_loss_conti,G_loss,Q_loss_dis,Q_loss_conti))


epoch-  1-D_loss-1.1713-GQ_loss-3.3608-Image_loss-0.6297-Disc_loss-2.3405-Conti_loss-0.3906
epoch-  2-D_loss-1.2225-GQ_loss-3.3818-Image_loss-0.6563-Disc_loss-2.3370-Conti_loss-0.3885
epoch-  3-D_loss-1.2376-GQ_loss-3.4157-Image_loss-0.6763-Disc_loss-2.3325-Conti_loss-0.4069
epoch-  4-D_loss-1.2488-GQ_loss-3.4396-Image_loss-0.6769-Disc_loss-2.3362-Conti_loss-0.4265
epoch-  5-D_loss-1.2563-GQ_loss-3.4299-Image_loss-0.6776-Disc_loss-2.3356-Conti_loss-0.4167
epoch-  6-D_loss-1.2730-GQ_loss-3.4187-Image_loss-0.6764-Disc_loss-2.3423-Conti_loss-0.4000
epoch-  7-D_loss-1.2794-GQ_loss-3.4218-Image_loss-0.6761-Disc_loss-2.3418-Conti_loss-0.4040
epoch-  8-D_loss-1.2952-GQ_loss-3.4162-Image_loss-0.6749-Disc_loss-2.3407-Conti_loss-0.4007
epoch-  9-D_loss-1.3038-GQ_loss-3.4032-Image_loss-0.6754-Disc_loss-2.3406-Conti_loss-0.3871
epoch- 10-D_loss-1.3093-GQ_loss-3.4010-Image_loss-0.6745-Disc_loss-2.3396-Conti_loss-0.3869
epoch- 11-D_loss-1.3119-GQ_loss-3.4120-Image_loss-0.6756-Disc_loss-2.3439-Conti_