In [1]:
# loading libraries

In [2]:
import numpy as np
import pandas as pd

# pytorch
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter # to print to tensorboard

In [3]:
# Creating a generator class

In [5]:
class Generator(nn.Module):
    def __init__(self, input_dim, output_dim):
        super(Generator, self).__init__()
        
        self.generator = nn.Sequential(
            nn.Linear(input_dim, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, output_dim),
            nn.Tanh() # normalize inputs to [-1, 1] so make outputs [-1, 1]
        )
        
    def forward(self, x):
        return self.generator(x)

In [6]:
# Creating a Discriminator

In [7]:
class Discriminator(nn.Module):
    def __init__(self, input_dim):
        super(Discriminator, self).__init__()
        
        self.discriminator = nn.Sequential(
            nn.Linear(input_dim, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
            nn.Sigmoid() 
        )
        
    def forward(self, x):
        return self.discriminator(x)

In [9]:
# # Hyperparameters

In [10]:
device = "cuda:0" if torch.cuda.is_available() else "cpu"
device

'cpu'

In [11]:
lr = 3e-4
input_dim = 64
image_dim = 28 * 28 * 1 # 784 minist dataset
batch_size = 64
num_epochs = 5

In [12]:
# Creating model instance

In [14]:
gen = Generator(input_dim, image_dim).to(device)
disc = Discriminator(image_dim).to(device)

In [15]:
# defining transforms

In [16]:
transforms = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,),(0.5,))
    ]
)

In [18]:
# set a noise
fixed_noise = torch.randn((batch_size, input_dim)).to(device)

In [19]:
# seting optimizer

In [20]:
opt_gen = optim.Adam(gen.parameters(), lr=lr)
opt_disc = optim.Adam(disc.parameters(), lr=lr)

In [21]:
# setting loss metrix

In [22]:
criterion = nn.BCELoss()

In [23]:
# setting for tensorboard summary writer

In [24]:
writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")
step = 0

In [25]:
# Loading dataset

In [26]:
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)

Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz to dataset/MNIST\raw\train-images-idx3-ubyte.gz


100%|████████████████████████████████████████████████████████████████████| 9912422/9912422 [00:19<00:00, 509180.12it/s]


Extracting dataset/MNIST\raw\train-images-idx3-ubyte.gz to dataset/MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz to dataset/MNIST\raw\train-labels-idx1-ubyte.gz


100%|███████████████████████████████████████████████████████████████████████| 28881/28881 [00:00<00:00, 7260155.46it/s]

Extracting dataset/MNIST\raw\train-labels-idx1-ubyte.gz to dataset/MNIST\raw






Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz to dataset/MNIST\raw\t10k-images-idx3-ubyte.gz


100%|████████████████████████████████████████████████████████████████████| 1648877/1648877 [00:03<00:00, 539525.33it/s]


Extracting dataset/MNIST\raw\t10k-images-idx3-ubyte.gz to dataset/MNIST\raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz to dataset/MNIST\raw\t10k-labels-idx1-ubyte.gz


100%|█████████████████████████████████████████████████████████████████████████| 4542/4542 [00:00<00:00, 1515073.07it/s]

Extracting dataset/MNIST\raw\t10k-labels-idx1-ubyte.gz to dataset/MNIST\raw






In [27]:
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [28]:
# Training and testing loop

In [29]:
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, input_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        # where the second option of maximizing doesn't suffer from
        # saturating gradients
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1

Epoch [0/5] Batch 0/938                       Loss D: 0.7109, loss G: 0.7026
Epoch [1/5] Batch 0/938                       Loss D: 0.0226, loss G: 6.1919
Epoch [2/5] Batch 0/938                       Loss D: 1.3222, loss G: 1.2148
Epoch [3/5] Batch 0/938                       Loss D: 0.7110, loss G: 1.3979
Epoch [4/5] Batch 0/938                       Loss D: 0.4796, loss G: 1.3377
