In [None]:
import torch as pt
from torchvision.datasets import MNIST
from torchvision import transforms
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
from torchvision.utils import save_image

from matplotlib import pyplot as plt
%matplotlib inline

In [None]:
transforms_ = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean=(0.5), std=(0.5)),
])

In [None]:
device = "cuda" if pt.cuda.is_available() else "cpu"
latent_dim = 100
batch_size = 64
gen_lr = 2e-4
disc_lr = 2e-4
epochs = 50
stats = (0.5), (0.5)

In [None]:
train_dataset = MNIST("./data", train=True, download=True, transform=transforms_)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

test_dataset = MNIST("./data", train=False, download=True, transform=transforms_)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

In [None]:
# def denorm(img_tensors):
#     return img_tensors * stats[1][0] + stats[0][0]

In [None]:
def show_images(images, n_max=batch_size):
    _, ax = plt.subplots(figsize=(8, 8))
    ax.set_xticks([]); ax.set_yticks([])
    # ax.imshow(make_grid(denorm(images.detach().cpu()[:n_max]), nrow=8).permute(1, 2, 0))
    ax.imshow(make_grid(images.detach().cpu()[:n_max]).permute(1, 2, 0))

In [None]:
from torch import nn
from torch.nn import functional as F

In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, *args, **kwargs):
        super().__init__(*args, **kwargs)

        # latent_dim size: (64, 100, 1, 1) i.e. (B, C, H, W) format

        self.conv_transpose_block1 = nn.Sequential(
            nn.ConvTranspose2d(latent_dim, 128, kernel_size=6, padding=0, stride=1, bias=False),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # Height = 1 * (1 - 1) + 6 - 2*0 (Stride - (image height - 1) + kernel size - 2*padding)
            # = 1 * (0) + 6 - (0)
            # = 6
            # (64, 128, 6, 6) (6x6 because height and width are the same)

            nn.ConvTranspose2d(128, 256, kernel_size=6, padding=0, stride=2, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # Height = 2 * (6 - 1) + 6 - 2*0
            # = 2 * (5) + 6 - (0)
            # = 10 + 6
            # = 16
            # (64, 256, 16, 16) (16x16 because height and width are the same)

            nn.ConvTranspose2d(256, 512, kernel_size=4, padding=0, stride=2, bias=False),
            nn.BatchNorm2d(512),
            nn.ReLU(True),
            # Height = 2 * (16 - 1) + 4 - 2*0
            # = 2 * (15) + 4 - (0)
            # = 30 + 4
            # = 34
            # (64, 512, 34, 34) (34x34 because height and width are the same)

            nn.ConvTranspose2d(512, 256, kernel_size=4, padding=1, stride=2, bias=False),
            nn.BatchNorm2d(256),
            nn.ReLU(True),
            # Height = 2 * (34 - 1) + 4 - 2*1
            # = 2 * (33) + 4 - (2)
            # = 66 + 2
            # = 68
            # (64, 256, 68, 68) (68x68 because height and width are the same)

            nn.ConvTranspose2d(256, 128, kernel_size=3, padding=1, stride=1, bias=True),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            # Height = 1 * (68 - 1) + 3 - 2*1
            # = 1 * (67) + 3 - (2)
            # = 67 + 1
            # = 68
            # (64, 128, 68, 68) (68x68 because height and width are the same)

            nn.ConvTranspose2d(128, 64, kernel_size=3, padding=1, stride=1, bias=True),
            nn.BatchNorm2d(64),
            nn.ReLU(True),
            # Stride * (Image height in pixels - 1) + kernel size - 2*padding
            # Height = 1 * (68 - 1) + 3 - 2*1
            # = 1 * (67) + 3 - (2)
            # = 67 + 1
            # = 68
            # (64, 64, 68, 68) (68x68 because height and width are the same)

            nn.ConvTranspose2d(64, 32, kernel_size=3, padding=1, stride=1, bias=True),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            # Stride * (Image height in pixels - 1) + kernel size - 2*padding
            # Height = 1 * (68 - 1) + 3 - 2*1
            # = 1 * (67) + 3 - (2)
            # = 67 + 1
            # = 68
            # (64, 32, 68, 68) (68x68 because height and width are the same)

            nn.ConvTranspose2d(64, 28, kernel_size=3, padding=1, stride=1, bias=True),
            nn.BatchNorm2d(28),
            nn.ReLU(True),
            # Stride * (Image height in pixels - 1) + kernel size - 2*padding
            # Height = 1 * (68 - 1) + 3 - 2*1
            # = 1 * (67) + 3 - (2)
            # = 67 + 1
            # = 68
            # (64, 28, 68, 68) (68x68 because height and width are the same)

            nn.Conv2d(28, 1, kernel_size=3, padding=0, stride=2, bias=False),
            # (W – F + 2P) / S + 1
            # (68 - 3 + (2*0)) / 2 + 1
            # (65) / 2 + 1
            # 32.5 + 1
            # 33
            nn.Conv2d(1, 1, kernel_size=3, padding=0, stride=1, bias=False),
            # (W – F + 2P) / S + 1
            # (33 - 3 + (2*0)) / 1 + 1
            # (30) / 1 + 1
            # 30 + 1
            # 31
            nn.Conv2d(1, 1, kernel_size=4, padding=0, stride=1, bias=False),
            # (W – F + 2P) / S + 1
            # (31 - 4 + (2*0)) / 1 + 1
            # (27) / 1 + 1
            # 27 + 1
            # 28
        )

    def forward(self, x):
        x = self.conv_transpose_block1(x)
        return F.tanh(x)

In [None]:
generator = Generator(latent_dim).to(device)

xb = pt.randn(batch_size, latent_dim, 1, 1) # random latent tensors - (B, C, H, W)
fake_images = generator(xb.to(device))
print(fake_images.shape)
show_images(fake_images)

In [None]:
class Discriminator(nn.Module):
    def __init__(self, *args, **kwargs) -> None:
        super().__init__(*args, **kwargs)
        self.discriminator = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(64),
            nn.MaxPool2d(2),
            nn.LeakyReLU(0.2, inplace=True),
            # (28 – 3 + (2*1)) / 1 + 1 = 28 / maxpool = 14

            nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(128),
            nn.MaxPool2d(2),
            nn.LeakyReLU(0.2, inplace=True),
            # (14 – 3 + (2*1)) / 1 + 1 = 14 / maxpool = 7

            nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1, bias=False),
            nn.BatchNorm2d(256),
            nn.MaxPool2d(2),
            nn.LeakyReLU(0.2, inplace=True),
            # (7 – 3 + (2*1)) / 1 + 1 = 7 / maxpool = 3
        )
        self.dense1 = nn.Linear(256*3*3, 1)

    def forward(self, x):
        x = self.discriminator(x)
        x = x.view(-1, 256*3*3)
        x = self.dense1(x)
        return F.sigmoid(x)

In [None]:
discriminator = Discriminator().to(device)

In [None]:
def train_generator(opt_g):
    # Clear generator gradients
    opt_g.zero_grad()

    # Generate fake images
    latent = pt.randn(batch_size, latent_dim, 1, 1, device=device)
    fake_images = generator.train()(latent)

    # Try to fool the discriminator
    preds = discriminator(fake_images)
    targets = pt.ones(batch_size, 1, device=device)
    loss = F.binary_cross_entropy(preds, targets)

    # Update generator weights
    loss.backward()
    opt_g.step()

    return loss.item()

In [None]:
def train_discriminator(real_images, opt_d):
    # Clear discriminator gradients
    opt_d.zero_grad()

    # Pass real images through discriminator
    real_preds = discriminator.train()(real_images)
    real_targets = pt.ones(real_images.size(0), 1, device=device)
    real_loss = F.binary_cross_entropy(real_preds, real_targets)
    # real_score = pt.mean(real_preds).item()

    # Generate fake images
    latent = pt.randn(batch_size, latent_dim, 1, 1, device=device)
    fake_images = generator.train()(latent)

    # Pass fake images through discriminator
    fake_targets = pt.zeros(fake_images.size(0), 1, device=device)
    fake_preds = discriminator.train()(fake_images)
    fake_loss = F.binary_cross_entropy(fake_preds, fake_targets)
    # fake_score = pt.mean(fake_preds).item()

    # Update discriminator weights
    loss = real_loss + fake_loss # (... + ...) / 2
    loss.backward()
    opt_d.step()
    return loss.item(), real_loss.detach().cpu().item(), fake_loss.detach().cpu().item() # real_score, fake_score

