In [1]:
import torch
import torch.nn as nn
import numpy as np
from torch.autograd import Variable
from glob import glob
import torch.optim as optim
import torchvision.datasets as dset
import torchvision.transforms as transforms
import os
from PIL import Image
from glob import glob

In [2]:
#params
noise_dim = 64
n_epochs = 5
batch_size = 16
imageH = 64
imageW = 64
global_kernel_num = 128
learning_rate = 0.00008
k_init = 0
lam = 0.001
gamma = 0.5
if_cuda = True

In [None]:
111111111

# Data

In [None]:
if not os.path.exists('./data/CelebA/64_64'):
    os.mkdir('./data/CelebA/64_64')

In [None]:
file_list_ori = glob('./data/CelebA/splits/train/*')
for i in file_list_ori:
    img = Image.open(i)
    img = img.crop((25, 50, 25+128, 50+128))
    img = img.resize((64, 64))
    img.save('./data/CelebA/64_64/' + i.split('/')[-1])

In [3]:
file_list = glob('./data/CelebA/64_64/*')
mini_batch = len(file_list) / batch_size
def get_permutation():
    return np.random.permutation(len(file_list))
def get_samples(index, i):
    data = torch.FloatTensor(batch_size, 3, imageH, imageW)
    j = 0
    for i in range(batch_size*i, batch_size*(i+1)):
        data[j] = transforms.ToTensor()(Image.open(file_list[index[i]]))
        j += 1
    return data

# Network

