In [1]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torchvision.utils as vutils

In [2]:
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


# Hyperparameters taken from https://arxiv.org/pdf/1511.06434

In [3]:
# Hyperparameters
nz = 100                # Size of the latent z vector (input to the generator)
ngf = 64                # Size of feature maps in the generator
ndf = 64                # Size of feature maps in the discriminator
nc = 3                  # Number of channels in the training images. For colored images this is 3
lr = 0.0002             # Learning rate for optimizers
betas = (0.5, 0.999)    # Beta1 and Beta2 hyperparameters for Adam optimizer
batch_size = 128        # Batch size during training
num_epochs = 50         # Number of training epochs
weight_init_mean = 0.0
weight_init_std = 0.02

# Structure taken from https://arxiv.org/pdf/1511.06434

In [4]:
# Generator implementation
class Generator(nn.Module):
    def __init__(self, nz, ngf, nc):
        super(Generator, self).__init__()
        self.main = nn.Sequential(
            nn.ConvTranspose2d(nz, ngf * 8, 4, 1, 0, bias=False),
            nn.BatchNorm2d(ngf * 8),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 8, ngf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 4),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 4, ngf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf * 2),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf * 2, ngf, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ngf),
            nn.ReLU(True),
            nn.ConvTranspose2d(ngf, nc, 4, 2, 1, bias=False),
            nn.Tanh()
        )

    def forward(self, input):
        return self.main(input)

In [5]:
# Discriminator implementation
class Discriminator(nn.Module):
    def __init__(self, nc, ndf):
        super(Discriminator, self).__init__()
        self.main = nn.Sequential(
            nn.Conv2d(nc, ndf, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf, ndf * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 2, ndf * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 4, ndf * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(ndf * 8),
            nn.LeakyReLU(0.2, inplace=True),
            nn.Conv2d(ndf * 8, 1, 4, 1, 0, bias=False),
            nn.Sigmoid()
        )

    def forward(self, input):
        return self.main(input).view(-1)

In [6]:
# Weight initialization function for the generator and discriminator
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, weight_init_mean, weight_init_std)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, weight_init_std)
        nn.init.constant_(m.bias.data, 0)

In [7]:
def train_gan(data_folder, output_folder, nz=100, ngf=64, ndf=64, nc=3, num_epochs=50, lr=0.0002, betas=(0.5, 0.999), batch_size=128):

    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    # Create output folder if it doesn't exist
    os.makedirs(output_folder, exist_ok=True)

    # Image transformations
    transform = transforms.Compose([
        transforms.Resize(64),
        transforms.CenterCrop(64),
        transforms.ToTensor(),
        transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
    ])

    # Load dataset
    dataset = datasets.ImageFolder(root=data_folder, transform=transform)
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # Create the generator and discriminator
    netG = Generator(nz=nz, ngf=ngf, nc=nc).to(device)
    netG.apply(weights_init)

    netD = Discriminator(nc=nc, ndf=ndf).to(device)
    netD.apply(weights_init)

    # Loss function and optimizers
    criterion = nn.BCELoss()
    optimizerD = optim.Adam(netD.parameters(), lr=lr, betas=betas)
    optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=betas)

    # Noise for generating samples
    fixed_noise = torch.randn(64, nz, 1, 1, device=device)

    # Training loop
    for epoch in range(num_epochs):
        for i, data in enumerate(dataloader, 0):

            # Update Discriminator with real data
            netD.zero_grad()
            real_images = data[0].to(device)
            b_size = real_images.size(0)
            label = torch.full((b_size,), 1., device=device)
            output = netD(real_images).view(-1)
            errD_real = criterion(output, label)
            errD_real.backward()
            D_x = output.mean().item()

            # Update Discriminator with fake data
            noise = torch.randn(b_size, nz, 1, 1, device=device)
            fake_images = netG(noise)
            label.fill_(0.)
            output = netD(fake_images.detach()).view(-1)
            errD_fake = criterion(output, label)
            errD_fake.backward()
            D_G_z1 = output.mean().item()
            errD = errD_real + errD_fake
            optimizerD.step()

            # Update Generator
            netG.zero_grad()
            label.fill_(1.)
            output = netD(fake_images).view(-1)
            errG = criterion(output, label)
            errG.backward()
            D_G_z2 = output.mean().item()
            optimizerG.step()

            # Print statistics
            if i % 50 == 0:
                print(f'Epoch [{epoch}/{num_epochs}] Batch {i}/{len(dataloader)} '
                      f'Loss D: {errD.item():.4f}, Loss G: {errG.item():.4f}, '
                      f'D(x): {D_x:.4f}, D(G(z)): {D_G_z1:.4f} / {D_G_z2:.4f}')

        # Save images every 10 epochs
        if epoch % 10 == 0:
            with torch.no_grad():
                fake_images = netG(fixed_noise).detach().cpu()
            vutils.save_image(fake_images, f'{output_folder}/samples_epoch_{epoch}.png', normalize=True)

    # Save final models
    torch.save(netG.state_dict(), f'{output_folder}/models/generator.pth')
    torch.save(netD.state_dict(), f'{output_folder}/models/discriminator.pth')

### D_x: Describes the average output of the Discriminator for real images
###
### D_G_z1: Describes the average output of the Discriminator for fake images before the Generator update
###
### D_G_z2: Describes the average output of the Discriminator for fake images after the Generator update

In [8]:
data_folder = "D:/Uni/GAN_powered_Pokemon_Generation/data/All_Generations"
output_folder = "D:/Uni/GAN_powered_Pokemon_Generation/lab/output"
train_gan(data_folder, output_folder, num_epochs=500)

Epoch [0/500] Batch 0/8 Loss D: 1.8052, Loss G: 1.9593, D(x): 0.2843, D(G(z)): 0.2644 / 0.1809
Epoch [1/500] Batch 0/8 Loss D: 0.3008, Loss G: 6.4154, D(x): 0.8794, D(G(z)): 0.0602 / 0.0029
Epoch [2/500] Batch 0/8 Loss D: 0.1789, Loss G: 5.8592, D(x): 0.8893, D(G(z)): 0.0366 / 0.0047
Epoch [3/500] Batch 0/8 Loss D: 0.0299, Loss G: 4.4396, D(x): 0.9889, D(G(z)): 0.0145 / 0.0175
Epoch [4/500] Batch 0/8 Loss D: 0.1812, Loss G: 9.2975, D(x): 0.8885, D(G(z)): 0.0002 / 0.0002
Epoch [5/500] Batch 0/8 Loss D: 4.3191, Loss G: 22.5546, D(x): 0.0478, D(G(z)): 0.0000 / 0.0000
Epoch [6/500] Batch 0/8 Loss D: 1.3619, Loss G: 18.2902, D(x): 0.9836, D(G(z)): 0.6345 / 0.0000
Epoch [7/500] Batch 0/8 Loss D: 0.2262, Loss G: 10.8749, D(x): 0.8811, D(G(z)): 0.0000 / 0.0001
Epoch [8/500] Batch 0/8 Loss D: 1.0400, Loss G: 15.5426, D(x): 0.9786, D(G(z)): 0.5488 / 0.0000
Epoch [9/500] Batch 0/8 Loss D: 0.2173, Loss G: 5.0966, D(x): 0.8786, D(G(z)): 0.0328 / 0.0277
Epoch [10/500] Batch 0/8 Loss D: 2.0145, Loss 