In [None]:
import os

sample_dir = 'generated-images'
os.makedirs(sample_dir, exist_ok=True)

In [None]:
def save_samples(index, latent_tensors, show=True):
    fake_images = generator.eval()(latent_tensors)
    fake_fname = 'generated-images-{0:0=4d}.png'.format(index)
    save_image(
        # denorm(fake_images.detach().cpu()),
        fake_images.detach().cpu(),
        os.path.join(sample_dir, fake_fname),
        nrow=8
    )
    print(f"Saving {fake_fname}...")
    if show:
        _, ax = plt.subplots(figsize=(8, 8))
        ax.set_xticks([]); ax.set_yticks([])
        ax.imshow(make_grid(fake_images.detach().cpu(), nrow=8).permute(1, 2, 0))
        plt.show()

In [None]:
opt_g = pt.optim.Adam(generator.parameters(), lr=gen_lr, betas=(0.5, 0.999))
opt_d = pt.optim.Adam(discriminator.parameters(), lr=disc_lr, betas=(0.5, 0.999))

fixed_latent = pt.randn(64, latent_dim, 1, 1, device=device)

In [None]:
import numpy as np
from tqdm import tqdm

In [None]:
losses_g, losses_d = [], []
real_scores, fake_scores = [], []

for epoch in range(epochs):
    batch_loss_g, batch_loss_d = [], []
    batch_real_score, batch_fake_score = [], []

    for loader in train_loader, test_loader:
        # Train discriminator
        for real_images, _ in tqdm(loader):
            loss_d, real_score, fake_score = train_discriminator(real_images.to(device), opt_d)
            batch_loss_d.append(loss_d)
            batch_real_score.append(real_score)
            batch_fake_score.append(fake_score)
            # Train generator
            loss_g = train_generator(opt_g)
            batch_loss_g.append(loss_g)

    # Record losses & scores
    losses_g.append(np.mean(batch_loss_g))
    losses_d.append(np.mean(batch_loss_d))
    real_scores.append(np.mean(batch_real_score))
    fake_scores.append(np.mean(batch_fake_score))

    batch_loss_g, batch_loss_d = [], []
    batch_real_score, batch_fake_score = [], []

    # Save generated images
    save_samples(epoch+1, fixed_latent, show=False)

    # Log losses & scores (last batch)
    print(f"Epoch [{epoch+1}/{epochs}], loss_g: {loss_g:.4f}, loss_d: {loss_d:.4f},"
          + f" real_score: {real_score:.4f}, fake_score: {fake_score:.4f}\n")

In [None]:
plt.figure(figsize=(10, 8), dpi=100)

plt.subplot(2, 2, 1)
plt.plot(losses_g, '-')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Generator Loss')

plt.subplot(2, 2, 2)
plt.plot(losses_d, '-')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Discriminator Loss')

plt.subplot(2, 2, 3)
plt.plot(real_scores, '-')
plt.xlabel('Epoch')
plt.ylabel('Score')
plt.title('Real Images Discriminator Scores')

plt.subplot(2, 2, 4)
plt.plot(fake_scores, '-')
plt.xlabel('Epoch')
plt.ylabel('Score')
plt.title('Fake Images Discriminator Scores')

In [None]:
pt.save(generator.state_dict(), "models/generator-model")
pt.save(discriminator.state_dict(), "models/discriminator-model")