#  CE-40959: Deep Learning

## Homework 5 - 2:  EBGAN

The goal is to train a GAN with an auto-encoder as its discriminator.
For further information read the [paper of EBGAN](https://arxiv.org/abs/1609.03126).

Good luck

In [0]:
import numpy as np

In [0]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms

In [0]:
# MNIST Dataset
original_train_dataset = datasets.MNIST(root='./mnist_data/', train=True, transform=transforms.ToTensor(), download=True)
original_test_dataset = datasets.MNIST(root='./mnist_data/', train=False, transform=transforms.ToTensor(), download=True)

In [0]:
CUDA = True
BATCH_SIZE = 64

In [0]:
# Define Train loader
train_tensors = original_train_dataset.data.float() / 255
test_tensors = original_test_dataset.data.float() / 255

train_dataset = torch.utils.data.TensorDataset(train_tensors, original_train_dataset.targets)
test_dataset = torch.utils.data.TensorDataset(test_tensors, original_test_dataset.targets)

train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [0]:
import matplotlib.pyplot as plt
%matplotlib inline

In [0]:
def show(image_batch, rows=1):
    # Set Plot dimensions
    cols = np.ceil(image_batch.shape[0] / rows)
    plt.rcParams['figure.figsize'] = (0.0 + cols, 0.0 + rows) # set default size of plots
    
    for i in range(image_batch.shape[0]):
        plt.subplot(rows, cols, i + 1)
        plt.imshow(image_batch[i], cmap="gray", vmin=0, vmax=1)
        plt.axis('off')
    plt.show()

---

In [8]:
class AutoEncoderMSE(nn.Module):
    def __init__(self, input_dim, encoder_dims, decoder_dims, dropout_rate=0.5):
        super(AutoEncoderMSE, self).__init__()
        
        self.input_dim = input_dim
        self.input_dropout = nn.Dropout(p=dropout_rate)      
        
        # Encoder part
        encoder_layers = []
        for i in range(len(encoder_dims)):
          if i == 0:
            encoder_layers.append(nn.Linear(input_dim, encoder_dims[i]))
          else:
            encoder_layers.append(nn.Linear(encoder_dims[i - 1], encoder_dims[i]))
          encoder_layers.append(nn.LeakyReLU(0.2))

        self.encoder = nn.Sequential(*encoder_layers)
        
        last_encoder_dim = ([input_dim] + encoder_dims)[-1]

        # Decoder part
        decoder_layers = []
        for i in range(len(decoder_dims)):
          if i == 0:
            decoder_layers.append(nn.Linear(encoder_dims[-1], decoder_dims[i]))
          else:
            decoder_layers.append(nn.Linear(decoder_dims[i - 1], decoder_dims[i]))
          decoder_layers.append(nn.LeakyReLU(0.2))
        decoder_layers.append(nn.Linear(decoder_dims[-1], input_dim))
        
        self.decoder = nn.Sequential(*decoder_layers)
        
        self.MSE = nn.MSELoss(reduction='sum', reduce=False)

    def forward(self, x):
        dropout_output = self.input_dropout(x)
        encoder_output = self.encoder(dropout_output)
        output = self.decoder(encoder_output)
        mse = torch.pow(x - output, 2) / x.shape[1]
        return {"reconstruct": output, "loss": mse}

      
discriminator = AutoEncoderMSE(784, [256, 128, 64], [128, 256], dropout_rate=0.5)



In [0]:
generator = nn.Sequential(
    nn.Linear(128, 128),
    nn.LeakyReLU(0.2),
    nn.Linear(128, 256),
    nn.Dropout(),
    nn.LeakyReLU(0.2),
    nn.Linear(256, 512),
    nn.LeakyReLU(0.2),
    nn.Linear(512, 784),
    nn.Sigmoid()
)

In [0]:
if CUDA:
  discriminator.cuda()
  generator.cuda()

In [0]:
LEARNING_RATE_D = 0.0002
LEARNING_RATE_G = 0.0002

opt_D = optim.Adam(discriminator.parameters(), lr=LEARNING_RATE_D)
opt_G = optim.Adam(generator.parameters(), lr=LEARNING_RATE_G)

In [0]:
N_EPOCH = 100

In [13]:
m = 16

for epoch in range(N_EPOCH):
    for i, (img, label) in enumerate(train_loader):
        img = img.flatten(start_dim=1)

        real_img = img
        if CUDA:
            real_img = real_img.cuda()

        z = torch.randn(img.shape[0], 128)
        if CUDA:
            z = z.cuda()
        fake_img = generator(z)

        
        # Discriminator Part
        opt_D.zero_grad()
        disc_out = discriminator(real_img)
        reconstructed = disc_out["reconstruct"]
        fake_tensor = torch.sum(discriminator(fake_img)["loss"], dim=1)
        
        loss_d = torch.sum(disc_out["loss"]) + torch.sum(torch.max(m - fake_tensor, torch.zeros_like(fake_tensor)))

        loss_d.backward(retain_graph=True)
        opt_D.step()
        
        # Generator Part
        opt_G.zero_grad()
        loss_g = torch.sum(discriminator(fake_img)["loss"])

        loss_g.backward()
        opt_G.step()
        

    
    print("epoch: {} \t last batch loss D: {} \t last batch loss G: {}".format(epoch, loss_d.item(), loss_g.item()))
    imgs_to_show = fake_img[:30].view(-1, 28, 28).detach().cpu().numpy()
    show(imgs_to_show, rows=3)


Output hidden; open in https://colab.research.google.com to view.