In [1]:
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torch.optim as optim
import torchvision.datasets as datasets
import imageio
import numpy as np
import matplotlib

from torchvision.utils import make_grid, save_image
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm
import os

matplotlib.style.use('ggplot')

In [2]:
BATCH_SIZE = 16
EPOCHS = 200
SAMPLE_SIZE = 64
NZ = 128
k = 1
device = "cuda" if torch.cuda.is_available() else "cpu"

In [3]:
transform = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize((0.5,), (0.5,))
    ]
)
to_pil_image = transforms.ToPILImage()

In [4]:
train_data = datasets.MNIST(root="../input/data/",
                            train=True,
                            download=True,
                            transform=transform)
train_dataloader = DataLoader(train_data,
                              batch_size=BATCH_SIZE,
                              shuffle=True,
                              num_workers=os.cpu_count())

In [5]:
class Generator(nn.Module):
    def __init__(self, nz):
        super().__init__()
        self.nz = nz
        self.layer_stack = nn.Sequential(
            nn.Linear(self.nz, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 784),
            nn.Tanh()
        )
    def forward(self, x):
        return self.layer_stack(x).view(-1, 1, 28, 28)

In [6]:
class Discriminator(nn.Module):
    def __init__(self, input):
        super().__init__()
        self.input = input
        self.layer_stack = nn.Sequential(
            nn.Linear(self.input, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 512),
            nn.LeakyReLU(0.2),
            nn.Linear(512, 256),
            nn.LeakyReLU(0.2),
            nn.Linear(256, 1),
            nn.Sigmoid()
        )
    def forward(self, x):
        x = nn.Flatten(x)
        return self.layer_stack(x)

In [7]:
generator = Generator(NZ).to(device)
discriminator = Discriminator(NZ).to(device)

In [8]:
optim_g = optim.Adam(generator.parameters(), 0.0002)
optim_d = optim.Adam(discriminator.parameters(), 0.0002)

In [9]:
loss_fn = nn.BCELoss()

In [10]:
losses_g = []
losses_d = []
images = []

In [11]:
def label_real(size):
    data = torch.ones(size, 1)
    return data.to(device)

In [12]:
def label_fake(size):
    data = torch.zeros(size, 1)
    return data.to(device)

In [13]:
def create_noise(sample_size, nz):
    return torch.randn(sample_size, nz).to(device)

In [14]:
def save_generator_image(image, path):
    save_image(image, path)

In [15]:
def train_discriminator(optimizer, data_real, data_fake):
    b_size = data_real.size(0)
    real_label = label_real(b_size)
    fake_label = label_fake(b_size)

    optimizer.zero_grad()
    real_preds = discriminator(data_real)
    real_loss = loss_fn(real_preds, real_label)

    fake_preds = discriminator(data_fake)
    fake_loss = loss_fn(fake_preds, fake_label)

    real_loss.backward()
    fake_loss.backward()
    optimizer.step()

    return real_loss + fake_loss

In [16]:
def train_generator(optimizer, data_fake):
    b_size = data_fake.size(0)
    real_label = label_real(b_size)

    optimizer.zero_grad()
    preds = discriminator(data_fake)
    loss = loss_fn(preds, real_label)

    loss.backward()
    optimizer.step()
    return loss

In [17]:
noise = create_noise(SAMPLE_SIZE, NZ)

In [18]:
noise[0].shape

torch.Size([128])

In [19]:
generator.train()
discriminator.train()

Discriminator(
  (layer_stack): Sequential(
    (0): Linear(in_features=128, out_features=1024, bias=True)
    (1): LeakyReLU(negative_slope=0.2)
    (2): Linear(in_features=1024, out_features=512, bias=True)
    (3): LeakyReLU(negative_slope=0.2)
    (4): Linear(in_features=512, out_features=256, bias=True)
    (5): LeakyReLU(negative_slope=0.2)
    (6): Linear(in_features=256, out_features=1, bias=True)
    (7): Sigmoid()
  )
)

In [20]:
for epoch in range(EPOCHS):
    loss_g = 0.0
    loss_d = 0.0
    for bi, data in tqdm(enumerate(train_dataloader)):
        image, _  = data
        image = image.to(device)
        b_size = len(image)
        for step in range(k):
            data_fake = generator(create_noise(b_size, NZ)).detach()
            data_real = image
            loss_d += train_discriminator(optimizer=optim_d, data_fake=data_fake, data_real=data_real)
        data_fake = generator(create_noise(b_size, NZ))
        loss_g += train_generator(optimizer=optim_g, data_fake=data_fake)
    generated_img = generator(noise).cpu().detach()
    generated_img = make_grid(generated_img)
    save_generator_image(generated_img,  f"../outputs/gen_img{epoch}.png")
    images.append(generated_img)
    epoch_loss_g = loss_g / bi
    epoch_loss_d = loss_d / bi
    losses_g.append(epoch_loss_g)
    losses_d.append(epoch_loss_d)
    print(f"Epoch {epoch} of {EPOCHS}")
    print(f"Generator loss: {epoch_loss_g:.8f}, Discriminator loss: {epoch_loss_d:.8f}")

0it [00:00, ?it/s]

0it [00:12, ?it/s]


TypeError: linear(): argument 'input' (position 1) must be Tensor, not Flatten