# Discriminator and Generator implementation

In this notebook, you will implement the generator and discriminator models. These models will be use in the last exercise of this lesson to train your first GAN network! 

## Discriminator

The discriminator network is going to be a pretty typical linear classifier. To make this network a universal function approximator, we'll need at least one hidden layer, and these hidden layers should have one key attribute:
> All hidden layers will have a [Leaky ReLu](https://pytorch.org/docs/stable/nn.html#torch.nn.LeakyReLU) activation function applied to their outputs.

<img src='../assets/gan_network.png' width=70% />

#### Leaky ReLu

We should use a leaky ReLU to allow gradients to flow backwards through the layer unimpeded. A leaky ReLU is like a normal ReLU, except that there is a small non-zero output for negative input values.

<img src='../assets/leaky_relu.png' width=40% />

#### Output

We'll also take the approach of using a more numerically stable loss function on the outputs. Recall that we want the discriminator to output a value 0-1 indicating whether an image is _real or fake_. 
> We will ultimately use [BCEWithLogitsLoss](https://pytorch.org/docs/stable/nn.html#bcewithlogitsloss), which combines a `sigmoid` activation function **and** binary cross entropy loss in one function. 

So, our final output layer should not have any activation function applied to it.

#### Structure

The discriminator takes a high dimensional input (for example, an image) and outputs a single score value. Linear layers in the discriminator should have a number of neurons such that the dimensions of their output is smaller than the dimension of their input.

### First exercise

Implement a discriminator network. Your network should:
* use fully connected layer and leaky relu
* output a single logit
* take a image as input 

In [2]:
import torch
import torch.nn as nn

import tests

In [12]:
class Discriminator(nn.Module):
    """
    Discriminator model:
    args: 
    - input_dim: dimension of the input data. For example, for a 28 by 28 grayscale image, the input size is 784
    - hidden_dim: a parameter that controls the dimensions of the hidden layers. 
    """
    def __init__(self, input_dim: int, hidden_dim: int):
        super(Discriminator, self).__init__()
        #### 
        # IMPLEMENT HERE
        ####
        self.flatten = nn.Flatten()

        self.fc1 = nn.Linear(input_dim, hidden_dim // 2)
        self.fc2 = nn.Linear(input_dim // 2, hidden_dim // 4)
        self.fc3 = nn.Linear(input_dim // 4, 1)

        self.activation = nn.LeakyReLU()
        
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        #### 
        # IMPLEMENT HERE
        ####
        x = self.flatten(x)
        
        x = self.fc1(x)
        x = self.activation(x)

        x = self.fc2(x)
        x = self.activation(x)

        x = self.fc3(x)
        
        return x
    

# m = nn.Softmax(dim=1)
# input = torch.randn(2, 3)
# output = m(input)

# print(input)
# print(output)
# print(8 // 5)

In [13]:
# for a 28x28 grayscale image flattened, the input dim is 784
input_dim = 784
hidden_dim = 256

discriminator = Discriminator(input_dim, hidden_dim)
tests.check_discriminator(discriminator, input_dim)

RuntimeError: mat1 and mat2 shapes cannot be multiplied (16x128 and 392x64)

## Generator

The generator network will be almost exactly the same as the discriminator network, except that we're applying a [tanh activation function](https://pytorch.org/docs/stable/nn.html#tanh) to our output layer.

#### tanh Output
The generator has been found to perform the best with $tanh$ for the generator output, which scales the output to be between -1 and 1, instead of 0 and 1. 

<img src='../assets/tanh_fn.png' width=40% />

Recall that we also want these outputs to be comparable to the *real* input pixel values, which are read in as normalized values between 0 and 1. 
> So, we'll also have to **scale our real input images to have pixel values between -1 and 1** when we train the discriminator. 

## Second Exercise
Implement a generator network. Your network should:
* use fully connected, leaky relu and tanh layers
* take a latent as input
* output a vector (we will later reshape it as an image)

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim: int, hidden_dim: int, output_size: int):
        super(Generator, self).__init__()
        #### 
        # IMPLEMENT HERE
        ####

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        #### 
        # IMPLEMENT HERE
        ####
        return x

In [None]:
latent_dim = 128
hidden_dim = 256
output_dim = 784

generator = Generator(latent_dim, hidden_dim, output_dim)
tests.check_generator(generator, latent_dim, output_dim)