# GAN_MNIST
This code below is all about Generative Adversarial Network (GAN Model) to generate images which are similar to MNIST datasets. It demonstrates GAN model code simulation.


## Import header

In [None]:
import os
import torch
import torch.nn             as nn
import torch.optim          as optim
import torchvision
import torchvision.datasets as dsets
from   torch.utils.data       import DataLoader
from   torchvision.transforms import transforms
from   torchvision.utils      import save_image

## Select Cuda or CPU

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('device: {}'.format(device))

device: cuda


## Give a seed

In [None]:
torch.manual_seed(777)
if device=='cuda':
    torch.cuda.manual_seed_all(777)

## Hyper parameters

In [None]:
image_size = 784 
hidden_size = 256
latent_code_size = 64
batch_size = 100 
total_epoch = 200
sample_dir= 'G_image'

if not os.path.exists(sample_dir):
    os.makedirs(sample_dir)

## Dataset

In [None]:
transform = transforms.Compose([
                                transforms.ToTensor(),
                                transforms.Normalize(mean=[0.5],
                                                     std=[0.5])
])
MNIST_train = dsets.MNIST(root = 'MNIST_data/', train=True,  transform=transform, download=True)
MNIST_test  = dsets.MNIST(root = 'MNIST_data/', train=False, transform=transform, download=True)

dataloader = DataLoader(dataset=MNIST_train, batch_size = batch_size, drop_last=True, shuffle=True)
iteration = len(dataloader)
print('iteration: {}'.format(iteration))

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz


  0%|          | 0/9912422 [00:00<?, ?it/s]

Extracting MNIST_data/MNIST/raw/train-images-idx3-ubyte.gz to MNIST_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz


  0%|          | 0/28881 [00:00<?, ?it/s]

Extracting MNIST_data/MNIST/raw/train-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz


  0%|          | 0/1648877 [00:00<?, ?it/s]

Extracting MNIST_data/MNIST/raw/t10k-images-idx3-ubyte.gz to MNIST_data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz


  0%|          | 0/4542 [00:00<?, ?it/s]

Extracting MNIST_data/MNIST/raw/t10k-labels-idx1-ubyte.gz to MNIST_data/MNIST/raw

iteration: 600


  return torch.from_numpy(parsed.astype(m[2], copy=False)).view(*s)


## Model

