This GAN is built using fully connected layers. However, instead of the discriminator using a sigmoid output activation to classify the image, it uses ReLU and assigns a higher score to real images. This also makes it a critic, as it rates the images instead of classifies them. Instead of using BCE Loss, it tries to maximize the Wasserstein distance between the real and fake guess distribution. Likewise, the generator tries to maximize the output of the discriminator (make it think the images are real).

The losses to minimize are as follows:
$$Loss_D = C(G(z)) - C(x)$$
$$Loss_G = -C(G(z))$$
With $C$ being the critic, $G$ being the generator, $z$ is the random noise it samples from, and $x$ is the real image.

In the [paper](https://arxiv.org/pdf/1701.07875.pdf) (Arjovsky et al. 2017), $C$ must be a 1-Lipschitz continuous function. To enforce this, they clip the weights within a certain range. The authors state this is "clearly a terrible way" to do this, but that it is simple. Another option is the gradient penalty introduced in this [paper](https://arxiv.org/pdf/1704.00028.pdf) (Gulrajani et al. 2017), which adds a term
 $$\lambda  \mathbb{E}  [||\nabla C(x)||_2 - 1)^2]$$
 to the critic loss to penalize the norm of the gradient. This removes the need for clipping the weights, as they will tend to remain small to minimize the loss.


---

## Imports

In [1]:
import torch.optim

from mnist_wgan import Generator, Critic
import torchvision
from torch.utils.data import DataLoader
from torchvision import transforms

## Get MNIST

In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.5), (0.5))
])

data = torchvision.datasets.MNIST(
    root="../datasets/",
    train=True,
    download=True,
    transform=transform
)

loader = DataLoader(dataset=data, batch_size=32, shuffle=True)

## Hyperparameters and initialize models

In [None]:
noise_dimension = 100
image_dimension = 28 * 28
features = 256
device = "mps" if torch.cuda.is_available() else "cpu"
fixed_noise = torch.randn((16, noise_dimension)).to(device)

generator = Generator(noise_dimension, image_dimension, features).to(device)
critic = Critic(image_dimension, features).to(device)
gen_optimizer = torch.optim.Adam(generator.parameters(), lr=3e-4)
crit_optimizer = torch.optim.Adam(critic.parameters(), lr=3e-4)

generator_loss_history = []
discriminator_loss_history = []

In [None]:
device