# InfoGAN
referred to [the blog](http://peluigi.hatenablog.com/entry/2018/08/29/120314)

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os, argparse, sys 
sys.path.append('..')
from pathlib import Path

import torch
from torch import nn, optim
from torch.nn import functional as F
from torchvision import datasets, transforms
import torchvision.utils as vutils
from torchvision.datasets import MNIST
from torchvision.transforms import ToTensor
from torch.utils.data import DataLoader
import numpy as np

from InfoGAN.model import Generator, Discriminator, ContStatistics, DiscStatistics

batch_size = 100
lr = 1e-4
latent_size = 256
num_epochs = 100
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
output_dir = Path("../results/InfoGAN")
output_dir.mkdir(parents=True, exist_ok=True)

In [3]:
!ls ../data

ALI_BiGAN  MNIST  tv


In [4]:
CLASS_LABEL = 10
BATCH_SIZE = 100
NUM_WORKERS = 8
RANGE = 1
train_data = MNIST("../data/MNIST", train=True, download=True, transform=ToTensor())
loader = DataLoader(train_data, batch_size=BATCH_SIZE, shuffle=True,
                    drop_last=True, num_workers=NUM_WORKERS)

G = Generator().to(device)
D = Discriminator().to(device)
DiscS = DiscStatistics().to(device)
ContS1 = ContStatistics().to(device)
ContS2 = ContStatistics().to(device)
optimG = optim.Adam(G.parameters(), lr=1e-3, betas=(0.5, 0.999))
optimD = optim.Adam(D.parameters(), lr=2e-4, betas=(0.5, 0.999))
optimS = optim.Adam([{"params":DiscS.parameters()}, {"params":ContS1.parameters()},
                     {"params":ContS2.parameters()}], lr=1e-3, betas=(0.5, 0.999))
label = torch.zeros(BATCH_SIZE).to(device).float()
real_label, fake_label = 1, 0

c = torch.linspace(-1, 1, 10).repeat(10).reshape(-1, 1)
c1 = torch.cat([c, torch.zeros_like(c)], 1).float() * RANGE
c2 = torch.cat([torch.zeros_like(c), c], 1).float() * RANGE
idx = torch.from_numpy(np.arange(10).repeat(10))
one_hot = torch.zeros((BATCH_SIZE, CLASS_LABEL)).float()
one_hot[range(BATCH_SIZE), idx] = 1
fix_z = torch.Tensor(BATCH_SIZE, 62).uniform_(-1, 1)
fix_noise1 = torch.cat([fix_z, c1, one_hot], 1)[...,None,None].to(device)
fix_noise2 = torch.cat([fix_z, c2, one_hot], 1)[...,None,None].to(device)

In [5]:
for epoch in range(100):
    for i, (d_real, _) in enumerate(loader):
        optimD.zero_grad()
        d_real = d_real.to(device)
        label.fill_(real_label)
        real_prob = D(d_real).squeeze()
        real_loss = F.binary_cross_entropy_with_logits(real_prob, label)
        real_loss.backward()

        label.fill_(fake_label)
        ### get noise
        idx = torch.randint(0, 10, (BATCH_SIZE,)).long()
        disc_c = torch.eye(10)[idx][...,None,None].float().to(device)
        cont_c = torch.zeros(BATCH_SIZE, 2, 1, 1).uniform_(-1, 1).float().to(device) * RANGE
        z = torch.zeros(BATCH_SIZE, 62, 1, 1).uniform_(-1, 1).float().to(device)
        noise = torch.cat([z, cont_c, disc_c], 1).to(device).float()
        ### generate
        d_fake = G(noise)
        d_fake_series = d_fake.reshape(100, -1)
        fake_prob = D(d_fake.detach()).squeeze()
        fake_loss = F.binary_cross_entropy_with_logits(fake_prob, label)
        fake_loss.backward()
        loss_D = real_loss + fake_loss
        optimD.step()

        # generator
        optimG.zero_grad()
        optimS.zero_grad()
        label.fill_(real_label)
        ## adversarial loss
        inv_fake_prob = D(d_fake).squeeze()
        inv_fake_loss = F.binary_cross_entropy_with_logits(inv_fake_prob, label)
        ## MINE
        ### c ~ P(C)
        idx = torch.randint(0, 10, (100,)).long()
        disc_c_bar = torch.eye(10)[idx].float().to(device)
        cont_c_bar = torch.zeros(100, 2, 1, 1).uniform_(-1, 1).float().to(device) * RANGE
        ### discrete variable
        joint_disc = DiscS(torch.cat([d_fake_series, disc_c.reshape(100, -1)], 1))
        marginal_disc = DiscS(torch.cat([d_fake_series, disc_c_bar.reshape(100, -1)], 1))
        ### continuout variable
        joint_cont1 = ContS1(torch.cat([d_fake_series, cont_c[:,0].reshape(100, -1)], 1))
        joint_cont2 = ContS2(torch.cat([d_fake_series, cont_c[:,1].reshape(100, -1)], 1))
        marginal_cont1 = ContS1(torch.cat([d_fake_series, cont_c_bar[:,0].reshape(100, -1)], 1))
        marginal_cont2 = ContS2(torch.cat([d_fake_series, cont_c_bar[:,1].reshape(100, -1)], 1))
        ### calc mutual information
        mi_disc = F.softplus(-joint_disc).mean() + F.softplus(marginal_disc).mean()
        mi_cont1 = F.softplus(-joint_cont1).mean() + F.softplus(marginal_cont1).mean()
        mi_cont2 = F.softplus(-joint_cont2).mean() + F.softplus(marginal_cont2).mean()
        mi = (mi_disc + mi_cont1 + mi_cont2)/3

        loss = mi + inv_fake_loss
        loss.backward()
        optimG.step()
        optimS.step()
        if i == 599 and epoch %4 == 0:
            print("epoch [{}/{}], iter [{}/{}], D : {:.3}, G : {:.3}, S : {:.3}".format(
                epoch, 100, i, len(loader), loss_D.item(), inv_fake_loss.item(), mi.item()
            ))
    with torch.no_grad():
        fake1 = G(fix_noise1)
        fake2 = G(fix_noise2)
        vutils.save_image(fake1.detach(), 
                          str(output_dir / Path(f"{epoch}epoch_fake1.png")),
                          normalize=True, nrow=10)
        vutils.save_image(fake2.detach(), 
                          str(output_dir / Path(f"{epoch}epoch_fake2.png")),
                          normalize=True, nrow=10)

epoch [0/100], iter [599/600], D : 0.509, G : 1.59, S : 0.0288
epoch [4/100], iter [599/600], D : 1.23, G : 0.88, S : 0.00743
epoch [8/100], iter [599/600], D : 1.11, G : 0.877, S : 0.0388
epoch [12/100], iter [599/600], D : 1.24, G : 1.13, S : 0.00333
epoch [16/100], iter [599/600], D : 1.09, G : 0.994, S : 0.00202
epoch [20/100], iter [599/600], D : 1.04, G : 1.16, S : 0.0108
epoch [24/100], iter [599/600], D : 1.04, G : 0.962, S : 0.0277
epoch [28/100], iter [599/600], D : 1.17, G : 1.37, S : 0.0017
epoch [32/100], iter [599/600], D : 0.918, G : 1.26, S : 0.000884
epoch [36/100], iter [599/600], D : 1.07, G : 1.48, S : 0.00046
epoch [40/100], iter [599/600], D : 0.938, G : 1.18, S : 0.000251
epoch [44/100], iter [599/600], D : 1.01, G : 0.816, S : 0.00739
epoch [48/100], iter [599/600], D : 0.878, G : 1.11, S : 0.000651
epoch [52/100], iter [599/600], D : 0.851, G : 1.04, S : 0.000169
epoch [56/100], iter [599/600], D : 0.935, G : 1.13, S : 0.000798
epoch [60/100], iter [599/600], D