In [1]:
import pandas as pd
import numpy as np
import torch
from torch import nn
from torch import optim
from torch.autograd import Variable
import torchvision.transforms as transforms
import torchvision.datasets as dset
import torch.nn.functional as F
import torchvision.utils as vutils
import pickle
from PIL import ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

In [2]:
n_channel = 3
n_disc = 16
n_gen = 64
n_encode = 64
n_l = 10
n_z = 50
img_size = 128
batchSize = 20
use_cuda = torch.cuda.is_available()
n_age = int(n_z/n_l)
n_gender = int(n_z/2)

In [3]:
des_dir = "./data/"

dataset = dset.ImageFolder(root=des_dir,
                           transform=transforms.Compose([
                               transforms.Scale(img_size),
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
                           ]))

dataloader = torch.utils.data.DataLoader(dataset,
                                         batch_size= batchSize,
                                         shuffle=True)

In [4]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder,self).__init__()
        self.conv = nn.Sequential(
            #input: 3*128*128
            nn.Conv2d(n_channel,n_encode,5,2,2),
            nn.ReLU(),
            
            nn.Conv2d(n_encode,2*n_encode,5,2,2),
            nn.ReLU(),
            
            nn.Conv2d(2*n_encode,4*n_encode,5,2,2),
            nn.ReLU(),
            
            nn.Conv2d(4*n_encode,8*n_encode,5,2,2),
            nn.ReLU(),
        
        )
        self.fc = nn.Linear(8*n_encode*8*8,50)
        
    def forward(self,x):
        conv = self.conv(x).view(-1,8*n_encode*8*8)
        out = self.fc(conv)
        return out

In [5]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.fc = nn.Sequential(nn.Linear(n_z+n_l*n_age+n_gender,
                                          8*8*n_gen*16),
                                nn.ReLU())
        self.upconv= nn.Sequential(
            nn.ConvTranspose2d(16*n_gen,8*n_gen,4,2,1),
            nn.ReLU(),
            
            nn.ConvTranspose2d(8*n_gen,4*n_gen,4,2,1),
            nn.ReLU(),
            
            nn.ConvTranspose2d(4*n_gen,2*n_gen,4,2,1),
            nn.ReLU(),
            
            nn.ConvTranspose2d(2*n_gen,n_gen,4,2,1),
            nn.ReLU(),
            
            nn.ConvTranspose2d(n_gen,n_channel,3,1,1),
            nn.Tanh(),
        
        )
        
    def forward(self,z,age,gender):
        l = age.repeat(1,n_age)
        k = gender.view(-1,1).repeat(1,n_gender)
        
        x = torch.cat([z,l,k],dim=1)
        fc = self.fc(x).view(-1,16*n_gen,8,8)
        out = self.upconv(fc)
        return out

In [6]:
class Dimg(nn.Module):
    def __init__(self):
        super(Dimg,self).__init__()
        self.conv_img = nn.Sequential(
            nn.Conv2d(n_channel,n_disc,4,2,1),
        )
        self.conv_l = nn.Sequential(
            nn.ConvTranspose2d(n_l*n_age+n_gender, n_l*n_age+n_gender, 64, 1, 0),
            nn.ReLU()
        )
        self.total_conv = nn.Sequential(
            nn.Conv2d(n_disc+n_l*n_age+n_gender,n_disc*2,4,2,1),
            nn.ReLU(),
            
            nn.Conv2d(n_disc*2,n_disc*4,4,2,1),
            nn.ReLU(),
            
            nn.Conv2d(n_disc*4,n_disc*8,4,2,1),
            nn.ReLU()
        )
        
        self.fc_common = nn.Sequential(
            nn.Linear(8*8*img_size,1024),
            nn.ReLU()
        )
        self.fc_head1 = nn.Sequential(
            nn.Linear(1024,1),
            nn.Sigmoid()
        )
        self.fc_head2 = nn.Sequential(
            nn.Linear(1024,n_l),
            nn.Softmax()
        )
        
    def forward(self,img,age,gender):
        l = age.repeat(1,n_age,1,1,)
        k = gender.repeat(1,n_gender,1,1,)
        conv_img = self.conv_img(img)
        conv_l   = self.conv_l(torch.cat([l,k],dim=1))
        catted   = torch.cat((conv_img,conv_l),dim=1)
        total_conv = self.total_conv(catted).view(-1,8*8*img_size)
        body = self.fc_common(total_conv)
        
        head1 = self.fc_head1(body)
        head2 = self.fc_head2(body)
        
        return head1,head2

