In [0]:
!pip install torchvision



In [0]:
from torchvision.datasets import FashionMNIST
from torchvision import transforms

import torch
from torch.utils.data import DataLoader
from torch import nn
from torch import optim

from matplotlib import pyplot as plt
import numpy as np

import os

In [0]:
device = torch.device("cuda")

In [0]:
transformations = transforms.Compose([transforms.ToTensor(), 
                                      transforms.Normalize(mean=[0.5, ], std=[0.5, ])])

train_dataset = FashionMNIST('./train', download=True, transform=transformations)
test_dataset = FashionMNIST('./test', train=False, download=True, transform=transformations)

# configs
batch_size = 16

train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset)

In [0]:
class Discriminator(nn.Module):
    def __init__(self, input_size: int, num_classes: int):
        super().__init__()
        self.input_size = input_size
        self.num_classes = num_classes

        self.pack1 = nn.Sequential(
            nn.Linear(input_size, 512),
            nn.BatchNorm1d(512),
            nn.LeakyReLU(0.05)
        )
        self.pack2 = nn.Sequential(
            nn.Linear(512, 256),
            nn.BatchNorm1d(256),
            nn.LeakyReLU(0.05)
        )
        self.pack3 = nn.Sequential(
            nn.Linear(256, 128),
            nn.BatchNorm1d(128),
            nn.LeakyReLU(0.05)
        )
        self.pack4 = nn.Sequential(
            nn.Linear(128, num_classes)
        )

    def forward(self, x):
        x = x.view(-1, self.input_size)
        output = self.pack1(x)
        output = self.pack2(output)
        output = self.pack3(output)
        return self.pack4(output)

In [0]:
class Generator(nn.Module):
    def __init__(self, noise_size: int, output_width: int):
        super().__init__()
        self.noise_size = noise_size
        self.output_width = output_width

        self.linear1 = nn.Linear(noise_size, 128)
        self.batch_norm1 = nn.BatchNorm1d(128)
        self.linear2 = nn.Linear(128, 256)
        self.batch_norm2 = nn.BatchNorm1d(256)
        self.linear3 = nn.Linear(256, 512)
        self.batch_norm3 = nn.BatchNorm1d(512)
        self.linear4 = nn.Linear(512, output_width ** 2)
        self.activation = nn.ReLU()

    def forward(self, x):
        out = self.activation(self.batch_norm1(self.linear1(x)))
        out = self.activation(self.batch_norm2(self.linear2(out)))
        out = self.activation(self.batch_norm3(self.linear3(out)))
        out = self.linear4(out).view(-1, 1, self.output_width, self.output_width)
        return out

In [0]:
class NoiseGenerator:
    def __init__(self, noise_dim: int, device):
        self.noise_dim: int = noise_dim
        self.device = device

    def generate(self, batch_size: int):
        return torch.randn((batch_size, self.noise_dim)).to(self.device)

In [0]:
def init_nets(image_width: int, 
              noise_dim: int,
              device):
    discriminator = Discriminator(image_width ** 2, 1).to(device)
    generator = Generator(noise_dim, image_width).to(device)
    return discriminator, generator

noise_dim = 100
image_width = 28
num_classes = len(train_dataset.classes)

# Models
D, G = init_nets(image_width, noise_dim, device)
noise_generator = NoiseGenerator(noise_dim, device)

# Optimizers
D_opt = optim.RMSprop(D.parameters(), lr = 1e-3)
G_opt = optim.RMSprop(G.parameters(), lr = 1e-3)

# Losses
criterion = nn.BCEWithLogitsLoss()

In [0]:
def get_sample_image(generator: nn.Module, 
                     n_images: int, 
                     noise_generator: NoiseGenerator):
    noise = noise_generator.generate(n_images)
    fake_images = np.squeeze(generator(noise).cpu().detach().numpy(), axis=1)
    output_img = np.zeros((28, n_images*28))

    for i in range(n_images):
        output_img[:, 28*i:28*(i+1)] = fake_images[i]

    return output_img

In [0]:
# labels
true_labels = torch.ones(batch_size, 1).to(device)
fake_labels = torch.zeros(batch_size, 1).to(device)

epochs = 25
n_critic = 1

d_losses = []
g_losses = []

for epoch in range(epochs):
    for step, (images, labels) in enumerate(train_dataloader):
        noise = noise_generator.generate(batch_size)
        fake_images = G(noise)
        predicted_fake_labels = D(fake_images)
        fake_loss = criterion(predicted_fake_labels, fake_labels)

        images = images.to(device)
        predicted_true_labels = D(images)
        true_loss = criterion(predicted_true_labels, true_labels)

        discriminator_loss = fake_loss + true_loss
        D.zero_grad()
        discriminator_loss.backward()
        D_opt.step()

        if step % n_critic == 0:
            noise = noise_generator.generate(batch_size)
            predicted_fake_labels = D(G(noise))
            G_loss = criterion(predicted_fake_labels, true_labels)

            G.zero_grad()
            G_loss.backward()
            G_opt.step()

        if step % 200 == 0:
            d_losses.append(discriminator_loss.item())
            g_losses.append(G_loss.item())

        if step % 499 == 0:
            print('Epoch: {}/{}, Step: {}, D Loss: {}, G Loss: {}'.format(epoch, epochs, step, discriminator_loss.item(), G_loss.item()))
            
            G.eval()
            img = get_sample_image(G, 8, noise_generator)

            dirname = 'samples'
            if not os.path.exists(dirname): os.makedirs(dirname)
            plt.imsave('{}/{}_{}_{}.jpg'.format(dirname, "Vanilla-GAN", str(epoch).zfill(2), str(step).zfill(4)), img, cmap='gray')

            G.train()

Epoch: 0/25, Step: 0, D Loss: 1.4579217433929443, G Loss: 0.7650353908538818
Epoch: 0/25, Step: 499, D Loss: 0.8551731109619141, G Loss: 1.15485680103302
Epoch: 0/25, Step: 998, D Loss: 1.4883277416229248, G Loss: 1.6045719385147095
Epoch: 0/25, Step: 1497, D Loss: 0.9425194263458252, G Loss: 3.7760517597198486
Epoch: 0/25, Step: 1996, D Loss: 0.08200238645076752, G Loss: 4.819361686706543
Epoch: 0/25, Step: 2495, D Loss: 0.12142277508974075, G Loss: 3.295644760131836
Epoch: 0/25, Step: 2994, D Loss: 0.5682553052902222, G Loss: 3.8770785331726074
Epoch: 0/25, Step: 3493, D Loss: 0.23580729961395264, G Loss: 1.5679724216461182
Epoch: 1/25, Step: 0, D Loss: 0.47671765089035034, G Loss: 4.796777248382568
Epoch: 1/25, Step: 499, D Loss: 1.0586532354354858, G Loss: 4.165494918823242
Epoch: 1/25, Step: 998, D Loss: 0.05622059106826782, G Loss: 3.729593515396118
Epoch: 1/25, Step: 1497, D Loss: 0.051049187779426575, G Loss: 4.0515007972717285
Epoch: 1/25, Step: 1996, D Loss: 0.132385581731796