In [None]:
import torch
import torch.nn as nn
import torchvision
import torchvision.datasets as datasets
from torch.utils.data import DataLoader
import torchvision.transforms as transforms

import matplotlib.pyplot as plt


In [None]:
class Discriminator(nn.Module):
  def __init__(self, in_features: int=784):
    super().__init__()
    self.disc = nn.Sequential(
        nn.Linear(in_features, 128),
        nn.ReLU(),
        nn.Linear(128, 1),
        nn.Sigmoid()
    )

  def forward(self, x):
    return self.disc(x)


class Generator(nn.Module):
  def __init__(self, latent_dim: int = 64, out_features: int=784):
    super().__init__()
    self.gen = nn.Sequential(
        nn.Linear(latent_dim, 256),
        nn.LeakyReLU(0.2),
        nn.Linear(256, out_features),
        nn.Tanh()
    )

  def forward(self, x):
    return self.gen(x)


In [None]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5),(0.5))])
dataset = datasets.MNIST(root = 'data/', download = True, transform=transform)
train_dl = DataLoader(dataset, batch_size=64, shuffle=True)

In [None]:
def discriminator_step(discriminator, opt_d, generator, real_images, batch_size, latent_dim):
  real_preds = discriminator(real_images)
  real_targets = torch.ones_like(real_preds)
  real_loss = torch.nn.functional.binary_cross_entropy(real_preds, real_targets)

  latents = torch.randn(size=(batch_size, latent_dim))
  fake_images = generator(latents)
  fake_preds = discriminator(fake_images)
  fake_targets = torch.zeros_like(fake_preds)
  fake_loss = torch.nn.functional.binary_cross_entropy(fake_preds, fake_targets)

  opt_d.zero_grad()
  loss = real_loss + fake_loss
  loss.backward()
  opt_d.step()


def generator_step(generator, opt_g, discriminator, batch_size, latent_dim):
  latents = torch.randn(size=(batch_size, latent_dim))
  fake_images = generator(latents)
  fake_preds = discriminator(fake_images)
  fake_targets = torch.ones_like(fake_preds)
  loss = torch.nn.functional.binary_cross_entropy(fake_preds, fake_targets)

  opt_g.zero_grad()
  loss.backward()
  opt_g.step()


def train(generator,
          discriminator,
          train_dl,
          batch_size,
          latent_dim,
          num_epochs):
  opt_d = torch.optim.Adam(discriminator.parameters(), lr=1e-4)
  opt_g = torch.optim.Adam(generator.parameters(), lr=1e-4)
  fixed_latents = torch.randn(size=(64, latent_dim))

  with torch.inference_mode():
    fake_images = generator(fixed_latents)
    fake_images = torch.reshape(fake_images, (64, 1, 28, 28))
    image = (torchvision.utils.make_grid(fake_images.detach()[:64], nrow=8) + 1.0)/2.0
    plt.figure(figsize=(8,8))
    plt.imshow(image.permute(1,2,0))
    plt.axis('off')
    plt.show()

  for epoch in range(num_epochs):
    print(f"Starting epoch {epoch + 1} of {num_epochs}...")
    for xb, yb in train_dl:
      xb = torch.reshape(xb, (xb.shape[0], -1))
      discriminator_step(discriminator, opt_d, generator, xb, batch_size, latent_dim)
      generator_step(generator, opt_g, discriminator, batch_size, latent_dim)

    with torch.inference_mode():
      fake_images = generator(fixed_latents)
      fake_images = torch.reshape(fake_images, (64, 1, 28, 28))
      image = (torchvision.utils.make_grid(fake_images.detach()[:64], nrow=8) + 1.0)/2.0
      plt.figure(figsize=(8,8))
      plt.imshow(image.permute(1,2,0))
      plt.axis('off')
      plt.show()

In [None]:
disc = Discriminator()
gen = Generator()
train(gen, disc, train_dl, 64, 64, 50)