In [None]:
import torch
import torch.nn as nn
from torchvision.utils import make_grid
from torch.autograd import Variable
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.autograd.set_detect_anomaly(True)
from utils import show_and_save, vae_gan_training, dataloader

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

In [None]:
def weights_init(m):
  classname=m.__class__.__name__
  if classname.find('Conv')!=-1:
    m.weight.data.normal_(0.0,0.02)
  elif classname.find('BatchNorm')!=-1:
    m.weight.data.normal_(1.0,0.02)
    m.bias.data.fill_(0)

## The Encoder

The encoder is the first layer of our VAE-GAN, it is the "Encoder". Where we want to create to pass through series of different layers, until we archieve to create a latent space representation of the input image.
 **The encoder** is composed of 3 convolutional layers, with a kernel size of 5, a stride of 2 and a padding of 2.
  The output of the last convolutional layer is then passed through a fully connected layer (think linear), which will output the mean and the logvar of the latent space representation.
  Each convolutional layer is followed by a batch normalization layer and a LeacklyReLU activation function.
![VAE_GAN](assets/VAE_diagram.png)

A little advice, after each convolution or linear there is a batchnorm and after each batchnorm there is a ReLU but **never after the last layer of the encoder, the last layer of the decoder or the last layer of the discriminator.**

In [None]:
class Encoder(nn.Module):
  def __init__(self):
    super(Encoder,self).__init__()
    # set all the functions that you will use in the forward
        
  def forward(self,x):
    batch_size=x.size()[0]
    # the 3 first layer of the encoder in 2d
    out = ...
    ...
    out=out.view(batch_size,-1)
    # the last layer of the encoder in 1d 
    out=...
    # the mean and the logvar of the latent space representation in 1d
    mean=...
    logvar=...
      
    return mean,logvar

## The Decoder

The decoder is the second layer of our VAE-GAN, it is the "Decoder" class. Where we want to create to pass through series of different layers, from the latent space to try to rebuild the input image.
**The decoder** is composed of a linear fonction and then 3 transposed convolutional layers in 2d, with a kernel size of 6, a stride of 2 and a padding of 2. And you finish by a tanh function.

In [None]:
class Decoder(nn.Module):
  def __init__(self):
    super(Decoder,self).__init__()
    # set all the functions that you will use in the forward

  def forward(self,x):
    x = ...
    # the first layer of the decoder in 1d
    x=x.view(-1,256,8,8)
    # the 3 las layer of the decoder in 2d
    x = ...
    # the last function of the decoder
    x=self.tanh(self.deconv4(x))
    return x

## The Discriminator

The discriminator is the third and last layer of our VAE-GAN, it is the "Discriminator" class. Where we want to create to pass through series of different layers, from the input image to try to classify if the input image is real data or fake data (generate by the generator) they are put in competition during the learning process to see who has the best results.
**The discriminator** is composed of 4 convolutional layers, with a kernel size of 5, a stride of 2 and a padding of 2. After that, the descriminator is composed of a fully connected layer, which means a linear layers and a sigmoid function, the objective of the discriminator is to classify the input image as real or fake.
As you may have understand, it is really similar to the encoder, because it has to do the same thing, create a latent space representation of the input image, but the discriminator has to do it in a different way, because it has to classify the input image as real or fake, so it has to be more precise than the encoder. Each convolutional layer is followed by a batch normalization layer and a LeacklyReLU activation function.

In [None]:
class Discriminator(nn.Module):
  def __init__(self):
    super(Discriminator,self).__init__()
    # set all the functions that you will use in the forward

  def forward(self,x):
    batch_size=x.size()[0]
    # the first layer of the discriminator in 2d and don't take any batchnorm
    x=...
    # the 3 next layer of the discriminator in 2d similare to the encoder
    x=...
    ...
    x=x.view(-1,256*8*8)
    x1=x
    # the last layer of the discriminator in 1d
    x=...
    ...
      
    return x,x1

## The VAE-GAN

It is the combination of the three previous classes, the encoder, the decoder and the discriminator.

In [None]:
class VAE_GAN(nn.Module):
  def __init__(self):
    super(VAE_GAN,self).__init__()
    self.encoder=Encoder()
    self.decoder=Decoder()
    self.discriminator=Discriminator()
    self.encoder.apply(weights_init)
    self.decoder.apply(weights_init)
    self.discriminator.apply(weights_init)


  def forward(self,x):
    bs=x.size()[0]
    z_mean,z_logvar=self.encoder(x)
    std = z_logvar.mul(0.5).exp_()
    epsilon=Variable(torch.randn(bs,128)).to(device)
    z=z_mean+std*epsilon
    x_tilda=self.decoder(z)
      
    return z_mean,z_logvar,x_tilda

## Training the VAE-GAN

Run that to see if the that you create model is working if you don't have a powerfull computer, you can modify the number of epochs to 5 or 10 to see if the model is working and the modify the size of the dataset.

In [None]:
data_loader=dataloader(64)
gen=VAE_GAN().to(device)
discrim=Discriminator().to(device)
real_batch = next(iter(data_loader))
show_and_save("training" ,make_grid((real_batch[0]*0.5+0.5).cpu(),8))

max_train=10

epochs=5
lr=3e-4
alpha=0.1
gamma=15

vae_gan_training(gen, discrim, data_loader, epochs, lr, alpha, gamma, device, real_batch, max_train)
