In [17]:
import torch 
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.datasets as datasets 
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [18]:
class discriminator(nn.Module):
    def __init__(self,img_dim):
        super().__init__()

        self.disc=nn.Sequential(
            nn.Linear(img_dim,128),
            nn.LeakyReLU(0.1), # leaky relu is better for gans i general 
            nn.Linear(128,1), 
            nn.Sigmoid(),
        )

    def forward(self,x):
        return self.disc(x)  

In [19]:
class generator(nn.Module):
    def __init__(self,z_dim, img_dim):  #z_dim is latent noise
        super().__init__()
        self.gen=nn.Sequential(
            nn.Linear(z_dim,256),
            nn.LeakyReLU(0.1),
            nn.Linear(256,img_dim),
            nn.Tanh(),  # normalizig output to -1 and 1 as our input range will also be -1 to 1 after normalization
        )

    def forward(self,x):
        return self.gen(x)

In [20]:
#hyperparameters
device=torch.device("cuda" if torch.cuda.is_available() else "cpu")
lr=3e-4  # generally 3e-4 is the best learning rate for adam optimizer
z_dim=64
image_dim=28*28*1
batch_size=32
num_epochs=50

Discriminator (D) tries to maximize the probability of correctly classifying real vs. fake.

Generator (G) tries to minimize the probability that D correctly classifies fake samples.

In [21]:
disc=discriminator(image_dim).to(device)
gen=generator(z_dim,image_dim).to(device)
fixed_noise=torch.randn((batch_size,z_dim)).to(device)

# reason why we are doing somme fixed noise is because we can then how it has changed across the epochs 

In [22]:
transforms=transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,),(0.3081,)),
])

In [23]:
dataset=datasets.MNIST(root="dataset/",transform=transforms,download=True)
loader=DataLoader(dataset,batch_size=batch_size,shuffle=True)
optimzer_disc=optim.Adam(disc.parameters(),lr=lr)
optimzer_gen=optim.Adam(gen.parameters(),lr=lr)
criterion=nn.BCELoss()  
writer_fake=SummaryWriter(f"runs/GAN_mnist/fake")
writer_real=SummaryWriter(f"runs/GAN_mnist/real")
step=0

In [24]:
for epoch in range(num_epochs):
    for batch_idx, (images,labels) in enumerate(loader):
        images=images.view(-1,784).to(device)
        batch_size=images.shape[0]

        # train discriminator: max log(D(real)) + log(1-D(G(z)))
        noise=torch.randn(batch_size,z_dim).to(device)

        fake=gen(noise)
        disc_real=disc(images).view(-1)  # log (d(real))  .view(-1) flattens
        lossd_real=criterion(disc_real,torch.ones_like(disc_real))

        # disc_fake=disc(fake.detach()).view(-1)  # as we want to use fake image later in generator, we detach it or retain_graph=True in loss so those computations remain even if we do backward pass 
        disc_fake=disc(fake).view(-1)  # as we want to use fake image later in generator, we detach it or retain_graph=True in loss so those computations remain even if we do backward pass 
        lossd_fake=criterion(disc_fake,torch.zeros_like(disc_fake))

        lossd=(lossd_real+lossd_fake)/2

        disc.zero_grad()
        lossd.backward(retain_graph=True)
        optimzer_disc.step()

        # train genereator min (log (1-d(g(z)))) = max log(d(g(z)))

        output=disc(fake).view(-1)
        lossg=criterion(output,torch.ones_like(output))
        gen.zero_grad()
        lossg.backward()
        optimzer_gen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(loader)} \
                      Loss D: {lossd:.4f}, loss G: {lossg:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = images.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1

Epoch [0/50] Batch 0/1875                       Loss D: 0.6620, loss G: 0.7225
Epoch [1/50] Batch 0/1875                       Loss D: 0.2270, loss G: 1.6645
Epoch [2/50] Batch 0/1875                       Loss D: 0.1329, loss G: 2.9073
Epoch [3/50] Batch 0/1875                       Loss D: 0.0885, loss G: 3.0628
Epoch [4/50] Batch 0/1875                       Loss D: 0.0567, loss G: 3.2308
Epoch [5/50] Batch 0/1875                       Loss D: 0.0837, loss G: 4.4954
Epoch [6/50] Batch 0/1875                       Loss D: 0.0603, loss G: 4.6880
Epoch [7/50] Batch 0/1875                       Loss D: 0.0659, loss G: 4.6345
Epoch [8/50] Batch 0/1875                       Loss D: 0.1427, loss G: 4.5405
Epoch [9/50] Batch 0/1875                       Loss D: 0.0267, loss G: 5.0754
Epoch [10/50] Batch 0/1875                       Loss D: 0.0226, loss G: 4.6562
Epoch [11/50] Batch 0/1875                       Loss D: 0.0097, loss G: 4.8237
Epoch [12/50] Batch 0/1875                       L

KeyboardInterrupt: 