# Self GAN

<table class="tfo-notebook-buttons" align="left" >
 <td>
    <a target="_blank" href="https://colab.research.google.com/github/HighCWu/SelfGAN/blob/master/implementations/gan/self_gan.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/HighCWu/SelfGAN/blob/master/implementations/gan/self_gan.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

## Prepare

In [0]:
import argparse
import os
import sys
import numpy as np
import math

import torchvision.transforms as transforms
from torchvision.utils import save_image

from torch.utils.data import DataLoader
from torchvision import datasets
from torch.autograd import Variable

import torch.nn as nn
import torch.nn.functional as F
import torch

os.makedirs('images', exist_ok=True)
os.makedirs('images_normal', exist_ok=True)

In [0]:
parser = argparse.ArgumentParser()
parser.add_argument('--n_epochs', type=int, default=200, help='number of epochs of training')
parser.add_argument('--batch_size', type=int, default=64, help='size of the batches')
parser.add_argument('--lr', type=float, default=2e-4, help='adam: learning rate')
parser.add_argument('--b1', type=float, default=0.5, help='adam: decay of first order momentum of gradient')
parser.add_argument('--b2', type=float, default=0.999, help='adam: decay of first order momentum of gradient')
parser.add_argument('--n_cpu', type=int, default=8, help='number of cpu threads to use during batch generation')
parser.add_argument('--latent_dim', type=int, default=100, help='dimensionality of the latent space')
parser.add_argument('--img_size', type=int, default=28, help='size of each image dimension')
parser.add_argument('--channels', type=int, default=1, help='number of image channels')
parser.add_argument('--sample_interval', type=int, default=400, help='interval betwen image samples')

opt,_ = parser.parse_known_args()
print(opt)

In [0]:
img_shape = (opt.channels, opt.img_size, opt.img_size)

cuda = True if torch.cuda.is_available() else False

In [0]:
class Generator(nn.Module):
    def __init__(self):
        super(Generator, self).__init__()

        def block(in_feat, out_feat, normalize=True):
            layers = [nn.Linear(in_feat, out_feat)]
            if normalize:
                layers.append(nn.BatchNorm1d(out_feat, 0.8))
            layers.append(nn.LeakyReLU(0.2, inplace=True))
            return layers

        self.model = nn.Sequential(
            *block(opt.latent_dim, 128, normalize=False),
            *block(128, 256),
            *block(256, 512),
            *block(512, 1024),
            nn.Linear(1024, int(np.prod(img_shape))),
            nn.Tanh()
        )

    def forward(self, z):
        img = self.model(z)
        img = img.view(img.size(0), *img_shape)
        return img

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()

        self.model = nn.Sequential(
            nn.Linear(int(np.prod(img_shape)), 512),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )

    def forward(self, img):
        img_flat = img.view(img.size(0), -1)
        validity = self.model(img_flat)

        return validity
      
class SelfGAN(nn.Module):
    def __init__(self):
        super(SelfGAN, self).__init__()

        # Initialize generator and discriminator
        self.generator = Generator()
        self.discriminator = Discriminator()

    def forward(self, z, real_img, fake_img):
        gen_img = self.generator(z)
        validity_gen = self.discriminator(gen_img)
        validity_real = self.discriminator(real_img)
        validity_fake = self.discriminator(fake_img)

        return gen_img, validity_gen, validity_real, validity_fake

## SelfGAN Part

In [0]:
# Loss function
adversarial_loss = torch.nn.BCELoss()
shard_adversarial_loss = torch.nn.BCELoss(reduction='none')

# Initialize SelfGAN model
self_gan = SelfGAN()

if cuda:
    self_gan.cuda()
    adversarial_loss.cuda()
    shard_adversarial_loss.cuda()

# Configure data loader
os.makedirs('data/mnist', exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('data/mnist', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                   ])),
    batch_size=opt.batch_size, shuffle=True, drop_last=True)

# Optimizers
optimizer = torch.optim.Adam(self_gan.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

last_imgs = Tensor(opt.batch_size, *img_shape)*0.0

### Standard performance on the GPU

In [0]:
# ----------
#  Training
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train SelfGAN
        # -----------------

        optimizer.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        gen_imgs, validity_gen, validity_real, validity_fake = self_gan(z, real_imgs, last_imgs)

        # Loss measures generator's ability to fool the discriminator and measure discriminator's ability to classify real from generated samples at the same time
        gen_loss = adversarial_loss(validity_gen, valid)
        real_loss = adversarial_loss(validity_real, valid)
        fake_loss = adversarial_loss(validity_fake, fake)
        v_g = 1 - torch.mean(validity_gen)
        v_f = torch.mean(validity_fake)
        s_loss = (real_loss + v_g*gen_loss*0.1 + v_f*fake_loss*0.9) / 2

        s_loss.backward()
        optimizer.step()
        
        last_imgs = gen_imgs.detach()
        
        sys.stdout.flush()
        print ("\r[Epoch %d/%d] [Batch %d/%d] [S loss: %f  R loss: %f  F loss: %f  G loss: %f]" % (epoch, opt.n_epochs, i, len(dataloader),
                                                            s_loss.item(), real_loss.item(), fake_loss.item(), gen_loss.item()),
              end='')

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], 'images/%d.png' % batches_done, nrow=5, normalize=True)

