# MNIST WGAN with feature matching

In [35]:
import numpy as np

import torch
from torch import nn
from torch import autograd
import torch.nn.functional as F
import torch.optim as optim

import torchvision.datasets as dset
import torchvision.transforms as transforms

from matplotlib import pyplot as plt
import matplotlib.gridspec as gridspec
import os

%matplotlib inline

use_cuda = torch.cuda.is_available()
if use_cuda:
    gpu = 0

In [36]:
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)

Extracting ../../MNIST_data/train-images-idx3-ubyte.gz
Extracting ../../MNIST_data/train-labels-idx1-ubyte.gz
Extracting ../../MNIST_data/t10k-images-idx3-ubyte.gz
Extracting ../../MNIST_data/t10k-labels-idx1-ubyte.gz


In [37]:
batch_size = 32
epochs = 100000
h_dim = 128
z_dim = 10
image_dim = mnist.train.images.shape[1]
target_dim = mnist.train.labels.shape[1]
learning_rate = 1e-3

In [38]:
image_dim

784

In [39]:
target_dim

10

# MODELLING

## Generator Network

In [40]:
class Generator(nn.Module):
    def __init__(self, z_dim, h_dim, image_dim):
        super(Generator, self).__init__()
        self.G = nn.Sequential(
            nn.Linear(z_dim, h_dim),
            nn.ReLU(),
            nn.Linear(h_dim, image_dim),
            nn.Sigmoid()
        )
    
    def forward(self, x):
        return self.G(x)

## Discriminator Network

In [41]:
class Discriminator(nn.Module):
    def __init__(self, h_dim, image_dim):
        super(Discriminator, self).__init__()
        self.D = nn.Sequential(
            nn.Linear(image_dim, h_dim),
            nn.ReLU()
        )
        self.out = nn.Linear(h_dim, 1)
    
    def forward(self, x):
        feat = self.D(x)
        out = self.out(feat)
        return feat, out

## Custom JSD Loss

In [45]:
class JSDLoss(nn.Module):
    def __init__(self):
        super(JSDLoss,self).__init__()

    def forward(self, f_real, f_synt):
        assert f_real.size()[1] == f_synt.size()[1]

        f_num_features = f_real.size()[1]
        batch_size = f_real.size()[0]
        identity = autograd.Variable(torch.eye(f_num_features)*0.1)

        if use_cuda:
            identity = identity.cuda(gpu)

        f_real_mean = torch.mean(f_real, 0, keepdim=True)
        f_synt_mean = torch.mean(f_synt, 0, keepdim=True)

        dev_f_real = f_real - f_real_mean.expand(batch_size,f_num_features) # batch_size x num_feat
        dev_f_synt = f_synt - f_synt_mean.expand(batch_size,f_num_features) # batch_size x num_feat

        f_real_xx = torch.mm(torch.t(dev_f_real), dev_f_real) # num_feat x batch_size * batch_size x num_feat = num_feat x num_feat
        f_synt_xx = torch.mm(torch.t(dev_f_synt), dev_f_synt) # num_feat x batch_size * batch_size x num_feat = num_feat x num_feat

        cov_mat_f_real = f_real_xx / (batch_size-1) - torch.mm(f_real_mean, torch.t(f_real_mean)) + identity # num_feat x num_feat
        cov_mat_f_synt = f_synt_xx / (batch_size-1) - torch.mm(f_synt_mean, torch.t(f_synt_mean)) + identity # num_feat x num_feat

        cov_mat_f_real_inv = torch.inverse(cov_mat_f_real)
        cov_mat_f_synt_inv = torch.inverse(cov_mat_f_synt)

        temp1 = torch.trace(torch.add(torch.mm(cov_mat_f_synt_inv, torch.t(cov_mat_f_real)), torch.mm(cov_mat_f_real_inv, torch.t(cov_mat_f_synt))))
#         temp1 = temp1.view(1,1)
        print("temp1: ", temp1)
        temp2 = torch.mm(torch.mm((f_synt_mean - f_real_mean), (cov_mat_f_synt_inv + cov_mat_f_real_inv)), torch.t(f_synt_mean - f_real_mean))
        print("temp2: ", temp2)
        loss_g = temp1 + temp2

        return loss_g