In [None]:
class Binary_Classfier(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(image_size, hidden_size),
            nn.LeakyReLU(0.2)
        )
        self.layer2 = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.LeakyReLU(0.2),
        )
        self.layer3 = nn.Sequential(
            nn.Linear(hidden_size,1),
            nn.Sigmoid()
        )
    def forward(self, x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        return out

class GAN(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Sequential(
            nn.Linear(latent_code_size, hidden_size),
            nn.LeakyReLU(0.2)
        )
        self.layer2 = nn.Sequential(
            nn.Linear(hidden_size, hidden_size),
            nn.ReLU(),
        )
        self.layer3 = nn.Sequential(
            nn.Linear(hidden_size, image_size),
            nn.Tanh()
        )
    def forward(self,x):
        out = self.layer1(x)
        out = self.layer2(out)
        out = self.layer3(out)
        return out

BC_model = Binary_Classfier().to(device)
GAN_model = GAN().to(device)

In [None]:
criterion = nn.BCELoss().to(device)
BC_optimizer = optim.Adam(BC_model.parameters(), lr=0.0002)
GAN_optimizer = optim.Adam(GAN_model.parameters(), lr=0.0002)

## Training

In [None]:
def denorm(x):
    out = (x + 1) / 2
    return out.clamp(0, 1)

In [None]:
D_loss_array = list()
G_loss_array = list()

for epoch in range(total_epoch):
    D_average_loss = 0.
    G_average_loss = 0.
    for (images,_) in (dataloader):
        ### print(images.shape)
        images = images.reshape(batch_size, -1).to(device)
        ### print(images.shape)

        # Discriminator
        real_labels = torch.ones(batch_size, 1).to(device)
        fake_labels = torch.zeros(batch_size, 1).to(device)
        ### print(real_labels.shape)
        ### print(fake_labels.shape)

        # Generator
        z = torch.randn(batch_size, latent_code_size).to(device)
        ### print('latent code z shape: {}'.format(z.shape))

        D_score = BC_model(images)

        fake_images = GAN_model(z)
        G_score = BC_model(fake_images)
        
        ### Discriminator loss and update parameters
        D_loss = criterion(D_score,real_labels) + criterion(G_score, fake_labels)
        BC_optimizer.zero_grad()
        GAN_optimizer.zero_grad()
        D_loss.backward(retain_graph=True)
        BC_optimizer.step()

        z = torch.randn(batch_size, latent_code_size).to(device)
        fake_images = GAN_model(z)
        G_score = BC_model(fake_images)

        # Fix Discrimnator's parameters and update Generator's parameters only.
        # Update parameters of Generator.
        G_loss = criterion(G_score, real_labels)
        GAN_optimizer.zero_grad()
        BC_optimizer.zero_grad()
        G_loss.backward()
        GAN_optimizer.step()

        # Calculate average losse of Discriminator and Generator during one epoch.
        D_average_loss += D_loss/iteration
        G_average_loss += G_loss/iteration

    print('Epoch: {:4d}/{} D_Loss: {:.5f} G_Loss: {:.5f} D_score: {:.5f} G_score: {:.5f}'.format(epoch, total_epoch, D_average_loss, G_average_loss, D_score.float().mean().item(), G_score.float().mean().item()))
    
    # Save real images
    if (epoch+1) == 1:
        images = images.reshape(images.size(0), 1, 28, 28)
        save_image(denorm(images), os.path.join(sample_dir, 'real_images.png'))
    
    # Save sampled images
    fake_images = fake_images.reshape(fake_images.size(0), 1, 28, 28)
    save_image(denorm(fake_images), os.path.join(sample_dir, 'fake_images-{}.png'.format(epoch+1)))

# Save the model checkpoints 
torch.save(GAN_model.state_dict(), 'GAN_model.ckpt')
torch.save(BC_model.state_dict(), 'BC_model.ckpt')


Epoch:    0/200 D_Loss: 0.24757 G_Loss: 3.67362 D_score: 0.94149 G_score: 0.09379
Epoch:    1/200 D_Loss: 0.28397 G_Loss: 4.51259 D_score: 0.93468 G_score: 0.02459
Epoch:    2/200 D_Loss: 0.43253 G_Loss: 4.23841 D_score: 0.80063 G_score: 0.05733
Epoch:    3/200 D_Loss: 0.50343 G_Loss: 3.37753 D_score: 0.79487 G_score: 0.18658
Epoch:    4/200 D_Loss: 0.57226 G_Loss: 2.76673 D_score: 0.75646 G_score: 0.08846
Epoch:    5/200 D_Loss: 0.83640 G_Loss: 2.53615 D_score: 0.72359 G_score: 0.09790
Epoch:    6/200 D_Loss: 0.65673 G_Loss: 2.41268 D_score: 0.77749 G_score: 0.15134
Epoch:    7/200 D_Loss: 0.56302 G_Loss: 2.84430 D_score: 0.88739 G_score: 0.08046
Epoch:    8/200 D_Loss: 0.42204 G_Loss: 3.07882 D_score: 0.90318 G_score: 0.10524
Epoch:    9/200 D_Loss: 0.39532 G_Loss: 3.67709 D_score: 0.91229 G_score: 0.03076
Epoch:   10/200 D_Loss: 0.32124 G_Loss: 3.81268 D_score: 0.95800 G_score: 0.13232
Epoch:   11/200 D_Loss: 0.39949 G_Loss: 4.15244 D_score: 0.94661 G_score: 0.04471
Epoch:   12/200 

## Download image file from Colab

In [None]:
from google.colab import files

files.download('/content/G_image/fake_images-1.png')
files.download('/content/G_image/fake_images-25.png')
files.download('/content/G_image/fake_images-50.png')
files.download('/content/G_image/fake_images-75.png')
files.download('/content/G_image/fake_images-100.png')
files.download('/content/G_image/fake_images-125.png')
files.download('/content/G_image/fake_images-150.png')
files.download('/content/G_image/fake_images-175.png')
files.download('/content/G_image/fake_images-200.png')

<IPython.core.display.Javascript object>

<IPython.core.display.Javascript object>