### Running on the GPU with similar performance of running on the TPU (Maybe)

In [0]:
# ----------
#  Training
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train SelfGAN
        # -----------------

        optimizer.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))
        
        s = opt.batch_size//8
        for k in range(8):
          # Generate a batch of images
          gen_imgs, validity_gen, validity_real, validity_fake = self_gan(z[k*s:k*s+s], real_imgs[k*s:k*s+s], last_imgs[k*s:k*s+s])

          # Loss measures generator's ability to fool the discriminator and measure discriminator's ability to classify real from generated samples at the same time
          gen_loss = shard_adversarial_loss(validity_gen, valid[k*s:k*s+s])
          real_loss = shard_adversarial_loss(validity_real, valid[k*s:k*s+s])
          fake_loss = shard_adversarial_loss(validity_fake, fake[k*s:k*s+s])
          v_g = 1 - torch.mean(validity_gen)
          v_r = 1 - torch.mean(validity_real)
          v_f = torch.mean(validity_fake)
          v_sum = v_g + v_r + v_f
          s_loss = v_r*real_loss/v_sum + v_g*gen_loss/v_sum + v_f*fake_loss/v_sum
          
          gen_loss = torch.mean(gen_loss)
          real_loss = torch.mean(real_loss)
          fake_loss = torch.mean(fake_loss)
          s_loss = torch.mean(s_loss)

          s_loss.backward()
          last_imgs[k*s:k*s+s] = gen_imgs.detach()
          
        optimizer.step()
        
        sys.stdout.flush()
        print ("\r[Epoch %d/%d] [Batch %d/%d] [S loss: %f  R loss: %f  F loss: %f  G loss: %f]" % (epoch, opt.n_epochs, i, len(dataloader),
                                                            s_loss.item(), real_loss.item(), fake_loss.item(), gen_loss.item()),
              end='')

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(last_imgs.data[:25], 'images/%d.png' % batches_done, nrow=5, normalize=True)

## Normal GAN Part

In [0]:
# Loss function
adversarial_loss = torch.nn.BCELoss()

# Initialize generator and discriminator
generator = Generator()
discriminator = Discriminator()

if cuda:
    generator.cuda()
    discriminator.cuda()
    adversarial_loss.cuda()

# Configure data loader
os.makedirs('data/mnist', exist_ok=True)
dataloader = torch.utils.data.DataLoader(
    datasets.MNIST('data/mnist', train=True, download=True,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
                   ])),
    batch_size=opt.batch_size, shuffle=True)

# Optimizers
optimizer_G = torch.optim.Adam(generator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))
optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=opt.lr, betas=(opt.b1, opt.b2))

Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor

In [0]:
# ----------
#  Training
# ----------

for epoch in range(opt.n_epochs):
    for i, (imgs, _) in enumerate(dataloader):

        # Adversarial ground truths
        valid = Variable(Tensor(imgs.size(0), 1).fill_(1.0), requires_grad=False)
        fake = Variable(Tensor(imgs.size(0), 1).fill_(0.0), requires_grad=False)

        # Configure input
        real_imgs = Variable(imgs.type(Tensor))

        # -----------------
        #  Train Generator
        # -----------------

        optimizer_G.zero_grad()

        # Sample noise as generator input
        z = Variable(Tensor(np.random.normal(0, 1, (imgs.shape[0], opt.latent_dim))))

        # Generate a batch of images
        gen_imgs = generator(z)

        # Loss measures generator's ability to fool the discriminator
        g_loss = adversarial_loss(discriminator(gen_imgs), valid)

        g_loss.backward()
        optimizer_G.step()

        # ---------------------
        #  Train Discriminator
        # ---------------------

        optimizer_D.zero_grad()

        # Measure discriminator's ability to classify real from generated samples
        real_loss = adversarial_loss(discriminator(real_imgs), valid)
        fake_loss = adversarial_loss(discriminator(gen_imgs.detach()), fake)
        d_loss = (real_loss + fake_loss) / 2

        d_loss.backward()
        optimizer_D.step()
        
        sys.stdout.flush()
        print ("\r[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch, opt.n_epochs, i, len(dataloader),
                                                            d_loss.item(), g_loss.item()), 
               end='')

        batches_done = epoch * len(dataloader) + i
        if batches_done % opt.sample_interval == 0:
            save_image(gen_imgs.data[:25], 'images_normal/%d.png' % batches_done, nrow=5, normalize=True)