In [3]:
import torch
import torchvision
import torchvision.transforms as transforms
import torchvision.datasets as dset 
from torchvision.utils import save_image
from torch.autograd import Variable
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt 
import torch.optim as optim
import math
import itertools
from IPython import display

In [4]:
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])
realimages = torchvision.datasets.MNIST(root='./data', train=True,download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(realimages, batch_size=100,shuffle=True, num_workers=2)

In [93]:
class FrontEnd(nn.Module):
    """Front End for D and Q"""
    def __init__(self):
        super(FrontEnd,self).__init__()
        self.frontend = nn.Sequential(
            nn.Conv2d(1, 64, 4, 2, 1),
          nn.LeakyReLU(0.1, inplace=True),
          nn.Conv2d(64, 128, 4, 2, 1, bias=False),
          nn.BatchNorm2d(128),
          nn.LeakyReLU(0.1, inplace=True),
) 
        
    def forward(self,x):
        output = self.frontend(x)
        #print(x.shape)
        output = output.view(-1,128*7*7)
        return output

In [94]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.probability = nn.Sequential(
            nn.Linear(128*7*7, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 1),
        )
    def forward(self,x):
        x = self.probability(x)
        x = F.sigmoid(x)
        return x

In [95]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.fc = nn.Sequential(
            nn.Linear(74, 1024),
            nn.BatchNorm1d(1024),
            nn.ReLU(),
            nn.Linear(1024, 128 *7*7),
            nn.BatchNorm1d(128 *7*7),
            nn.ReLU(),
        )
        self.generate = nn.Sequential(
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
            nn.Tanh()
        )
    def forward(self,x):
        x = self.fc(x)
        x = x.view(-1, 128, 7, 7)
        x = self.generate(x)
        return x

In [96]:
class Q(nn.Module):
    def __init__(self):
        super(Q,self).__init__()
        self.classprob = nn.Sequential(
            nn.Linear(128*7*7, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.2),
            nn.Linear(128, 10),
        )
    def forward(self,x):
        x = self.classprob(x)
        return x

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [97]:
def generate_noise(batch_size):
    idx = np.random.randint(10,size=batch_size)
    c = np.zeros((batch_size,10))
    c[range(batch_size),idx] = 1
    c = torch.Tensor(c)
    noise = torch.FloatTensor(batch_size,64)
    noise.data.uniform_(-10,10)
    z = torch.cat([noise,c],1).view(-1,74)
    return z,idx

In [110]:
FE = FrontEnd()
D = Discriminator()
G = Generator()
QE = Q()
FE,D,G,QE = FE.cuda(),D.cuda(),G.cuda(),QE.cuda()
for i in [FE,D,G,QE]:
    i.apply(weights_init)

In [111]:
criterionD = nn.BCELoss()
criterionQ = nn.CrossEntropyLoss()
optimizerD = optim.Adam([{'params':FE.parameters()}, {'params':D.parameters()}], lr=0.0002, betas=(0.5, 0.99))
optimizerG = optim.Adam([{'params':G.parameters()}, {'params':QE.parameters()}], lr=0.001, betas=(0.5, 0.99))

In [112]:
num_epochs = 25
num_test_samples =100
idix = np.arange(10).repeat(10)
one_hot = np.zeros((num_test_samples, 10))
one_hot[range(num_test_samples), idix] = 1
fix_noise = torch.FloatTensor(num_test_samples, 64)
fix_noise.data.uniform_(-10, 10)
fixed_z = torch.cat([fix_noise,torch.Tensor(one_hot)],1).view(-1,74)
#print(fixed_z.shape)
fixed_z = fixed_z.cuda()

# create figure for plotting
size_figure_grid = int(math.sqrt(num_test_samples))


In [113]:
for epoch in range(num_epochs):
    for n, (images, _) in enumerate(train_loader):
        bs = images.size(0)
        images = images.cuda()
        images = Variable(images)
        #print(images.shape)
        # Discriminator
        optimizerD.zero_grad()
        label = torch.Tensor(np.ones(bs))
        label = label.cuda()
        label = Variable(label,requires_grad=False)
        feout1 =  FE(images)
        real_prob = D(feout1)
        real_loss = criterionD(real_prob,label)
        real_loss.backward()
        
        z,idx = generate_noise(bs)
        z = z.cuda()
        z = Variable(z)
        fake_img = G(z)
        feout1 = FE(fake_img)
        fake_prob = D(feout1)
        label2 = torch.Tensor(np.zeros(bs))
        label2 = label2.cuda()
        label2 = Variable(label2,requires_grad = False)
        fake_loss = criterionD(fake_prob,label2)
        fake_loss.backward(retain_graph=True)
        
        D_loss = real_loss + fake_loss
        optimizerD.step()
        
        # Generator and Q
        optimizerG.zero_grad()
        feout1 = FE(fake_img)
        fake_prob = D(feout1)
        label = torch.Tensor(np.ones(bs))
        label = label.cuda()
        label = Variable(label,requires_grad=False)
        reconstruct_loss = criterionD(fake_prob,label)
        
        q_logits = QE(feout1)
        class_ = torch.LongTensor(idx).cuda()
        target = Variable(class_)
        #print(q_logits.shape,target.shape)
        q_loss = criterionQ(q_logits,target)
        G_loss = reconstruct_loss + q_loss
        G_loss.backward(retain_graph=True)
        optimizerG.step()
        
        if (n+1) % 100 == 0:
            print('Epoch/Iter:{0}/{1}, Dloss: {2}, Gloss: {3}'.format(
            epoch, n, D_loss.data.cpu().numpy(),
            G_loss.data.cpu().numpy())
          )  
    test_images = G(fixed_z)
    #print(test_images[0].shape)
    save_image(test_images.data,'./samples/infoGAN2/epoch_{:d}_pytorch.png'.format(epoch),nrow=10)

  "Please ensure they have the same size.".format(target.size(), input.size()))


