In [1]:
import torch
import torchvision
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST

import matplotlib as plt

In [2]:
random_seed = 42
torch.manual_seed(random_seed)

BATCH_SIZE = 128

In [4]:
class MNISTDataModule(torch.utils.data.DataModule):
    def __init__(self, data_dir=".", batch_size=BATCH_SIZE):
        self.data_dir = data_dir
        self.batch_size = batch_size
        self.transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
        )

    def setup(self):
        self.train_full = MNIST(
            root=".", train=True, transform=self.transform, download=True
        )
        self.train_dataset, self.val_dataset = torch.utils.data.random_split(
            self.train_full, [55000, 5000]
        )
        self.test_dataset = MNIST(
            root=".", train=False, transform=self.transform, download=True
        )

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=self.batch_size, shuffle=False)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=self.batch_size, shuffle=False)

In [5]:
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 10, kernel_size=5, stride=1)
        self.conv2 = nn.Conv2d(10, 20, kernel_size=5, stride=1)
        self.conv2_drop = nn.Dropout2d()
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 1)

    def forward(self, x):
        x = F.relu(F.max_pool2d(self.conv1(x), 2))
        x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))
        x = x.view(-1, 320)
        x = F.relu(self.fc1(x))
        x = F.dropout(x, training=self.training)
        x = self.fc2(x)
        return F.sigmoid(x)

In [6]:
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.lin1 = nn.Linear(latent_dim, 7 * 7 * 64)
        self.ct1 = nn.ConvTranspose2d(64, 32, 4, stride=2)
        self.ct2 = nn.ConvTranspose2d(32, 16, 4, stride=2)
        self.conv = nn.Conv2d(16, 1, 7)

    def forward(self, x):
        x = self.lin1(x)
        x = F.relu(x)
        x = x.view(-1, 64, 7, 7)

        x = self.ct1(x)
        x = F.relu(x)

        x = self.ct2(x)
        x = F.relu(x)

        x = self.conv(x)
        return x

In [None]:
class GAN(torchvision.Module):
    def __init__(self, latent_dim, lr=1e-3):
        super().__init__()

        self.lr = lr
        self.latent_dim = latent_dim

        self.generator = Generator(latent_dim)
        self.discriminator = Discriminator()

        self.validation_z = torch.randn(6, latent_dim)

    def forward(self, x):
        return self.discriminator(x)

    def adversarial_loss(self, y_hat, y):
        return F.binary_cross_entropy(y_hat, y)

    def training_step(self, batch, batch_idx, opt_idx):
        real_imgs, _ = batch
        z = torch.randn(real_imgs[0], self.latent_dim)
        z = z.type_as(real_imgs)

        # train generator: max log(D(G(z)))
        if opt_idx == 0:
            fake_imgs = self.generator(z)
            validity = self.discriminator(fake_imgs)
            g_loss = self.adversarial_loss(validity, torch.ones_like(validity))
            return g_loss

        # train discriminator: max log(D(x)) + log(1 - D(G(z)))
        if opt_idx == 1:
            real_validity = self.discriminator(real_imgs)
            d_real_loss = self.adversarial_loss(
                real_validity, torch.ones_like(real_validity)
            )

            fake_imgs = self.generator(z)
            fake_validity = self.discriminator(fake_imgs.detach())
            d_fake_loss = self.adversarial_loss(
                fake_validity, torch.zeros_like(fake_validity)
            )

            d_loss = (d_real_loss + d_fake_loss) / 2
            return d_loss

        return None

    def config_optimizers(self):
        opt_g = optim.Adam(self.generator.parameters(), lr=self.lr)
        opt_d = optim.Adam(self.discriminator.parameters(), lr=self.lr)
        return [opt_g, opt_d], []

    def plot_img(self):
        z = self.validation_z.type_as(self.generator.lin1.weight)
        sample_imgs = self(z).cpu()

        print("epoch", self.current_epoch)
        for i in range(sample_imgs.size(0)):
            plt.subplot(2, 3, i + 1)
            plt.tight_layout()
            plt.imshow(
                sample_imgs.detach()[i, 0, :, :], cmap="gray", interpolation="none"
            )
            plt.title("Generated data")
            plt.xticks([])
            plt.yticks([])
            plt.axis("off")
        plt.show()

    def on_epoch_end(self):
        self.plot_img()