In [4]:
#Generator
class G_CNN(nn.Module):
    def __init__(self, embedding_size, kernel_num):
        super(G_CNN, self).__init__()
        self.kernel_num = kernel_num
        FC_start_size = 8 * 8 * self.kernel_num
        self.start = nn.Linear(embedding_size, FC_start_size)
        #all layers
        self.main_arch = nn.Sequential(
            #repeated units
            nn.Conv2d(in_channels=kernel_num, out_channels=kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.Conv2d(in_channels=kernel_num, out_channels=kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.UpsamplingNearest2d(scale_factor=2),
            #repeated units
            nn.Conv2d(in_channels=kernel_num, out_channels=kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.Conv2d(in_channels=kernel_num, out_channels=kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.UpsamplingNearest2d(scale_factor=2),
            #repeated units
            nn.Conv2d(in_channels=kernel_num, out_channels=kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.Conv2d(in_channels=kernel_num, out_channels=kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.UpsamplingNearest2d(scale_factor=2),
            #last layer
            nn.Conv2d(in_channels=kernel_num, out_channels=3, kernel_size=3, padding=1)
        )

    def forward(self, in_data):
        in_data = self.start(in_data)
        in_data = in_data.view(-1, self.kernel_num, 8, 8)
        output = self.main_arch(in_data)
        return output

In [5]:
#Discriminator
class D_CNN(nn.Module):
    def __init__(self, embedding_size, base_kernel_num):
        super(D_CNN, self).__init__()
        self.base_kernel_num = base_kernel_num
        
        #encoder w/o flatten and FC
        self.encoder_wo_FC = nn.Sequential(
            nn.Conv2d(in_channels=3, out_channels=base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            
            nn.Conv2d(in_channels=base_kernel_num, out_channels=base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.Conv2d(in_channels=base_kernel_num, out_channels=base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            #downsampling
            nn.Conv2d(in_channels=base_kernel_num, out_channels=base_kernel_num, kernel_size=3, padding=1, stride=2),
            nn.ELU(inplace=True),
            
            nn.Conv2d(in_channels=base_kernel_num, out_channels=2*base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.Conv2d(in_channels=2*base_kernel_num, out_channels=2*base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            #downsampling
            nn.Conv2d(in_channels=2*base_kernel_num, out_channels=2*base_kernel_num, kernel_size=3, padding=1, stride=2),
            nn.ELU(inplace=True),
            
            nn.Conv2d(in_channels=2*base_kernel_num, out_channels=3*base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.Conv2d(in_channels=3*base_kernel_num, out_channels=3*base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            #downsampling
            nn.Conv2d(in_channels=3*base_kernel_num, out_channels=3*base_kernel_num, kernel_size=3, padding=1, stride=2),
            nn.ELU(inplace=True),
            
            nn.Conv2d(in_channels=3*base_kernel_num, out_channels=4*base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.Conv2d(in_channels=4*base_kernel_num, out_channels=4*base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True)
        )
        FC_mid1_size = 8 * 8 * 4 * self.base_kernel_num
        FC_mid2_size = 8 * 8 * self.base_kernel_num
        self.FC1 = nn.Linear(FC_mid1_size, embedding_size)
        self.FC2 = nn.Linear(embedding_size, FC_mid2_size)
        self.decoder_wo_FC = nn.Sequential(
            #repeated units
            nn.Conv2d(in_channels=base_kernel_num, out_channels=base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.Conv2d(in_channels=base_kernel_num, out_channels=base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.UpsamplingNearest2d(scale_factor=2),
            #repeated units
            nn.Conv2d(in_channels=base_kernel_num, out_channels=base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.Conv2d(in_channels=base_kernel_num, out_channels=base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.UpsamplingNearest2d(scale_factor=2),
            #repeated units
            nn.Conv2d(in_channels=base_kernel_num, out_channels=base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.Conv2d(in_channels=base_kernel_num, out_channels=base_kernel_num, kernel_size=3, padding=1),
            nn.ELU(inplace=True),
            nn.UpsamplingNearest2d(scale_factor=2),
            #last layer
            nn.Conv2d(in_channels=base_kernel_num, out_channels=3, kernel_size=3, padding=1)
        )

    def forward(self, in_data):
        in_data = self.encoder_wo_FC(in_data)
        
        in_data = in_data.view(-1, 4*self.base_kernel_num * 8 * 8);
        in_data = self.FC1(in_data)
        in_data = self.FC2(in_data)
        in_data = in_data.view(-1, self.base_kernel_num, 8, 8)
        
        output = self.decoder_wo_FC(in_data)
        return output

In [6]:
noise = (torch.FloatTensor(batch_size, noise_dim))
input = (torch.FloatTensor(batch_size, 3, imageH, imageW))
GNet = G_CNN(embedding_size=noise_dim, kernel_num=global_kernel_num)
DNet = D_CNN(embedding_size=noise_dim, base_kernel_num=global_kernel_num)
criterion = nn.L1Loss()

In [7]:
#move to gpu
if if_cuda:
    GNet.cuda()
    DNet.cuda()
    noise = noise.cuda()
    input = input.cuda()
    criterion = criterion.cuda()

In [8]:
optimG = optim.Adam(GNet.parameters(), lr=learning_rate)
optimD = optim.Adam(DNet.parameters(), lr=learning_rate)
noise = Variable(noise)
input = Variable(input)

In [9]:
lossG_rec = []
lossD_rec = []

In [10]:
def train_model(n_epochs):
    k = k_init
    for i in range(n_epochs):
        index = get_permutation()
        for j in range(int(mini_batch)):
            #Discriminator
            #real
            DNet.zero_grad()
            #input.data.copy_(data)
            input.data.copy_(get_samples(index, j))
            output_real = DNet(input)
            D_real_loss = criterion(output_real, input)
            
            
            #fake
            noise.data.normal_(0, 1)
            fake = GNet(noise)
            #detach from Generator maybe save computing grad from G
            output_fake = DNet(fake)
            D_fake_loss = criterion(output_fake, fake.detach())
            
            #update
            D_loss = D_real_loss - k * D_fake_loss
            D_loss.backward(retain_variables=True)
            optimD.step()
            
            #generator
            DNet.zero_grad()
            GNet.zero_grad()
#             fake = GNet(noise)
#             G_out = DNet(fake)
#             G_loss = criterion(G_out, fake.detach()) #target can should make require_grad false
            G_loss = D_fake_loss
            G_loss.backward(retain_variables=True)
            optimG.step()
            if j == 0:
                for s in range(batch_size):
                    img_sample = transforms.ToPILImage()(fake[s].data.cpu())
                    img_sample.save('./model/'+str(i) + '_' + str(j) + '_' + str(s) + '.jpg')
            #update k
            k += lam * (gamma * D_real_loss.data.cpu()[0] - D_fake_loss.data.cpu()[0])
            k = max(min(1, k), 0)
            if j % 20 == 0:
                print 'Epoch %d' %i, 'Iter %d'%j, '/', int(mini_batch)
                print 'Epoch [%d/%d] Loss of G: %.4f \tLoss of D: %.4f \tk: %.4f' % (i, n_epochs, G_loss.data[0], D_loss.data[0], k)
            lossD_rec.append(D_loss.data[0])
            lossG_rec.append(G_loss.data[0])
        # do checkpointing
        torch.save(GNet.state_dict(), './model/netG_epoch_%d.pth' % (i))
        torch.save(DNet.state_dict(), './model/netD_epoch_%d.pth' % (i))

In [11]:
train_model(n_epochs)
np.save(np.array('DLoss.npy', lossD_rec))
np.save(np.array('GLoss.npy', lossG_rec))

Epoch 0 Iter 0 / 10173
Epoch [0/5] Loss of G: 0.0210 	Loss of D: 0.4257 	k: 0.0002
Epoch 0 Iter 20 / 10173
Epoch [0/5] Loss of G: 0.6551 	Loss of D: 0.2053 	k: 0.0000
Epoch 0 Iter 40 / 10173
Epoch [0/5] Loss of G: 9.6285 	Loss of D: 0.2069 	k: 0.0000


KeyboardInterrupt: 

In [None]:
DNet(Variable(get_samples(get_permutation(), 3)))