In [106]:
import random
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision
from torchvision.utils import save_image
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
import math
import matplotlib.pyplot as plt
import numpy as np

In [149]:
def normal_init(m, mean, std):
    if isinstance(m, (nn.Linear, nn.Conv2d, nn.BatchNorm2d, nn.BatchNorm1d)):
        m.weight.data.normal_(mean, std)
        if m.bias.data is not None:
            m.bias.data.zero_()

class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.fc = nn.Sequential(
            nn.ConvTranspose2d(100, 4*4*1024, 1)
        )
        self.conv = nn.Sequential(
            nn.ConvTranspose2d(1024, 512, 4,2,1),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
           nn.ConvTranspose2d(512, 256, 4, 2, 1),
          nn.BatchNorm2d(256),
            nn.ReLU(True),
          nn.ConvTranspose2d(256 , 128, 4, 2, 1),
          nn.BatchNorm2d(128),
            nn.ReLU(True),
           nn.ConvTranspose2d(128 , 3, 4, 2, 1),
            nn.Tanh()

        )
        """
        for m in self._modules:
            normal_init(self._modules[m], 0, 1)
            """

    def forward(self, x):
        x = x.view(x.size(0), 100, 1, 1)
        out = self.fc(x)
        out = out.view(-1, 1024, 4, 4)
        out  = self.conv(out)
        return out

G = Generator()
G = G.cuda()

In [154]:
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.encoder = nn.Sequential(
            nn.Conv2d(3, 64, 4, 2, 1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 64 * 2, 4, 2, 1),
            nn.BatchNorm2d(64 * 2),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64 * 2, 64* 4, 4, 2, 1),
            nn.BatchNorm2d(64 * 4),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64* 4, 64* 8, 4, 2, 1),
            nn.BatchNorm2d(64* 8),
            nn.LeakyReLU(0.2)
        )
        
        self.decoder = nn.Sequential(
        nn.ConvTranspose2d(64 * 8, 64 * 4, 4, 2, 1),
        nn.BatchNorm2d(64 * 4),
        nn.LeakyReLU(0.2),
        nn.ConvTranspose2d(64* 4, 64* 2, 4, 2, 1),
        nn.BatchNorm2d(64* 2),
        nn.LeakyReLU(0.2),
        nn.ConvTranspose2d(64* 2, 64, 4, 2, 1),
        nn.BatchNorm2d(64),
        nn.LeakyReLU(0.2),
        nn.ConvTranspose2d(64, 3, 4, 2, 1),
        nn.LeakyReLU(0.2),
        )
        """
        for m in self._modules:
            normal_init(self._modules[m], 0, 1)"""

    def forward(self, x):
        en =self.encoder(x)
        de = self.decoder(en)
        return de,en.view(x.size(0),-1)

D = Discriminator()
D = D.cuda()

In [166]:
criterion_MSE = nn.MSELoss()
criterion_MSE = criterion_MSE.cuda()
fixed_noise = torch.randn(10,100)
fixed_noise = Variable(fixed_noise.cuda())
optimizerD = optim.Adam(D.parameters(), lr=0.0002)
optimizerG = optim.Adam(G.parameters(), lr=0.0002)

def repelling_regularizer(s1, s2):
    n = s1.size(0)
    s1 = F.normalize(s1, p=2, dim=1)
    s2 = F.normalize(s2, p=2, dim=1)
    S1 = s1.unsqueeze(1).repeat(1, s2.size(0), 1)
    S2 = s2.unsqueeze(0).repeat(s1.size(0), 1, 1)
    f_PT = S1.mul(S2).sum(-1).pow(2)
    f_PT = torch.tril(f_PT, -1).sum().mul(2).div((n*(n-1)))
    return f_PT
#print(fixed_noise)

In [167]:
transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))])
dataset = torchvision.datasets.ImageFolder(root='./celebA',transform=transform)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=100,shuffle=True, num_workers=2)