Epoch/Iter:0/99, Dloss: 0.9887743592262268, Gloss: 3.324737548828125
Epoch/Iter:0/199, Dloss: 0.747309148311615, Gloss: 3.584533452987671
Epoch/Iter:0/299, Dloss: 0.5177753567695618, Gloss: 3.793801784515381
Epoch/Iter:0/399, Dloss: 0.4259243607521057, Gloss: 3.3666489124298096
Epoch/Iter:0/499, Dloss: 0.3369847238063812, Gloss: 2.796945810317993
Epoch/Iter:0/599, Dloss: 0.2611375153064728, Gloss: 2.7085366249084473
Epoch/Iter:1/99, Dloss: 0.18423286080360413, Gloss: 2.6812591552734375
Epoch/Iter:1/199, Dloss: 0.15251821279525757, Gloss: 3.141993999481201
Epoch/Iter:1/299, Dloss: 0.12600767612457275, Gloss: 3.281402826309204
Epoch/Iter:1/399, Dloss: 0.08165092021226883, Gloss: 3.599621534347534
Epoch/Iter:1/499, Dloss: 0.07537064701318741, Gloss: 3.6825878620147705
Epoch/Iter:1/599, Dloss: 0.053766991943120956, Gloss: 3.902778148651123
Epoch/Iter:2/99, Dloss: 0.050118181854486465, Gloss: 4.235359191894531
Epoch/Iter:2/199, Dloss: 0.06988492608070374, Gloss: 4.467583179473877
Epoch/Iter

Epoch/Iter:19/199, Dloss: 0.07615428417921066, Gloss: 5.925811767578125
Epoch/Iter:19/299, Dloss: 0.03363516926765442, Gloss: 5.4948506355285645
Epoch/Iter:19/399, Dloss: 0.12530186772346497, Gloss: 6.18087911605835
Epoch/Iter:19/499, Dloss: 0.021310409530997276, Gloss: 6.340580940246582
Epoch/Iter:19/599, Dloss: 0.07556824386119843, Gloss: 6.676409721374512
Epoch/Iter:20/99, Dloss: 0.036128487437963486, Gloss: 5.577206134796143
Epoch/Iter:20/199, Dloss: 0.030040476471185684, Gloss: 5.977469444274902
Epoch/Iter:20/299, Dloss: 0.029988255351781845, Gloss: 5.273538589477539
Epoch/Iter:20/399, Dloss: 0.02847667597234249, Gloss: 6.0708465576171875
Epoch/Iter:20/499, Dloss: 0.007585631683468819, Gloss: 6.227371692657471
Epoch/Iter:20/599, Dloss: 0.07143455743789673, Gloss: 6.247979640960693
Epoch/Iter:21/99, Dloss: 0.08764276653528214, Gloss: 6.224170207977295
Epoch/Iter:21/199, Dloss: 0.028515130281448364, Gloss: 6.63681697845459
Epoch/Iter:21/299, Dloss: 0.1726539134979248, Gloss: 3.77998