In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from researchlib.single_import import *

In [3]:
train_loader = FromPublic('mnist', 'train', batch_size=128, normalize=False, shuffle=True, pin_memory=True, num_workers=4)
test_loader = FromPublic('mnist', 'test', batch_size=128, normalize=False, shuffle=True, pin_memory=True, num_workers=4)

In [50]:
front = builder([
    nn.Conv2d(1, 64, 4, 2, 1, bias=False),
    nn.BatchNorm2d(64),
    nn.SELU(inplace=True),
    nn.Conv2d(64, 128, 4, 2, 1, bias=False),
    nn.BatchNorm2d(128),
    nn.SELU(inplace=True),
    nn.Conv2d(128, 1024, 7, bias=False),
    nn.BatchNorm2d(1024),
    nn.SELU(inplace=True)
])

dd = builder([
    nn.Conv2d(1024, 1, 1),
    nn.Sigmoid(),
    Reshape((-1, 1))
])

class Q_(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv = nn.Conv2d(1024, 128, 1, bias=False)
        self.bn = nn.BatchNorm2d(128)
        self.lReLU = nn.LeakyReLU(0.1, inplace=True)
        self.conv_disc = nn.Conv2d(128, 10, 1)
        self.conv_mu = nn.Conv2d(128, 2, 1)
        self.conv_var = nn.Conv2d(128, 2, 1)

    def forward(self, x):
        y = self.conv(x)
        disc_logits = self.conv_disc(y).squeeze()
        mu = self.conv_mu(y).squeeze()
        var = self.conv_var(y).squeeze().exp()
        return disc_logits, mu, var 
qq = Q_()

class D_(nn.Module):
    def __init__(self, front, dd, qq):
        super().__init__()
        self.fe = front
        self.d = dd
        self.q = qq
    
    def forward(self, x):
        x = self.fe(x)
        dis = self.d(x)
        logits, mu, var = self.q(x)
        return dis, logits, mu, var
D = D_(front, dd, qq)
        

G = builder([
    nn.ConvTranspose2d(74, 1024, 1, 1, bias=False),
    nn.BatchNorm2d(1024),
    nn.SELU(inplace=True),
    nn.ConvTranspose2d(1024, 128, 7, 1, bias=False),
    nn.BatchNorm2d(128),
    nn.SELU(inplace=True),
    nn.ConvTranspose2d(128, 64, 4, 2, 1, bias=False),
    nn.BatchNorm2d(64),
    nn.SELU(inplace=True),
    nn.ConvTranspose2d(64, 1, 4, 2, 1, bias=False),
    nn.Sigmoid()
])

In [74]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import torch.autograd as autograd

import torchvision.datasets as dset
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.utils import save_image

import numpy as np

class log_gaussian:

  def __call__(self, x, mu, var):

    logli = -0.5*(var.mul(2*np.pi)+1e-6).log() - \
            (x-mu).pow(2).div(var.mul(2.0)+1e-6)
    
    return logli.sum(1).mean().mul(-1)

class Trainer:

  def __init__(self, G, D, train_loader):

    self.G = G.cuda()
    self.D = D.cuda()
    self.train_loader = train_loader
    self.batch_size = 128

  def _noise_sample(self, bs, dis_c=10, con_c=2, noise=62):
    idx = np.random.randint(dis_c, size=bs)
    c = np.zeros((bs, dis_c))
    c[range(bs), idx] = 1.0
    c = torch.from_numpy(c).float()

    con_c_ = torch.FloatTensor(bs, con_c)
    noise_ = torch.FloatTensor(bs, noise)
    
    con_c_.data.uniform_(-1.0, 1.0)
    noise_.data.uniform_(-1.0, 1.0)
    
    z = torch.cat([noise_, c, con_c_], 1).view(-1, 74, 1, 1)
    return z, idx, con_c_

  def train(self):
    label = torch.FloatTensor(self.batch_size, 1).cuda()
    label = Variable(label, requires_grad=False)
    
    criterionD = nn.BCELoss().cuda()
    criterionQ_dis = nn.CrossEntropyLoss().cuda()
    criterionQ_con = log_gaussian()

    optimD = optim.Adam(self.D.parameters(), lr=0.0002, betas=(0.5, 0.99))
    optimG = optim.Adam(self.G.parameters(), lr=0.001, betas=(0.5, 0.99))
    
    # for test
    c = np.linspace(-1, 1, 10).reshape(1, -1)
    c = np.repeat(c, 10, 0).reshape(-1, 1)
    c1 = np.hstack([c, np.zeros_like(c)])
    c2 = np.hstack([np.zeros_like(c), c])
    idx = np.arange(10).repeat(10)
    one_hot = np.zeros((100, 10))
    one_hot[range(100), idx] = 1
    fix_noise = torch.Tensor(100, 62).uniform_(-1, 1)


    for epoch in range(100):
      for num_iters, batch_data in enumerate(self.train_loader, 0):
        
        #--------------------------------------------------- D
        optimD.zero_grad()
                
        # data
        x, _ = batch_data
        label = torch.ones(x.size(0), 1).cuda()
        
        # forward
        probs_real = self.D(x.cuda())[0]
        
        # loss
        loss_real = criterionD(probs_real, label)
        loss_real.backward()
        
        #@@@@@@@@@@@@@@@@@
        
        # fake data
        z, idx, con_c = self._noise_sample(x.size(0))        
        z = z.cuda()
        con_c = con_c.cuda()
        label.data.fill_(0)
        
        # forward
        fake_x = self.G(z)
        probs_fake = self.D(fake_x.detach())[0]
        
        # loss
        loss_fake = criterionD(probs_fake, label)
        loss_fake.backward()
        D_loss = loss_real + loss_fake

        optimD.step()
        #----------------------------------------------------END D
        
        
        #----------------------------------------------------G,Q
        optimG.zero_grad()

        # data
        label.data.fill_(1.0)
        target = Variable(torch.LongTensor(idx).cuda())
        
        # forward
        probs_fake, q_logits, q_mu, q_var = self.D(fake_x)

        # loss
        reconstruct_loss = criterionD(probs_fake, label)
        dis_loss = criterionQ_dis(q_logits, target)
        con_loss = criterionQ_con(con_c, q_mu, q_var)*0.1
        G_loss = reconstruct_loss + dis_loss + con_loss
        
        G_loss.backward()
        optimG.step()
        #-----------------------------------------------------END G,Q
        
        
        if num_iters % 100 == 0:
            print('Epoch/Iter:{0}/{1}, Dloss: {2}, Gloss: {3}'.format(
            epoch, num_iters, D_loss.data.cpu().numpy(),
            G_loss.data.cpu().numpy())
            )
            
            noise = fix_noise.float()
            dis_c = torch.Tensor(one_hot).float()
            con_c = torch.from_numpy(c1).float()
            z = torch.cat([noise, dis_c, con_c], 1).view(-1, 74, 1, 1)
            x_save = self.G(z.cuda())
            save_image(x_save.data, './tmp/c1.png', nrow=10)

            con_c = torch.from_numpy(c2).float()
            z = torch.cat([noise, dis_c, con_c], 1).view(-1, 74, 1, 1)
            x_save = self.G(z.cuda())
            save_image(x_save.data, './tmp/c2.png', nrow=10)



In [75]:
t = Trainer(G, D, train_loader)
t.train()

Epoch/Iter:0/0, Dloss: 1.3062171936035156, Gloss: 2.7148499488830566
Epoch/Iter:0/100, Dloss: 1.085994839668274, Gloss: 2.6728336811065674
Epoch/Iter:0/200, Dloss: 1.2241331338882446, Gloss: 2.6113290786743164
Epoch/Iter:0/300, Dloss: 1.0840668678283691, Gloss: 2.5850718021392822
Epoch/Iter:0/400, Dloss: 1.1240341663360596, Gloss: 2.6426939964294434
Epoch/Iter:1/0, Dloss: 1.1300783157348633, Gloss: 2.6917238235473633
Epoch/Iter:1/100, Dloss: 1.100205421447754, Gloss: 2.6043779850006104
Epoch/Iter:1/200, Dloss: 1.1939834356307983, Gloss: 2.577333688735962
Epoch/Iter:1/300, Dloss: 1.2049721479415894, Gloss: 2.6343424320220947
Epoch/Iter:1/400, Dloss: 1.1715185642242432, Gloss: 2.549881935119629
Epoch/Iter:2/0, Dloss: 1.1423088312149048, Gloss: 2.6548547744750977
Epoch/Iter:2/100, Dloss: 1.1370999813079834, Gloss: 2.62497878074646
Epoch/Iter:2/200, Dloss: 1.1598628759384155, Gloss: 2.5694077014923096
Epoch/Iter:2/300, Dloss: 1.0810167789459229, Gloss: 2.6170501708984375
Epoch/Iter:2/400, 

Epoch/Iter:23/200, Dloss: 1.0819087028503418, Gloss: 2.71279239654541
Epoch/Iter:23/300, Dloss: 1.156919240951538, Gloss: 2.7179269790649414
Epoch/Iter:23/400, Dloss: 1.1303086280822754, Gloss: 2.735119104385376
Epoch/Iter:24/0, Dloss: 1.0527141094207764, Gloss: 2.7640984058380127
Epoch/Iter:24/100, Dloss: 1.0398590564727783, Gloss: 2.7919905185699463
Epoch/Iter:24/200, Dloss: 1.1066138744354248, Gloss: 2.770742177963257
Epoch/Iter:24/300, Dloss: 1.0491502285003662, Gloss: 2.7117767333984375
Epoch/Iter:24/400, Dloss: 1.0576646327972412, Gloss: 2.7272067070007324
Epoch/Iter:25/0, Dloss: 1.0098536014556885, Gloss: 2.77632999420166
Epoch/Iter:25/100, Dloss: 1.041809320449829, Gloss: 2.808562994003296
Epoch/Iter:25/200, Dloss: 1.0218186378479004, Gloss: 2.76668381690979
Epoch/Iter:25/300, Dloss: 1.0840929746627808, Gloss: 2.718836784362793
Epoch/Iter:25/400, Dloss: 1.076441764831543, Gloss: 2.804506301879883
Epoch/Iter:26/0, Dloss: 1.0222265720367432, Gloss: 2.7945737838745117
Epoch/Iter:2

Epoch/Iter:46/300, Dloss: 0.8499060869216919, Gloss: 3.0803775787353516
Epoch/Iter:46/400, Dloss: 0.8373176455497742, Gloss: 3.05374813079834
Epoch/Iter:47/0, Dloss: 0.9004802703857422, Gloss: 3.082582473754883
Epoch/Iter:47/100, Dloss: 0.9401005506515503, Gloss: 3.06014084815979
Epoch/Iter:47/200, Dloss: 0.9125829935073853, Gloss: 3.087968587875366
Epoch/Iter:47/300, Dloss: 1.0089216232299805, Gloss: 2.954106569290161
Epoch/Iter:47/400, Dloss: 0.8565082550048828, Gloss: 3.1822943687438965
Epoch/Iter:48/0, Dloss: 0.9764677882194519, Gloss: 3.030803680419922
Epoch/Iter:48/100, Dloss: 0.8825559616088867, Gloss: 3.0896494388580322
Epoch/Iter:48/200, Dloss: 0.9680340886116028, Gloss: 2.9830129146575928
Epoch/Iter:48/300, Dloss: 1.009861707687378, Gloss: 3.048816204071045
Epoch/Iter:48/400, Dloss: 1.041391134262085, Gloss: 3.0259833335876465
Epoch/Iter:49/0, Dloss: 1.084768533706665, Gloss: 3.090958833694458
Epoch/Iter:49/100, Dloss: 0.9710615277290344, Gloss: 2.999007225036621
Epoch/Iter:4

Epoch/Iter:69/400, Dloss: 0.8734657764434814, Gloss: 3.271695613861084
Epoch/Iter:70/0, Dloss: 0.7705744504928589, Gloss: 3.2441391944885254
Epoch/Iter:70/100, Dloss: 0.7803820371627808, Gloss: 3.2201735973358154
Epoch/Iter:70/200, Dloss: 0.9163256883621216, Gloss: 3.3016974925994873
Epoch/Iter:70/300, Dloss: 0.7577604055404663, Gloss: 3.3597397804260254
Epoch/Iter:70/400, Dloss: 0.844048023223877, Gloss: 3.274661064147949
Epoch/Iter:71/0, Dloss: 0.8053690791130066, Gloss: 3.2499380111694336
Epoch/Iter:71/100, Dloss: 0.8341843485832214, Gloss: 3.324840545654297
Epoch/Iter:71/200, Dloss: 0.7678184509277344, Gloss: 3.377678871154785
Epoch/Iter:71/300, Dloss: 0.7293052673339844, Gloss: 3.36234450340271
Epoch/Iter:71/400, Dloss: 0.8027940392494202, Gloss: 3.4565958976745605
Epoch/Iter:72/0, Dloss: 0.8044732809066772, Gloss: 3.291034698486328
Epoch/Iter:72/100, Dloss: 0.9450010061264038, Gloss: 3.234060049057007
Epoch/Iter:72/200, Dloss: 0.9053916931152344, Gloss: 3.2721660137176514
Epoch/I

Epoch/Iter:93/0, Dloss: 0.7030186057090759, Gloss: 3.5539932250976562
Epoch/Iter:93/100, Dloss: 0.7050765752792358, Gloss: 3.6799330711364746
Epoch/Iter:93/200, Dloss: 0.7126112580299377, Gloss: 3.699540138244629
Epoch/Iter:93/300, Dloss: 0.7981753945350647, Gloss: 3.6393070220947266
Epoch/Iter:93/400, Dloss: 0.8378759026527405, Gloss: 3.4463438987731934
Epoch/Iter:94/0, Dloss: 0.8283374309539795, Gloss: 3.4396426677703857
Epoch/Iter:94/100, Dloss: 0.7474976778030396, Gloss: 3.6722793579101562
Epoch/Iter:94/200, Dloss: 0.7146772146224976, Gloss: 3.461449384689331
Epoch/Iter:94/300, Dloss: 0.7444645166397095, Gloss: 3.574892282485962
Epoch/Iter:94/400, Dloss: 0.7916353940963745, Gloss: 3.633924961090088
Epoch/Iter:95/0, Dloss: 0.738510012626648, Gloss: 3.54362416267395
Epoch/Iter:95/100, Dloss: 0.6688709259033203, Gloss: 3.696139097213745
Epoch/Iter:95/200, Dloss: 0.6496422290802002, Gloss: 3.804899215698242
Epoch/Iter:95/300, Dloss: 0.6979778409004211, Gloss: 3.6664676666259766
Epoch/I