In [68]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, datasets
from torch.utils.data import DataLoader

In [69]:
# It tries to determine whether sample came from data distribution or not
class Discriminator(nn.Module):
    def __init__(self,img_dim):
        super(Discriminator,self).__init__()
        self.disc = nn.Sequential(
            nn.Linear(img_dim,128),
            nn.LeakyReLU(0.1),
            nn.Linear(128,1),
            nn.Sigmoid()
        )
    def forward(self,x):
        return self.disc(x)

In [70]:
# It tries to generate data which cannot be 
# distinguished from the training data by discriminator
class Generator(nn.Module):
    def __init__(self,z_dim,img_dim):
        super(Generator,self).__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim,256),
            nn.LeakyReLU(0.1),
            nn.Linear(256,img_dim),
            nn.Tanh()
        )
    def forward(self,x):
        return self.gen(x)

In [71]:
# training hyperparameters

device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
img_dim = 28 * 28
batch_size = 64
num_epochs = 50

In [72]:
# initializing models, optimizer and loss function,
# load data and create dataset

disc = Discriminator(img_dim).to(device)
gen = Generator(z_dim,img_dim).to(device)
fixed_noise = torch.randn((batch_size,z_dim)).to(device)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])
dataset = datasets.MNIST(root="dataset/",transform=transform,download=True)
loader = DataLoader(dataset,batch_size=batch_size,shuffle=True)
opt_disc = optim.Adam(disc.parameters(),lr=lr)
opt_gen = optim.Adam(gen.parameters(),lr=lr)
loss = nn.BCELoss()

In [None]:
# Training loop
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(loader):
        real = real.view(-1,784).to(device)
        batch_size = real.shape[0]

        # Train discriminator
        # we try to max log(D(real)) + log(1 - D(G(z)))
        noise = torch.randn((batch_size, z_dim)).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = loss(disc_real, torch.ones_like(disc_real))

        disc_fake = disc(fake).view(-1)
        lossD_fake = loss(disc_fake, torch.zeros_like(disc_fake))

        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        # Train generator
        # we try to max log(D(G(z)))
        output = disc(fake).view(-1)
        lossG = loss(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        # Print losses
        if batch_idx % 100 == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )