In [31]:
import torch
import torch.nn as nn
import torchvision
import torch.optim as optim
import torchvision.datasets as dataset
from torchvision import datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
from torch.utils.tensorboard import SummaryWriter

In [32]:
class Discriminator(nn.Module):
    def __init__(self, in_features):
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features, 128),
            nn.LeakyReLU(0.01),
            nn.Linear(128, 1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)

In [33]:
class Generator(nn.Module):
    def __init__(self, z_dim, img_dim):
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(z_dim, 256),
            nn.LeakyReLU(0.01),
            nn.Linear(256, img_dim),
            nn.Tanh(),  # normalize inputs to [-1, 1] so make outputs [-1, 1]
        )

    def forward(self, x):
        return self.gen(x)

In [34]:
device = "cuda" if torch.cuda.is_available() else "cpu"
lr = 3e-4
z_dim = 64
image_dim = 28 * 28 * 1  # 784
batch_size = 32
num_epochs = 50

In [35]:
disc = Discriminator(image_dim).to(device)
gen = Generator(z_dim, image_dim).to(device)
fixed_noise = torch.randn((batch_size, z_dim)).to(device)

In [36]:
transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)),]
)

In [37]:
dataset = datasets.MNIST(root='/dataset/', transform=transforms, download=True)
mloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

In [38]:
opt_disc = optim.Adam(disc.parameters(),lr=lr)
opt_gen = optim.Adam(gen.parameters(), lr=lr)
criterion = nn.BCELoss()
writer_fake = SummaryWriter(f"logs/fake")
writer_real = SummaryWriter(f"logs/real")
step = 0

In [39]:
for epoch in range(num_epochs):
    for batch_idx, (real, _) in enumerate(mloader):
        real = real.view(-1, 784).to(device)
        batch_size = real.shape[0]

        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device)
        fake = gen(noise)
        disc_real = disc(real).view(-1)
        lossD_real = criterion(disc_real, torch.ones_like(disc_real))
        disc_fake = disc(fake).view(-1)
        lossD_fake = criterion(disc_fake, torch.zeros_like(disc_fake))
        lossD = (lossD_real + lossD_fake) / 2
        disc.zero_grad()
        lossD.backward(retain_graph=True)
        opt_disc.step()

        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        # where the second option of maximizing doesn't suffer from
        # saturating gradients
        output = disc(fake).view(-1)
        lossG = criterion(output, torch.ones_like(output))
        gen.zero_grad()
        lossG.backward()
        opt_gen.step()

        if batch_idx == 0:
            print(
                f"Epoch [{epoch}/{num_epochs}] Batch {batch_idx}/{len(mloader)} \
                      Loss D: {lossD:.4f}, loss G: {lossG:.4f}"
            )

            with torch.no_grad():
                fake = gen(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)

                writer_fake.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                writer_real.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1

Epoch [0/50] Batch 0/1875                       Loss D: 0.6728, loss G: 0.6970
Epoch [1/50] Batch 0/1875                       Loss D: 0.2995, loss G: 1.7903
Epoch [2/50] Batch 0/1875                       Loss D: 0.6743, loss G: 0.6925
Epoch [3/50] Batch 0/1875                       Loss D: 0.8741, loss G: 0.7288
Epoch [4/50] Batch 0/1875                       Loss D: 0.4246, loss G: 1.2617
Epoch [5/50] Batch 0/1875                       Loss D: 0.3440, loss G: 1.6084
Epoch [6/50] Batch 0/1875                       Loss D: 0.4414, loss G: 1.2946
Epoch [7/50] Batch 0/1875                       Loss D: 0.7766, loss G: 1.3243
Epoch [8/50] Batch 0/1875                       Loss D: 0.4852, loss G: 1.3977
Epoch [9/50] Batch 0/1875                       Loss D: 0.5232, loss G: 1.4057
Epoch [10/50] Batch 0/1875                       Loss D: 0.2715, loss G: 2.1673
Epoch [11/50] Batch 0/1875                       Loss D: 0.3480, loss G: 1.9572
Epoch [12/50] Batch 0/1875                       L