In [168]:
margin =20
for epoch in range(20):
    for i, data in enumerate(dataloader):

        optimizerD.zero_grad()
        real_cpu, _ = data
        real_cpu = real_cpu.cuda()
        real_cpu = Variable(real_cpu)
        # train with real
        output,_ = D(real_cpu)
        energyD_real = F.mse_loss(output,real_cpu)
        # generate fake
        noise =  torch.randn(100,100)
        noise =  Variable(noise.cuda())
        fake = G(noise)

        # train with fake
        output,_ = D(fake)
        energyD_fake = F.mse_loss(output,fake)  # score on fake
        errD_fake = margin - energyD_fake
        errD_fake = errD_fake.clamp(min=0)
        errD = (energyD_real + errD_fake)# score fore supervision
        errD.backward(retain_graph=True)
        optimizerD.step()

        optimizerG.zero_grad()

        noise =  torch.randn(100,100)
        noise =  Variable(noise.cuda())
        output,hidden = D(fake)
        G_loss_PT = repelling_regularizer(hidden,hidden)
        errG = F.mse_loss(output,fake) + 0.1*G_loss_PT
        errG.backward()
        optimizerG.step()
        if i%100 ==0:
            print('[%d/%d][%d/%d] Loss_D: %.4f Loss_G: %.4f'
                  % (epoch, 30, i, len(dataloader),
                     errD.data[0], errG.data[0]))
    test_images = G(fixed_noise)
    #print(test_images)
    save_image(test_images.data,'./samples/EBGAN/output.png',nrow=5)    



[0/30][0/128] Loss_D: 0.0056 Loss_G: 30893.8711
[0/30][100/128] Loss_D: 0.0076 Loss_G: 2046.1187
[1/30][0/128] Loss_D: 0.0057 Loss_G: 2272.3025
[1/30][100/128] Loss_D: 0.0058 Loss_G: 2099.6426
[2/30][0/128] Loss_D: 0.0057 Loss_G: 1577.9468
[2/30][100/128] Loss_D: 0.0052 Loss_G: 2260.7549
[3/30][0/128] Loss_D: 0.0058 Loss_G: 1513.0361
[3/30][100/128] Loss_D: 0.0067 Loss_G: 1321.8615
[4/30][0/128] Loss_D: 0.0051 Loss_G: 1428.2266
[4/30][100/128] Loss_D: 0.0067 Loss_G: 1097.3239
[5/30][0/128] Loss_D: 0.0054 Loss_G: 1325.9751
[5/30][100/128] Loss_D: 0.0051 Loss_G: 1569.0348
[6/30][0/128] Loss_D: 0.0098 Loss_G: 1245.1099
[6/30][100/128] Loss_D: 0.0222 Loss_G: 1051.3688
[7/30][0/128] Loss_D: 0.0048 Loss_G: 1393.5771
[7/30][100/128] Loss_D: 0.0051 Loss_G: 958.7703
[8/30][0/128] Loss_D: 0.0053 Loss_G: 727.2678
[8/30][100/128] Loss_D: 0.0048 Loss_G: 997.7540
[9/30][0/128] Loss_D: 0.0061 Loss_G: 1206.4103
[9/30][100/128] Loss_D: 0.0047 Loss_G: 897.0942


KeyboardInterrupt: 

Process Process-304:
Traceback (most recent call last):
  File "/usr/lib/python3.6/multiprocessing/process.py", line 258, in _bootstrap
    self.run()
  File "/usr/lib/python3.6/multiprocessing/process.py", line 93, in run
    self._target(*self._args, **self._kwargs)
  File "/home/infero/.local/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 96, in _worker_loop
    r = index_queue.get(timeout=MANAGER_STATUS_CHECK_INTERVAL)
  File "/usr/lib/python3.6/multiprocessing/queues.py", line 104, in get
    if not self._poll(timeout):
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 257, in poll
    return self._poll(timeout)
Process Process-303:
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 414, in _poll
    r = wait([self], timeout)
  File "/usr/lib/python3.6/multiprocessing/connection.py", line 911, in wait
    ready = selector.select(timeout)
  File "/usr/lib/python3.6/selectors.py", line 376, in select
    fd_event_list = self._poll.poll(t