In [12]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [13]:
class Discriminator(nn.Module):
    def __init__(self, img_dim):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )
        
    def forward(self, x):
        return self.disc(x)

In [14]:
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dim),
            nn.Tanh(),
        )
        
    def forward(self, x):
        return self.gen(x)    

In [15]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
lr = 3e-4
z_dim = 64
imange_dim = 28 * 28 * 1
batch_size = 32
num_epochs =50

disc = Discriminator(imange_dim).to(device)
gen = Generator(z_dim, imange_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)
transforms = torchvision.transforms.Compose(
    [torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5,), (0.5,))]
)
#transforms

dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True)
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
opt_disc = optim.Adam(disc.parameters(), lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f'runs/GAN_MINST/fake')
writer_real = SummaryWriter(f'runs/GAN_MINST/real')
step = 0

for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]
        
        '''
        Train Discriminator
        max log(D(real)) + log(1-D(G(z)))
        '''
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real+ lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()
        
        
        '''
        Train Generator
        min log(1-D(G(z)))
        this is difficult and sometimes no training actually happens
        so we try, max log(D(G(z)))
        '''
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()
        
        if batch_idx ==0:
            print(
                f'Epoch [{epoch}/{num_epochs}] \ '
                f'Loss D: {lossD: .4f}, Loss G: {lossG: .4f}'
            )
            
        with torch.no_grad():
            fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
            data = real.reshape(-1, 1, 28, 28)
            img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
            img_grid_real = torchvision.utils.make_grid(data, normalize=True)

            writer_fake.add_image(
                'Minst Fake Images', img_grid_fake, global_step=step
            )

            writer_real.add_image(
                'Minst Real Images', img_grid_real, global_step=step
            )
                
            step+=1

Epoch [0/50] \ Loss D:  0.6481, Loss G:  0.7164
Epoch [1/50] \ Loss D:  0.5057, Loss G:  0.9389
Epoch [2/50] \ Loss D:  0.5441, Loss G:  1.1495
Epoch [3/50] \ Loss D:  0.6509, Loss G:  1.2326
Epoch [4/50] \ Loss D:  0.7629, Loss G:  0.9128
Epoch [5/50] \ Loss D:  0.4804, Loss G:  1.2489
Epoch [6/50] \ Loss D:  0.3955, Loss G:  1.3431
Epoch [7/50] \ Loss D:  0.4762, Loss G:  1.0314
Epoch [8/50] \ Loss D:  0.4758, Loss G:  1.2929
Epoch [9/50] \ Loss D:  0.8534, Loss G:  0.6756
Epoch [10/50] \ Loss D:  0.5570, Loss G:  1.4140
Epoch [11/50] \ Loss D:  0.5457, Loss G:  0.9449
Epoch [12/50] \ Loss D:  0.4697, Loss G:  1.3705
Epoch [13/50] \ Loss D:  0.5434, Loss G:  1.2194
Epoch [14/50] \ Loss D:  0.7492, Loss G:  0.9982
Epoch [15/50] \ Loss D:  0.7176, Loss G:  0.8685
Epoch [16/50] \ Loss D:  0.6734, Loss G:  1.1610
Epoch [17/50] \ Loss D:  0.4378, Loss G:  1.0905
Epoch [18/50] \ Loss D:  0.6629, Loss G:  0.9937
Epoch [19/50] \ Loss D:  0.6833, Loss G:  1.0397
Epoch [20/50] \ Loss D:  0.656