# Generative Adversarial Networks #
 * Discrement Models
    - Classification
    - $ Feature(X) → Class(Y) $ or $ P(Y|X)$
 * Generative Models
    - $Noise(\xi), Class (Y) → Features (X)$ or $P(X|Y)$
 * Generator → Discriminator
   - Generator learns to make fake that looks real
   - Discriminator learns to distinquish real from fake
      * Classifiers

 * Binary Cross Entropy
   - $ J\left(\theta\right)=-\frac{1}{m}\sum_{i=1}^{m}\left[y^{(i)}\log{h(x^{\left(i\right)},\theta)}+\left(1-y^{\left(i\right)}\right)\log(1-h(x^{\left(i\right)},\theta))\right] $
   

In [None]:
# Importing Libraries
import torch
from torch import nn
from torchvision import transforms
from torchvision.datasets import MNIST 
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from torchvision.io import write_png
from torchvision.transforms import transforms
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
class GAN():
    def __init__(self, numClass, dimImage, numImages):
        self.dimImage = dimImage
        self.numClass = numClass
        self.numImages = numImages
        self.generator = None
        self.discriminator = None
    
    def noiseMatrix(self, n_samples, numClass, device):
        return torch.randn(n_samples, numClass, device=device)
    
    def initGenerator(self, dimImage, numClass, device):
        self.generator = nn.Sequential(
            nn.Sequential(nn.Linear(dimImage, numClass),nn.BatchNorm1d(numClass),nn.ReLU(inplace=True)),
            nn.Sequential(nn.Linear(dimImage, numClass),nn.BatchNorm1d(numClass),nn.ReLU(inplace=True)),
            nn.Sequential(nn.Linear(dimImage, numClass),nn.BatchNorm1d(numClass),nn.ReLU(inplace=True)),
            nn.Sequential(nn.Linear(dimImage, numClass),nn.BatchNorm1d(numClass),nn.ReLU(inplace=True)),
            nn.Linear(numClass, dimImage),
            nn.Sigmoid()
        )

    def initDiscriminator(self, dimImage, numClass, device):
        self.discriminator = nn.Sequential(
            nn.Sequential(nn.Linear(dimImage, numClass),nn.LeakyReLU(0.2)),
            nn.Sequential(nn.Linear(dimImage, numClass),nn.LeakyReLU(0.2)),
            nn.Sequential(nn.Linear(dimImage, numClass),nn.LeakyReLU(0.2)),
            nn.Sequential(nn.Linear(dimImage, numClass),nn.LeakyReLU(0.2)),
            nn.Linear(numClass, 1)
        )

    def fwdGenerator(self, Images):
        return self.generator(Images)

    def fwdDiscriminator(self, Images):
        return self.discriminator(Images)

    def lossGenerator(self, numClass, numImages, device):
        # Generatoring the Noisy Vector
        NoisyVectors = self.noiseMatrix(numImages, numClass, device=device)
        
        # Generating fake images batch
        fakeImages = self.generator(NoisyVectors)

        # Getting Discrimenters Prediction
        discResult = self.discriminator(fakeImages)

        # Tensor for ground truth
        refMatrix = torch.empty(numImages, 1)

        # Calculating the loss
        gen_loss = nn.BCEWithLogitsLoss(discResult, torch.ones_like(refMatrix, device=device))

        return gen_loss

    def lossDiscriminator(self,realImages, numClass, numImages, device):
        # Generatoring the Noisy Vector
        NoisyVectors = self.noiseMatrix(numImages, numClass, device=device)
        
        # Generating fake images batch
        FakeImages = self.gen(NoisyVectors).detach()
        RealImages = realImages
        
        # Tensor for ground truth
        refMatrix = torch.empty(numImages, 1) 
        
        # Getting Loss of Fake Images
        discResult = self.discriminator(FakeImages)
        LossFake = nn.BCEWithLogitsLoss(discResult, torch.zeros_like(refMatrix, device=device))
        
        # Getting Loss of Real Images
        discResult = self.discriminator(RealImages)
        LossReal = nn.BCEWithLogitsLoss(discResult,  torch.ones_like(refMatrix, device=device))
        
        # Computing the Disc Loss
        disc_loss = (LossFake + LossReal) / 2
        
        return disc_loss
        