# Recurrent GAN:
________

In this notebook we present a recurrent GAN (RGAN) that uses stacked LSTMs for both the generator and the discriminator. It is built to generate MNIST numbers seen as time-series. 

Before delving into the actual code, let's take a look at the math behind the model.

## MNIST as time series:

The MNIST hand-written digit dataset is ubiquitous in machine learning research. Accuracy is high enough to consider the problem *solved*, and generating MNIST digits isn't an issue for traditional GANs. However, sequential generation is less commonly done. To serialize the images, each $28 \times 28$ image is flattened into 784-dimensional vector, which is a sequence we aim to generate with the RGAN.

## Recurrent GAN model:

### Discriminator:

The discriminator is trained to minimize the average negative cross-entropy between its predictions *per time-step* and the labels of the sequence. If we denote by $RNN(X)$ the vector of outputs from an RNN receiving $X$ as input, then the loss is:

$$D_{loss}(X_n, y_n) = - CE(RNN_D(X_n), y_n)$$

For real sequences $y_n$ is a vector of 1s, or 0s for synthetic sequences. 

### Generator:

The objective for the generator is then to trick the discriminator into classifying its outputs as true. It thus wishes to minimize the average negative cross-entropy between the discriminator's predictions on generated synthetic sequences and the *true* label, the vector of 1s (written $\mathbb{1}$).

$$G_{loss}(Z_n) = D_{loss}(RNN_G(Z_n), \mathbb{1})$$

In [None]:
import torch
import torchvision
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import numpy as np
#from skimage.io import imsave
import os
#from tensorboardX import SummaryWriter

from tqdm import tqdm

In [None]:
use_gpu = torch.cuda.is_available()
def gpu(tensor, gpu=use_gpu):
    if gpu:
        return tensor.cuda()
    else:
        return tensor

In [None]:
img_height = 28
img_width = 28
img_size = img_height * img_width

to_train = True
to_restore = False
output_path = "output"

max_epoch = 1000

hg_size = 150
hd_size = 300
z_size = 100
batch_size = 256
seq_size=4
n_hidden=300
tr_data_num=60000;
g_num_layers=2;
d_num_layers=2;

In [None]:
root_dir = "/home/majrda/Scripts/data"

In [None]:
class GaussianNoise(nn.Module):
    def __init__(self, stddev = 0.1):
        super().__init__()
        self.stddev = stddev

    def forward(self, din):
        if self.training:
            return din + torch.autograd.Variable(torch.randn(din.size()).cuda() * self.stddev)
        return din

In [None]:
class Generator(nn.Module):
  def __init__(self):
    super(Generator, self).__init__()

    self.lstm_G = nn.LSTM(input_size = z_size, 
                           hidden_size = n_hidden,
                           num_layers = g_num_layers,
                           bias = True)
    
    self.Lrelu = nn.LeakyReLU()
    
    self.MLP = nn.Linear(n_hidden, img_size)

  def forward(self, x):
    x = x.unsqueeze(1)
    #print(x.size())
    output, _ = self.lstm_G(x)
    output = torch.tanh(self.MLP(self.Lrelu(output)))
    #print(output[0,:,0])
    return output

In [None]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator, self).__init__()
    
    self.lstm_D = nn.LSTM(input_size = img_size, 
                           hidden_size = n_hidden,
                           num_layers = d_num_layers,
                           bias = True)
    
    self.Lrelu = nn.LeakyReLU()
    
    self.MLP = nn.Linear(n_hidden, 1)
    
    self.noise = GaussianNoise(.3)

  def forward(self, x):
    outputs, _ = self.lstm_D(x)
    #print(outputs.size())
    outputs = self.noise(outputs)
    res = self.MLP(self.Lrelu(outputs[:, -1, :]))
    #print(res)
    y_data = torch.sigmoid(res.narrow(0, 0, x[0].shape[0]))
    return y_data

In [None]:
mnist_trainset = datasets.MNIST(root=root_dir, train=True, download=False, transform=transforms.ToTensor())
mnist_testset = datasets.MNIST(root=root_dir, train=False, download=False, transform=transforms.ToTensor())

In [None]:
train_loader = torch.utils.data.DataLoader(dataset=mnist_trainset,
                                       batch_size=batch_size, 
                                       shuffle=True)

test_loader = torch.utils.data.DataLoader(dataset=mnist_testset,
                                        batch_size=batch_size, 
                                        shuffle=False)

In [None]:
lr = 1e-4
nb_epochs = 10

In [None]:
loss_D_epoch = []
loss_G_epoch = []

In [None]:
net_G = gpu(Generator())
net_D = gpu(Discriminator())

optimizer_G = torch.optim.Adam(net_G.parameters(),lr=lr)
optimizer_D = torch.optim.Adam(net_D.parameters(),lr=lr)

criterion = nn.CrossEntropyLoss()

for e in range(nb_epochs):
    print("Epoch ",e)
    loss_G = 0
    loss_D = 0
    for t, real_batch in enumerate(tqdm(train_loader)):
        
        #improving D
        z = gpu(torch.empty(batch_size,z_size).normal_())
        fake_batch = net_G(z)
        #print(fake_batch.shape)
        #print(fake_batch.shape)
        
        D_scores_on_fake = net_D(fake_batch)
        #print(D_scores_on_fake)
        #print(real_batch[0])
        #print(real_batch[0].view(256, 1, 784))
        D_scores_on_real = net_D(real_batch[0].view(real_batch[0].shape[0], 1, 784).cuda())
        #print(D_scores_on_real)
            
        loss = -torch.mean(torch.log(1-D_scores_on_fake[0]) + torch.log(D_scores_on_real[0]))
        
        optimizer_D.zero_grad()
        loss.backward()
        optimizer_D.step()
        loss_D += loss
                    
        # improving G
        z = gpu(torch.empty(batch_size,z_size).normal_())
        fake_batch = net_G(z)
        D_scores_on_fake = net_D(fake_batch)
            
        loss = -torch.mean(torch.log(D_scores_on_fake[0]))
        
        optimizer_G.zero_grad()
        loss.backward()
        optimizer_G.step()
        loss_G += loss
           
    loss_D_epoch.append(loss_D)
    loss_G_epoch.append(loss_G)
    print("Loss on Generator this epoch: {}\nLoss on Discriminator this epoch: {}".format(loss_G, loss_D))

In [None]:
import matplotlib.pyplot as plt

plt.plot(loss_D_epoch)
plt.plot(loss_G_epoch)
plt.show()

In [None]:
z = gpu(torch.empty(batch_size,z_size).normal_())
fake_samples = net_G(z)
fake_data = fake_samples.cpu().data.numpy()

In [None]:
x = fake_data[0, 0]

x = x.reshape(28, 28)

In [None]:
from matplotlib import pyplot as plt
plt.imshow(x, interpolation='nearest')
plt.show()