In [1]:
import pandas as pd
import numpy as np

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets

from torch.utils.tensorboard import SummaryWriter

In [2]:
class Discriminator(nn.Module):
    def __init__(self,img_dim):
        super().__init__()
        self.disc = nn.Sequential(nn.Linear(img_dim,128),
                                 nn.LeakyReLU(0.1),
                                  nn.Linear(128,1),
                                  nn.Sigmoid()
                                 )
    def forward(self,X):
        return self.disc(X)

In [3]:
class Generator(nn.Module):
    def __init__(self,z_dim,img_dim):
        super().__init__()
        self.gen = nn.Sequential(nn.Linear(z_dim,256),
                                nn.LeakyReLU(0.1),
                                nn.Linear(256,img_dim),
                                nn.Tanh())
    def forward(self,X):
        return self.gen(X)

In [4]:
#Hyperparamtres, etc.
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
image_dim = 28*28*1
batch_size = 32
num_epoch = 50

In [5]:
disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim,image_dim).to(device)
fixed_noise = torch.randn((batch_size,z_dim)).to(device)

In [6]:
transforms = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,),(0.3081,))
])

In [7]:
dataset = datasets.MNIST(root="dataset/",transform=transforms,download=True)

In [8]:
loader = DataLoader(dataset,batch_size=batch_size,shuffle=True)

In [9]:
opt_disc = optim.Adam(disc.parameters(),lr=lr)
opt_gen = optim.Adam(gen.parameters(),lr=lr)

In [10]:
criterion = nn.BCELoss()

In [11]:
writer_fake = SummaryWriter(f"runs/GAN_MNIST/fake")
writer_real = SummaryWriter(f"runs/GAN_MNIST/real")
step = 0

In [16]:
for epoch in range(num_epoch):
    for batch_idx,(real,_) in enumerate(loader):
        real = real.view(-1,784).to(device)
        batch_size = real.shape[0]
        
        ###Train Descriminator max log(D(real)) + log(1-D(G(z)))
        noise = torch.randn(batch_size,z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real,torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake,torch.zeros_like(disc_fake))
        lossD = (lossD_real+lossD_fake)/2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()
        
        
        ###Train Generator min log(1-D(G(Z))) ~ max(D(G(z)))
        output = disc(fake).view(-1)
        lossG = criterion(output,torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()
        
        if batch_idx ==0:
            print(
            f'Epoch [{epoch}/{num_epoch}] \ '
            f'Loos D: {lossD:.4f}, loss G: {lossG:0.4f}')
            
        with torch.no_grad():
            fake = gen(fixed_noise).reshape(-1,1,28,28)
            data = real.reshape(-1,1,28,28)
            img_grid_fake = torchvision.utils.make_grid(fake,normalize=True)
            img_grid_real = torchvision.utils.make_grid(real,normalize=True)
            
            writer_fake.add_image("MNIST fake Images",
                                 img_grid_fake,
                                 global_step=step)
            writer_real.add_image("MNIST real Images",
                                 img_grid_real,
                                 global_step=step)
            step +=1

Epoch [0/50] \ Loos D: 0.4535, loss G: 0.7510
Epoch [1/50] \ Loos D: 0.1944, loss G: 1.9177
Epoch [2/50] \ Loos D: 0.0794, loss G: 3.0048
Epoch [3/50] \ Loos D: 0.0573, loss G: 4.3026
Epoch [4/50] \ Loos D: 0.1316, loss G: 3.4630
Epoch [5/50] \ Loos D: 0.0249, loss G: 5.0518
Epoch [6/50] \ Loos D: 0.0794, loss G: 5.2957
Epoch [7/50] \ Loos D: 0.0203, loss G: 4.8293
Epoch [8/50] \ Loos D: 0.0205, loss G: 4.6067
Epoch [9/50] \ Loos D: 0.0243, loss G: 4.6479
Epoch [10/50] \ Loos D: 0.0178, loss G: 4.4805
Epoch [11/50] \ Loos D: 0.0045, loss G: 6.4864
Epoch [12/50] \ Loos D: 0.0186, loss G: 4.9694
Epoch [13/50] \ Loos D: 0.0078, loss G: 5.7522
Epoch [14/50] \ Loos D: 0.0038, loss G: 7.1585
Epoch [15/50] \ Loos D: 0.0061, loss G: 6.3335
Epoch [16/50] \ Loos D: 0.0313, loss G: 5.9002
Epoch [17/50] \ Loos D: 0.0050, loss G: 6.0848
Epoch [18/50] \ Loos D: 0.0043, loss G: 6.5770
Epoch [19/50] \ Loos D: 0.0050, loss G: 6.1573
Epoch [20/50] \ Loos D: 0.0184, loss G: 5.9929
Epoch [21/50] \ Loos D:

KeyboardInterrupt: 