In [1]:
import torch
from torch import nn
import torchvision
from torch.utils.tensorboard import SummaryWriter

In [2]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

'cuda'

In [3]:
from torchvision import transforms

transforms = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5), (0.5))]
)
transforms

Compose(
    ToTensor()
    Normalize(mean=0.5, std=0.5)
)

In [5]:
from torchvision import datasets

dataset = datasets.MNIST(root="data", transform=transforms, download=True)
dataset

Dataset MNIST
    Number of datapoints: 60000
    Root location: data
    Split: Train
    StandardTransform
Transform: Compose(
               ToTensor()
               Normalize(mean=0.5, std=0.5)
           )

In [6]:
from torch.utils.data import DataLoader

dataloader = DataLoader(dataset=dataset, batch_size=32, shuffle=True)
dataloader

<torch.utils.data.dataloader.DataLoader at 0x23713a71570>

In [7]:
class Discriminator(nn.Module):
    def __init__(self, image_dim) -> None:
        super().__init__()
        self.disc = nn.Sequential(
            nn.Linear(in_features=image_dim, out_features=128),
            nn.LeakyReLU(0.01),
            nn.Linear(in_features=128, out_features=1),
            nn.Sigmoid(),
        )

    def forward(self, x):
        return self.disc(x)

In [8]:
class Generator(nn.Module):
    def __init__(self, z_dim, image_dim) -> None:
        super().__init__()
        self.gen = nn.Sequential(
            nn.Linear(in_features=z_dim, out_features=256),
            nn.LeakyReLU(0.01),
            nn.Linear(in_features=256, out_features=image_dim),
            nn.Tanh(),
        )

    def forward(self, x):
        return self.gen(x)

In [9]:
lr = 3e-4
z_dim = 64
image_dim = 28 * 28 * 1
batch_size = 32
epochs = 100

In [10]:
discriminator = Discriminator(image_dim=image_dim).to(device=device)
generator = Generator(z_dim=z_dim, image_dim=image_dim).to(device=device)

In [11]:
fixed_noise = torch.randn((batch_size, z_dim)).to(device=device)
optimizer_generator = torch.optim.Adam(generator.parameters(), lr=lr)
optimizer_discriminator = torch.optim.Adam(discriminator.parameters(), lr=lr)
loss_fn = nn.BCELoss()
fake_writer = SummaryWriter(f"logs/fake")
real_writer = SummaryWriter(f"logs/real")

In [None]:
step = 0
for epoch in range(epochs):
    for i, (real, _) in enumerate(dataloader):
        real = real.view(-1, 784).to(device=device)
        batch_size = real.shape[0]
        ### Train Discriminator: max log(D(x)) + log(1 - D(G(z)))
        noise = torch.randn(batch_size, z_dim).to(device=device)
        fake = generator(noise)
        real_discriminator = discriminator(real).view(-1)
        real_discriminator_loss = loss_fn(
            real_discriminator, torch.ones_like(real_discriminator)
        )
        fake_discriminator = discriminator(fake).view(-1)
        fake_discriminator_loss = loss_fn(
            fake_discriminator, torch.ones_like(fake_discriminator)
        )
        discriminator_loss = (real_discriminator_loss + fake_discriminator_loss) / 2
        discriminator.zero_grad()
        discriminator_loss.backward(retain_graph=True)
        optimizer_discriminator.step()
        ### Train Generator: min log(1 - D(G(z))) <-> max log(D(G(z))
        output = discriminator(fake).view(-1)
        generator_loss = loss_fn(output, torch.ones_like(output))
        generator.zero_grad()
        generator_loss.backward()
        optimizer_generator.step()
        if i == 0:
            print(
                f"Epoch [{epoch}/{epochs}] Batch {i}/{len(dataloader)} \
                      Discriminator Loss: {discriminator_loss:.4f}, Generator Loss: {generator_loss:.4f}"
            )
            with torch.no_grad():
                fake = generator(fixed_noise).reshape(-1, 1, 28, 28)
                data = real.reshape(-1, 1, 28, 28)
                img_grid_fake = torchvision.utils.make_grid(fake, normalize=True)
                img_grid_real = torchvision.utils.make_grid(data, normalize=True)
                fake_writer.add_image(
                    "Mnist Fake Images", img_grid_fake, global_step=step
                )
                real_writer.add_image(
                    "Mnist Real Images", img_grid_real, global_step=step
                )
                step += 1