# simple DCGAN

## imports

In [None]:
import torch
import torchvision
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np

## hyperparams

In [None]:
PATH = 'data'
BATCHSIZE = 64
INPUTCHANNELS = 1
NDF = 4
NGF = 32
NZ = 100
LR = 1e-3
EPOCHS = 20

## load MNIST data

In [None]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

transform = torchvision.transforms.ToTensor()
train_data = torchvision.datasets.MNIST(root=PATH, train=True, transform=transform, download=True)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=BATCHSIZE, shuffle=True)

## Generator and Discriminator classes

In [None]:
class Generator(nn.Module):
    def __init__(self, nz, ngf, nc):
        super().__init__()
        
        self.gen = nn.Sequential(
                                    # input_size: (BATCHSIZE, nz=100, 1, 1)
                                    nn.ConvTranspose2d(in_channels=nz, out_channels=ngf*8, kernel_size=4,stride=1, padding=0), 
                                    nn.BatchNorm2d(ngf * 8), 
                                    nn.ReLU(), 
                                    # input_size: (BATCHSIZE, ngf*8, 4, 4)
                                    nn.ConvTranspose2d(in_channels=ngf*8, out_channels=ngf*4, kernel_size=3, stride=2, padding=1), 
                                    nn.BatchNorm2d(ngf * 4), 
                                    nn.ReLU(), 
                                    # input_size: (BATCHSIZE, ngf*4, 7, 7)
                                    nn.ConvTranspose2d(in_channels=ngf*4, out_channels=ngf, kernel_size=4, stride=2, padding=1), 
                                    nn.BatchNorm2d(ngf), 
                                    nn.ReLU(), 
                                    # input_size: (BATCHSIZE, ngf, 14, 14)
                                    nn.ConvTranspose2d(in_channels=ngf, out_channels=nc, kernel_size=4, stride=2, padding=1), 
                                    # output_size: (BATCHSIZE, nc, 28, 28)
                                    nn.Sigmoid() 
                                )

    def forward(self, input): 
        x = self.gen(input)
        return x


class Discriminator(nn.Module):
    def __init__(self, nc, ndf):
        super().__init__()
        
        self.dis = nn.Sequential(
                                    nn.Conv2d(in_channels=nc, out_channels=ndf, kernel_size=4, stride=2, padding=1), 
                                    nn.ReLU(), 
                                    nn.Conv2d(in_channels=ndf, out_channels=ndf*4, kernel_size=4, stride=2, padding=1), 
                                    nn.BatchNorm2d(ndf*4), 
                                    nn.ReLU(), 
                                    nn.Conv2d(in_channels=ndf*4, out_channels=ndf*8, kernel_size=4, stride=2, padding=1), 
                                    nn.BatchNorm2d(ndf*8), 
                                    nn.ReLU(), 
                                    nn.Flatten(1,-1), 
                                    nn.Linear(ndf*8 * 3 * 3, 1), 
                                    nn.Sigmoid()
                                )
        
    def forward(self, input): 
        x = self.dis(input)
        return x

## initialize model

In [None]:
dis = Discriminator(INPUTCHANNELS, NDF).to(device)
gen = Generator(NZ, NGF, INPUTCHANNELS).to(device)

criterion = nn.BCELoss()

optimizer_dis = torch.optim.Adam(dis.parameters(), lr=LR)
optimizer_gen = torch.optim.Adam(gen.parameters(), lr=LR)

## training loop

In [None]:
gen_losses = []
dis_losses = []

for epoch in range(EPOCHS): 

    for i, data in enumerate(train_loader): 
        real_im, _ = data
        real_im = real_im.to(device)
        
        # ---------------------
        #  Train Discriminator
        # ---------------------
        
        optimizer_dis.zero_grad()
        
        with torch.no_grad(): 
            z = torch.randn(len(real_im), NZ, 1, 1).to(device)   # batch-size, number-channels, height, width
            fake_im = gen(z)
        
        disc_real = dis(real_im)
        disc_fake = dis(fake_im)
        
        # 1 being label for real image, 0 label for generated image
        real_loss = criterion(disc_real, torch.ones_like(disc_real))
        fake_loss = criterion(disc_fake, torch.zeros_like(disc_fake))
        loss_dis = (real_loss + fake_loss) / 2
        
        loss_dis.backward() 
        optimizer_dis.step()
        
        # -----------------
        #  Train Generator
        # -----------------

        optimizer_gen.zero_grad()
        
        z = torch.randn(len(real_im), NZ, 1, 1).to(device)
        fake_im = gen(z)
        output = dis(fake_im)
        
        # 1 being label for real image, 0 label for generated image (here, for generator training, generated images are labeled as real)
        loss_gen = criterion(output, torch.ones_like(output))

        loss_gen.backward()
        optimizer_gen.step()
        
        # record stats
        avg_pred_real = disc_real.mean().item()
        avg_pred_gen1 = disc_fake.mean().item()
        avg_pred_gen2 = output.mean().item()

        # print stats 
        if i % 50 == 0:
            print(f'[{epoch+1}/{EPOCHS}] [{i}/{len(train_loader)}] \nLoss D: {loss_dis.item()}, Loss G: {loss_gen.item()}, Mean D(x): {avg_pred_real}, Mean D(G(z)):{avg_pred_gen1} / {avg_pred_gen2}')

        gen_losses.append(loss_gen.item())
        dis_losses.append(loss_dis.item())

## check results

In [None]:
# plot generator and discriminator loss
plt.figure(figsize=(10,5))
plt.title('Generator and Discriminator Loss')
plt.plot(gen_losses, label='Generator')
plt.plot(dis_losses, label='Discriminator')
plt.xlabel('iterations')
plt.ylabel('loss')
plt.legend()
plt.show()

# sample from a standard gaussian
z = torch.randn(10, NZ).to(device) 
# generate some images
gen_img = gen(z.unsqueeze(2).unsqueeze(3)) 

# show generated images
for i in range(10): 
  plt.subplot(2,5, i + 1) 
  plt.axis('off')
  plt.imshow(gen_img[i].squeeze().detach().cpu().numpy(), cmap='gray_r')

plt.show()

# check discriminator
out_test = dis(gen_img)
print(f'Discriminator tested on generated images, MEAN prediction: {out_test.mean().item():.2f}')

real_batch, _ = next(iter(train_loader))
real_test = dis(real_batch.to(device))
print(f'Discriminator tested on real images, MEAN prediction:  {real_test.mean().item():.2f}')