In [43]:
def reset_grad():
    G.zero_grad()
    D.zero_grad()

G = Generator(z_dim, h_dim, image_dim)
D = Discriminator(h_dim, image_dim)
G_optimizer = optim.Adam(G.parameters(), lr=learning_rate)
D_optimizer = optim.Adam(D.parameters(), lr=learning_rate)
jsdloss = JSDLoss()

if use_cuda:
    jsdloss = jsdloss.cuda(gpu)
    D = D.cuda(gpu)
    G = G.cuda(gpu)

In [46]:
for epoch in range(epochs):
    for _ in range(5):
        for p in D.parameters():
            p.requires_grad = True
        
        G = G.eval()
        D = D.train()
        D.zero_grad()
        z = autograd.Variable(torch.randn(batch_size, z_dim))
        X, _ = mnist.train.next_batch(batch_size)
        X = autograd.Variable(torch.from_numpy(X))

        G_sample = G(z)
        D_real_feat, D_real_out = D(X)
        D_fake_feat, D_fake_out = D(G_sample)

        D_loss_GAN = -(torch.mean(D_real_out) - torch.mean(D_fake_out))
        D_loss_feat_matching = jsdloss(D_real_feat, D_fake_feat)
        
        D_loss = D_loss_GAN + D_loss_feat_matching

        D_loss.backward(retain_graph=True)
        D_optimizer.step()

        for p in D.parameters():
            p.data.clamp_(-0.01, 0.01)

        reset_grad()
    
    for p in D.parameters():
        p.requires_grad = False
    
    G = G.train()
    D = D.eval()
    G.zero_grad()

    X, _ = mnist.train.next_batch(batch_size)
    X = autograd.Variable(torch.from_numpy(X))
    z = autograd.Variable(torch.randn(batch_size, z_dim))

    G_sample = G(z)
    D_fake = D(G_sample)

    G_loss = jsdloss(D_real_feat, D_fake_feat)

    G_loss.backward(retain_graph=True)
    G_optimizer.step()

    reset_grad()
    
    if epoch % 100 == 0:
        print('Iter-{}; D_loss: {}; G_loss: {}'
              .format(epoch, D_loss.data.numpy(), G_loss.data.numpy()))
        samples = G(z).data.numpy()[:16]

        fig = plt.figure(figsize=(4, 4))
        gs = gridspec.GridSpec(4, 4)
        gs.update(wspace=0.05, hspace=0.05)

        for i, sample in enumerate(samples):
            ax = plt.subplot(gs[i])
            plt.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])
            ax.set_aspect('equal')
            plt.imshow(sample.reshape(28, 28), cmap='Greys_r')

        if not os.path.exists('out/'):
            os.makedirs('out/')

        plt.savefig('out/{}.png'.format(str(epoch).zfill(3)), bbox_inches='tight')
        epoch += 1
        plt.close(fig)

Variable containing:
 256.1435
[torch.FloatTensor of size 1]

Variable containing:
 256.3137
[torch.FloatTensor of size 1]

Variable containing:
 256.0270
[torch.FloatTensor of size 1]

Variable containing:
 256.2185
[torch.FloatTensor of size 1]

Variable containing:
 256.0122
[torch.FloatTensor of size 1]

Variable containing:
 256.0122
[torch.FloatTensor of size 1]

Iter-0; D_loss: [[256.03543]]; G_loss: [[256.03525]]
Variable containing:
 256.0945
[torch.FloatTensor of size 1]

Variable containing:
 256.0633
[torch.FloatTensor of size 1]

Variable containing:
 256.0435
[torch.FloatTensor of size 1]

Variable containing:
 256.0259
[torch.FloatTensor of size 1]

Variable containing:
 256.0709
[torch.FloatTensor of size 1]

Variable containing:
 256.0709
[torch.FloatTensor of size 1]

Variable containing:
 256.0193
[torch.FloatTensor of size 1]

Variable containing:
 256.0497
[torch.FloatTensor of size 1]

Variable containing:
 256.2111
[torch.FloatTensor of size 1]

Variable containi

KeyboardInterrupt: 