In [19]:
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [20]:
class Discriminator(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.LeakyReLU(0.1),
            nn.Linear(128,1),
            nn.Sigmoid(),
        )
    
    def forward(self, x):
        return self.disc(x)

In [21]:
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.1),
            nn.Linear(256, img_dim),
            nn.Tanh(),
        )
    
    def forward(self, x):
        return self.gen(x)

In [22]:
#hyperparamters, very sensitive to these as this is a simple GAN
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4 # learning rate, can play around to see what happens when you change it
z_dim = 64 #128, 256 or smaller
image_dim = 28*28*1 #784
batch_size = 32
num_epochs = 50

In [23]:
disc = Discriminator(image_dim).to(device) #discriminator of image dimension sent to device
gen = Generator(z_dim, image_dim).to(device) # Generator of z dimension and image dimension sent to device
fixed_noise = torch.randn((batch_size, z_dim)).to(device) # fixed noise to see how it has changed across the epochs
transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))] # mean and standard deviation in that order, actual mean and sd of mnis data
)


In [24]:
dataset = datasets.MNIST(root="dataset/", transform=transforms, download=True) #getting mnist in
loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) #load all the data with the amount you want and shuffle them
opt_disc = optim.Adam(disc.parameters(), lr=lr)#optimiser for the discriminator 
opt_gen = optim.Adam(gen.parameters(), lr=lr)#optimiser for the generator
criterion = nn.BCELoss() #loss function, BCE Loss
writer_fake = SummaryWriter(f"logs/fake") #writer showing fake images from the generator
writer_real = SummaryWriter(f"logs/real") # writer showing real images
step = 0

In [None]:
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader): #don't need to use the actual labels, GAN unsupervised
        real = real.view(-1, 784).to(device) #Keep the number of examples, then flatten everything else
        batch_size = real.shape[0] #first dimension
        
        #Train Discriminator: maximise log(D(real)) + log(1-D(G(z))) where z is some random noise
        
        noise = torch.randn(batch_size, z_dim).to(device) #generating the noise
        fake = gen(noise) #generate some fake images
        disc_real = disc(real).view(-1) #what the discriminator outputs on the real lines
        lossD_real = criterion(disc_real, torch.ones_like(disc_real)) #first term in the equation above, see BCELoss page
        disc_fake = disc(fake).view(-1) #now the same thing for the second term of the equation above this is D(G(z))
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake)) 
        lossD = (lossD_real + lossD_fake) / 2 #putting the equation together essentially
        disc.zero_grad()
        lossD.backward(retain_graph=True) #saves some of the terms from above
        opt_disc.step()
        
        #Train Generator min log(1-D(G(z))) or better maximise log(D(G(z)))
        
        output = disc(fake).view(-1) #using the second term from above
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()
        
        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1

Epoch [0/50] Batch 0/1875                       Loss D: 0.6864, loss G: 0.6809
Epoch [1/50] Batch 0/1875                       Loss D: 0.1926, loss G: 1.9405
Epoch [2/50] Batch 0/1875                       Loss D: 0.0783, loss G: 2.7705
Epoch [3/50] Batch 0/1875                       Loss D: 0.0486, loss G: 3.7118
Epoch [4/50] Batch 0/1875                       Loss D: 0.1290, loss G: 3.2759
Epoch [5/50] Batch 0/1875                       Loss D: 0.1040, loss G: 4.9335
Epoch [6/50] Batch 0/1875                       Loss D: 0.0645, loss G: 5.3071
Epoch [7/50] Batch 0/1875                       Loss D: 0.0086, loss G: 5.8743
Epoch [8/50] Batch 0/1875                       Loss D: 0.0336, loss G: 4.7474
Epoch [9/50] Batch 0/1875                       Loss D: 0.0096, loss G: 5.7562