In [7]:
class Dz(nn.Module):
    def __init__(self):
        super(Dz,self).__init__()
        self.model = nn.Sequential(
            nn.Linear(n_z,n_disc*4),
            nn.ReLU(),
            
            nn.Linear(n_disc*4,n_disc*2),
            nn.ReLU(),
            
            nn.Linear(n_disc*2,n_disc),
            nn.ReLU(),
            
            nn.Linear(n_disc,1),
            nn.Sigmoid()
        )
    def forward(self,z):
        return self.model(z)

In [8]:
if use_cuda:
    netE = Encoder().cuda()
    netD_img = Dimg().cuda()
    netD_z  = Dz().cuda()
    netG = Generator().cuda()
else:
    netE = Encoder()
    netD_img = Dimg()
    netD_z  = Dz()
    netG = Generator()

In [9]:
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1 or classname.find("Linear") !=-1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)

In [10]:
netE.apply(weights_init)
netD_img.apply(weights_init)
netD_z.apply(weights_init)
netG.apply(weights_init)

Generator (
  (fc): Sequential (
    (0): Linear (125 -> 65536)
    (1): ReLU ()
  )
  (upconv): Sequential (
    (0): ConvTranspose2d(1024, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (1): ReLU ()
    (2): ConvTranspose2d(512, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (3): ReLU ()
    (4): ConvTranspose2d(256, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (5): ReLU ()
    (6): ConvTranspose2d(128, 64, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
    (7): ReLU ()
    (8): ConvTranspose2d(64, 3, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): Tanh ()
  )
)

In [11]:
optimizerE = optim.Adam(netE.parameters(),lr=0.0002,betas=(0.5,0.999))
optimizerD_z = optim.Adam(netD_z.parameters(),lr=0.0002,betas=(0.5,0.999))
optimizerD_img = optim.Adam(netD_img.parameters(),lr=0.0002,betas=(0.5,0.999))
optimizerG = optim.Adam(netG.parameters(),lr=0.0002,betas=(0.5,0.999))

In [12]:
def one_hot(labelTensor):
    oneHot = - torch.ones(batchSize*n_l).view(batchSize,n_l)
    for i,j in enumerate(labelTensor):
        oneHot[i,j] = 1
    if use_cuda:
        return Variable(oneHot).cuda()
    else:
        return Variable(oneHot)

In [13]:
if use_cuda:
    BCE = nn.BCELoss().cuda()
    L1  = nn.L1Loss().cuda()
    CE = nn.CrossEntropyLoss().cuda()
    MSE = nn.MSELoss().cuda()
else:
    BCE = nn.BCELoss()
    L1  = nn.L1Loss()
    CE = nn.CrossEntropyLoss()
    MSE = nn.MSELoss()

In [14]:
def TV_LOSS(imgTensor):
    x = (imgTensor[:,:,1:,:]-imgTensor[:,:,:img_size-1,:])**2
    y = (imgTensor[:,:,:,1:]-imgTensor[:,:,:,:img_size-1])**2 
    out = (x.mean(dim=1)+y.mean(dim=1)).mean()
    return out

In [16]:
niter=150

In [17]:
fixed_noise = pickle.load(open("fixed_noise.p","rb"))

In [18]:
fixed_l = -torch.ones(80*10).view(80,10)

In [19]:
for i,l in enumerate(fixed_l):
    l[i//8] = 1

In [21]:
fixed_g = -1*torch.FloatTensor([1,-1,-1,-1,-1,1,1,1]).view(-1,1).repeat(10,1)

In [22]:
fixed_l_v = Variable(fixed_l)
fixed_img_v = Variable(fixed_noise)
fixed_g_v = Variable(fixed_g)
if use_cuda:
    fixed_l_v = fixed_l_v.cuda()
    fixed_img_v = fixed_img_v.cuda()
    fixed_g_v = fixed_g_v.cuda()

In [23]:
outf='./result_tv_gender'

In [25]:
for epoch in range(30,niter):
    for i,(img_data,img_label) in enumerate(dataloader):
        
        # make image variable and class variable
        
        img_data_v = Variable(img_data)
        img_age = img_label/2
        img_gender = img_label%2*2-1
        
        img_age_v = Variable(img_age).view(-1,1)
        img_gender_v = Variable(img_gender.float())


        if use_cuda:
            img_data_v = img_data_v.cuda()
            img_age_v = img_age_v.cuda()
            img_gender_v = img_gender_v.cuda()            
        
        # make one hot encoding version of label
        batchSize = img_data_v.size(0)
        age_ohe = one_hot(img_age)
        
        # prior distribution z_star, real_label, fake_label
        z_star = Variable(torch.FloatTensor(batchSize*n_z).uniform_(-1,1)).view(batchSize,n_z)
        real_label = Variable(torch.ones(batchSize).fill_(1)).view(-1,1)
        fake_label = Variable(torch.ones(batchSize).fill_(0)).view(-1,1)
        
        if use_cuda:
            z_star, real_label, fake_label = z_star.cuda(),real_label.cuda(),fake_label.cuda()
            
            
        ## train Encoder and Generator with reconstruction loss
        netE.zero_grad()
        netG.zero_grad()
        
        # EG_loss 1. L1 reconstruction loss
        z = netE(img_data_v)
        reconst = netG(z,age_ohe,img_gender_v)
        EG_L1_loss = L1(reconst,img_data_v)
            
            
        # EG_loss 2. GAN loss - image
        z = netE(img_data_v)
        reconst = netG(z,age_ohe,img_gender_v)
        D_reconst,_ = netD_img(reconst,age_ohe.view(batchSize,n_l,1,1),img_gender_v.view(batchSize,1,1,1))
        G_img_loss = BCE(D_reconst,real_label)

        
        
        ## EG_loss 3. GAN loss - z 
        Dz_prior = netD_z(z_star)
        Dz = netD_z(z)
        Ez_loss = BCE(Dz,real_label)
        
        ## EG_loss 4. TV loss - G
        reconst = netG(z.detach(),age_ohe,img_gender_v)
        G_tv_loss = TV_LOSS(reconst)
        
        EG_loss = EG_L1_loss + 0.0001*G_img_loss + 0.01*Ez_loss + G_tv_loss
        EG_loss.backward()
        
        optimizerE.step()
        optimizerG.step()
        


        ## train netD_z with prior distribution U(-1,1)
        netD_z.zero_grad()        
        Dz_prior = netD_z(z_star)
        Dz = netD_z(z.detach())
        
        Dz_loss = BCE(Dz_prior,real_label)+BCE(Dz,fake_label)
        Dz_loss.backward()
        optimizerD_z.step()
        


        ## train D_img with real images
        netD_img.zero_grad()
        D_img,D_clf = netD_img(img_data_v,age_ohe.view(batchSize,n_l,1,1),img_gender_v.view(batchSize,1,1,1))
        D_reconst,_ = netD_img(reconst.detach(),age_ohe.view(batchSize,n_l,1,1),img_gender_v.view(batchSize,1,1,1))

        D_loss = BCE(D_img,real_label)+BCE(D_reconst,fake_label)
        D_loss.backward()
        optimizerD_img.step()
        

        
    ## save fixed img for every 20 step        
    fixed_z = netE(fixed_img_v)
    fixed_fake = netG(fixed_z,fixed_l_v,fixed_g_v)
    vutils.save_image(fixed_fake.data,
                '%s/reconst_epoch%03d.png' % (outf,epoch+1),
                normalize=True)
    
    ## checkpoint
    if epoch%10==0:
        torch.save(netE.state_dict(),"%s/netE_%03d.pth"%(outf,epoch+1))
        torch.save(netG.state_dict(),"%s/netG_%03d.pth"%(outf,epoch+1))
        torch.save(netD_img.state_dict(),"%s/netD_img_%03d.pth"%(outf,epoch+1))
        torch.save(netD_z.state_dict(),"%s/netD_z_%03d.pth"%(outf,epoch+1))


    msg1 = "epoch:{}, step:{}".format(epoch+1,i+1)
    msg2 = format("EG_L1_loss:%f"%(EG_L1_loss.data[0]),"<30")+"|"+format("G_img_loss:%f"%(G_img_loss.data[0]),"<30")
    msg5 = format("G_tv_loss:%f"%(G_tv_loss.data[0]),"<30")+"|"+"Ez_loss:%f"%(Ez_loss.data[0])
    msg3 = format("D_img:%f"%(D_img.mean().data[0]),"<30")+"|"+format("D_reconst:%f"%(D_reconst.mean().data[0]),"<30")\
    +"|"+format("D_loss:%f"%(D_loss.data[0]),"<30")
    msg4 = format("D_z:%f"%(Dz.mean().data[0]),"<30")+"|"+format("D_z_prior:%f"%(Dz_prior.mean().data[0]),"<30")\
    +"|"+format("Dz_loss:%f"%(Dz_loss.data[0]),"<30")

    print()
    print(msg1)
    print(msg2)
    print(msg5)
    print(msg3)
    print(msg4)       
    print()
    print("-"*80)
        
        

  return a.add(b)



epoch:31, step:1186
EG_L1_loss:0.074755           |G_img_loss:10.135382          
G_tv_loss:0.003565            |Ez_loss:0.837289
D_img:0.999197                |D_reconst:0.001297            |D_loss:0.002108               
D_z:0.469071                  |D_z_prior:0.616150            |Dz_loss:1.201891              

--------------------------------------------------------------------------------

epoch:32, step:1186
EG_L1_loss:0.076463           |G_img_loss:8.965395           
G_tv_loss:0.003862            |Ez_loss:0.978929
D_img:0.977894                |D_reconst:0.011918            |D_loss:0.035351               
D_z:0.400765                  |D_z_prior:0.499361            |Dz_loss:1.253560              

--------------------------------------------------------------------------------

epoch:33, step:1186
EG_L1_loss:0.073955           |G_img_loss:10.829243          
G_tv_loss:0.003922            |Ez_loss:0.670335
D_img:0.909902                |D_reconst:0.040764            |D_loss:0.


epoch:52, step:1186
EG_L1_loss:0.061933           |G_img_loss:9.863884           
G_tv_loss:0.004212            |Ez_loss:0.941808
D_img:0.990994                |D_reconst:0.001856            |D_loss:0.011126               
D_z:0.421811                  |D_z_prior:0.492365            |Dz_loss:1.374661              

--------------------------------------------------------------------------------

epoch:53, step:1186
EG_L1_loss:0.066513           |G_img_loss:11.426128          
G_tv_loss:0.005444            |Ez_loss:0.933098
D_img:0.993847                |D_reconst:0.000207            |D_loss:0.006484               
D_z:0.432993                  |D_z_prior:0.409501            |Dz_loss:1.569698              

--------------------------------------------------------------------------------

epoch:54, step:1186
EG_L1_loss:0.064782           |G_img_loss:10.132051          
G_tv_loss:0.004250            |Ez_loss:1.049470
D_img:0.998850                |D_reconst:0.030500            |D_loss:0.


epoch:73, step:1186
EG_L1_loss:0.053262           |G_img_loss:7.900271           
G_tv_loss:0.004476            |Ez_loss:0.775865
D_img:0.879827                |D_reconst:0.003941            |D_loss:0.351988               
D_z:0.490823                  |D_z_prior:0.547798            |Dz_loss:1.364191              

--------------------------------------------------------------------------------

epoch:74, step:1186
EG_L1_loss:0.063578           |G_img_loss:9.692825           
G_tv_loss:0.004686            |Ez_loss:0.947681
D_img:0.999373                |D_reconst:0.003204            |D_loss:0.003871               
D_z:0.411077                  |D_z_prior:0.549572            |Dz_loss:1.204915              

--------------------------------------------------------------------------------

epoch:75, step:1186
EG_L1_loss:0.060853           |G_img_loss:5.634823           
G_tv_loss:0.004064            |Ez_loss:0.967266
D_img:0.971164                |D_reconst:0.058697            |D_loss:0.


epoch:94, step:1186
EG_L1_loss:0.057702           |G_img_loss:12.341487          
G_tv_loss:0.005413            |Ez_loss:0.822527
D_img:0.998125                |D_reconst:0.000290            |D_loss:0.002180               
D_z:0.464610                  |D_z_prior:0.518132            |Dz_loss:1.366867              

--------------------------------------------------------------------------------

epoch:95, step:1186
EG_L1_loss:0.056038           |G_img_loss:12.510416          
G_tv_loss:0.005178            |Ez_loss:1.000798
D_img:0.984432                |D_reconst:0.000044            |D_loss:0.016215               
D_z:0.399584                  |D_z_prior:0.682613            |Dz_loss:0.939096              

--------------------------------------------------------------------------------

epoch:96, step:1186
EG_L1_loss:0.061254           |G_img_loss:11.597952          
G_tv_loss:0.006253            |Ez_loss:1.127185
D_img:0.982024                |D_reconst:0.000350            |D_loss:0.


epoch:115, step:1186
EG_L1_loss:0.058240           |G_img_loss:7.303968           
G_tv_loss:0.006044            |Ez_loss:0.793804
D_img:0.999952                |D_reconst:0.004508            |D_loss:0.004590               
D_z:0.465815                  |D_z_prior:0.484638            |Dz_loss:1.435347              

--------------------------------------------------------------------------------

epoch:116, step:1186
EG_L1_loss:0.053413           |G_img_loss:12.095599          
G_tv_loss:0.004693            |Ez_loss:0.753219
D_img:0.999904                |D_reconst:0.000669            |D_loss:0.000766               
D_z:0.483004                  |D_z_prior:0.511638            |Dz_loss:1.364893              

--------------------------------------------------------------------------------

epoch:117, step:1186
EG_L1_loss:0.050686           |G_img_loss:12.946668          
G_tv_loss:0.004926            |Ez_loss:0.926873
D_img:0.933394                |D_reconst:0.000070            |D_loss


epoch:136, step:1186
EG_L1_loss:0.051005           |G_img_loss:11.805435          
G_tv_loss:0.004898            |Ez_loss:0.865574
D_img:0.997487                |D_reconst:0.002050            |D_loss:0.004590               
D_z:0.441171                  |D_z_prior:0.471866            |Dz_loss:1.423527              

--------------------------------------------------------------------------------

epoch:137, step:1186
EG_L1_loss:0.049611           |G_img_loss:10.011687          
G_tv_loss:0.005839            |Ez_loss:0.858592
D_img:0.995277                |D_reconst:0.000662            |D_loss:0.005453               
D_z:0.430759                  |D_z_prior:0.474046            |Dz_loss:1.388363              

--------------------------------------------------------------------------------

epoch:138, step:1186
EG_L1_loss:0.052184           |G_img_loss:10.447694          
G_tv_loss:0.004495            |Ez_loss:0.729217
D_img:0.999902                |D_reconst:0.003584            |D_loss