<a href="https://colab.research.google.com/github/Spartan-119/PyTorch-project-to-build-a-GAN-model-on-MNIST-dataset/blob/main/GAN_Implementation.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# importing all the necessary libraries
import torch
from torchvision.datasets import MNIST
from torch.utils.data import DataLoader
from torchvision import transforms
import torch.nn as nn
import numpy as np

In [2]:
# method to tansform and load the data with Torch
def get_dl(batchsize):

    # to transform the data to Tensors
    train_transforms = transforms.Compose([transforms.ToTensor()])

    # downloading the training and testing MNIST dataset and transforming it to Tensors
    train_data = MNIST(root = './train.', train = True, download = True, transform = train_transforms)
    test_data = MNIST(root = './test.', train = True, transform = train_transforms)

    # loading the data
    train_loader = DataLoader(train_data, batch_size = batchsize, shuffle = False, drop_last = True)
    test_loader = DataLoader(test_data, batch_size = batchsize, shuffle = False, drop_last = True)

    # returning the train and the test data
    return train_loader, test_loader

## Generator Network

In [4]:
class Generator(nn.Module):
    def __init__(self, batch_size, input_dim):
        super().__init__()
        self.batch_size = batch_size
        self.input_dim = input_dim
        self.fc1 = nn.Linear(input_dim, 128)
        self.LRelu = nn.LeakyReLU()
        self.fc2 = nn.Linear(128, 1 * 28 * 28)
        self.tanh = nn.Tanh()

    # the method for forward propogation
    def forward(self, x):
        layer1 = self.LRelu(self.fc1(x))
        layer2 = self.tanh(self.fc2(layer1))
        output = layer2.view(self.batch_size, 1, 28, 28)
        return output

## Discriminator Network

In [5]:
class Discriminator(nn.Module):
    def __init__(self, batch_size):
        super().__init__()
        self.batch_size = batch_size
        self.fc1 = nn.Linear(1 * 28 * 28, 128)
        self.LRelu = nn.LeakyReLU()
        self.fc2 = nn.Linear(128, 1)
        self.SigmoidL = nn.Sigmoid()

    # the method for forward propogation
    def forward(self, x):
        flat = x.view(self.batch_size, -1)
        layer1 = self.LRelu(self.fc1(flat))
        output = self.SigmoidL(self.fc2(layer1))
        return output.view(-1, 1).squeeze(1)