# DCGAN PyTorch Implementation

In this notebook we will be firstly implementing DCGAN for cifar-10, and the applying it to collected Minecraft textures. The resource used for the first step was: 

- Blog tutorial on Cifar 
(https://debuggercafe.com/implementing-deep-convolutional-gan-with-pytorch/)
- PyTorch tut on faces (https://pytorch.org/tutorials/beginner/dcgan_faces_tutorial.html)

## Imports

In [1]:
import torch
import torch.nn as nn
from torchvision.utils import save_image

import torchvision.transforms as transforms
import torch.optim as optim
import torchvision.datasets as datasets

from torch.utils.data import DataLoader

# set the computation device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [2]:
import numpy as np
from matplotlib import pyplot as plt
from tqdm.notebook import tqdm

import matplotlib
matplotlib.style.use('ggplot')

## Dimensional params

In [64]:
g_size = d_size = image_size = 32 # true width x height
nz = 100 # latent vector size
batch_size = 128 # loading data in batches

## DCGAN Model

In [88]:
# Generator
class Generator(nn.Module):
    def __init__(self, nz):
        super(Generator, self).__init__()
        self.nz = nz  # noise vector to be used as input
            
        # Conv 1
        self.conv1 = nn.Sequential(
            nn.ConvTranspose2d(nz, g_size * 8, kernel_size=4, 
                               stride=2, padding=0, bias=False),
            nn.BatchNorm2d(g_size * 8),
            nn.ReLU(True),
        )

        # Conv 2
        self.conv2 = nn.Sequential(
            nn.ConvTranspose2d(g_size * 8, g_size * 4, kernel_size=4,
                              stride=2, padding=1, bias=False),
            nn.BatchNorm2d(g_size * 4),
            nn.ReLU(True),
        )

        # Conv 3
        self.conv3 = nn.Sequential(
            nn.ConvTranspose2d(g_size * 4, g_size * 2, kernel_size=4,
                               stride=2, padding=1, bias=False),
            nn.BatchNorm2d(g_size * 2),
            nn.ReLU(True),
        )

        # Conv 4
        self.conv4 = nn.Sequential(
            nn.ConvTranspose2d(g_size * 2, g_size, kernel_size=4, 
                               stride=2, padding=1, bias=False),
            nn.BatchNorm2d(g_size),
            nn.ReLU(True),
        )
            
        # Conv 5 (Flatten to image)
        self.conv5 = nn.Sequential(
            nn.ConvTranspose2d(g_size, 3, kernel_size=1, 
                               stride=1, padding=0, bias=False),
            nn.Tanh()       
        )
        
    def forward(self, x):
        #print('G')
        #print(x.size())
        x = self.conv1(x)
        #print(x.size())
        x = self.conv2(x)
        #print(x.size())
        x = self.conv3(x)
        #print(x.size())
        x = self.conv4(x)
        #print(x.size())
        x = self.conv5(x)
        #print(x.size())
        #print()
        return x

In [89]:
# Discriminator
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
            
        # Conv 1
        self.conv1 = nn.Sequential(
            nn.Conv2d(3, d_size, kernel_size=4, 
                      stride=2, padding=1, bias=False),
            nn.LeakyReLU(0.2, inplace=True)
        )
            
        # Conv 2
        self.conv2 = nn.Sequential(
            nn.Conv2d(d_size, d_size * 2, kernel_size=4, 
                      stride=2, padding=1, bias=False),
            nn.BatchNorm2d(d_size * 2),
            nn.LeakyReLU(0.2, inplace=True)
        )
            
        # Conv 3
        self.conv3 = nn.Sequential(
            nn.Conv2d(d_size * 2, d_size * 4, kernel_size=4,  
                      stride=2, padding=1, bias=False),
            nn.BatchNorm2d(d_size * 4),
            nn.LeakyReLU(0.2, inplace=True)
        )
            
        # Conv 4
        self.conv4 = nn.Sequential(
            nn.Conv2d(d_size * 4, d_size * 8, kernel_size=4, 
                      stride=2, padding=1, bias=False),
            nn.BatchNorm2d(d_size * 8),
            nn.LeakyReLU(0.2, inplace=True)
        )
            
        # Conv 5 (Final layer for fake or real)
        self.conv5 = nn.Sequential(
            nn.Conv2d(d_size * 8, 1, kernel_size=2, 
                      stride=1, padding=0, bias=False),
            nn.Sigmoid()
        )
            
            
    def forward(self, x):
        #print('D')
        #print(x.size())
        x = self.conv1(x)
        #print(x.size())
        x = self.conv2(x)
        #print(x.size())
        x = self.conv3(x)
        #print(x.size())
        x = self.conv4(x)
        #print(x.size())
        x = self.conv5(x)
        #print(x.size())
        #print()
        return x

## Utils

In [19]:
# Creating real labels by batch size
def label_real(size):
    data = torch.ones(size, 1)
    return data.to(device)


# Creating fake labels by batch size
def label_fake(size):
    data = torch.zeros(size, 1)
    return data.to(device)


# Creating noise tensor by sample/batch size and latent vector size
def create_noise(sample_size, nz):
    return torch.randn(sample_size, nz, 1, 1).to(device)

In [20]:
def save_generator_image(image, path):
    save_image(image, path, normalize=True)

In [21]:
# Initializes model weights from dist with mu=0, s=0.02
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

## Data

In [35]:
# Resizes image and normalizes values
transform = transforms.Compose([
    #transforms.Resize(image_size),
    transforms.ToTensor(),
    transforms.Normalize((0.5, 0.5, 0.5), 
    (0.5, 0.5, 0.5)),
])

In [36]:
# prepare the data
train_data = datasets.CIFAR10(
    root='../input/cifar10',
    train=True,
    download=True,
    transform=transform
)
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True)

Files already downloaded and verified


## Training

Hyperparameters

In [47]:
# Training
beta1 = 0.5 # beta1 value for Adam optimizer
lr = 0.0002 # learning rate according to paper
sample_size = 64 # fixed sample size
epochs = 25 # number of epoch to train

In [90]:
# Initialise models
generator = Generator(nz).to(device)
discriminator = Discriminator().to(device)

# Initialize weights
generator.apply(weights_init)
discriminator.apply(weights_init)

print('Initialized')

Initialized


In [80]:
# optimizers
optim_g = optim.Adam(generator.parameters(), lr=lr, betas=(beta1, 0.999))
optim_d = optim.Adam(discriminator.parameters(), lr=lr, betas=(beta1, 0.999))

# loss function
criterion = nn.BCELoss()  # Binary Cross Entropy

# Losses per epoch to later plot
losses_g = []
losses_d = []

In [81]:
# Train the discriminator for an epoch
def train_discriminator(optimizer, data_real, data_fake):
    b_size = data_real.size(0)
    
    # Get real and fake label vectors
    real_label = label_real(b_size) 
    fake_label = label_fake(b_size)
    optimizer.zero_grad()
    
    # Get the outputs by doing real data forward pass
    output_real = discriminator(data_real)
    output_real = torch.squeeze(output_real)
    real_label = torch.squeeze(real_label)
    loss_real = criterion(output_real, real_label)
    
    # Get the outputs by doing fake data forward pass
    output_fake = discriminator(data_fake)
    output_fake = torch.squeeze(output_fake)
    fake_label = torch.squeeze(fake_label)
    loss_fake = criterion(output_fake, fake_label)
    
    # Compute gradients of real and fake losses 
    loss_real.backward()
    loss_fake.backward()
    
    # update discriminator parameters
    optimizer.step()
    
    return loss_real + loss_fake

In [82]:
# Train the generator for an epoch
def train_generator(optimizer, data_fake):
    b_size = data_fake.size(0)
    
    # Get the real label vector
    real_label = label_real(b_size)
    optimizer.zero_grad()
    
    # Output by doing a forward pass of the fake data through discriminator
    output = discriminator(data_fake)
    output = torch.squeeze(output)
    real_label = torch.squeeze(real_label)
    loss = criterion(output, real_label)
    
    # compute gradients of loss
    loss.backward()
    
    # update generator parameters
    optimizer.step()
    
    return loss 

In [83]:
# Create the noise vector right before training
noise = create_noise(sample_size, nz)

In [91]:
# Training loop over multiple epochs
for epoch in tqdm(range(epochs)):
    loss_g = 0.0
    loss_d = 0.0
    
    for bi, data in tqdm(enumerate(train_loader), total=int(len(train_data)/train_loader.batch_size)):
        image, _ = data
        image = image.to(device)
        b_size = len(image)
        # forward pass through generator to create fake data
        data_fake = generator(create_noise(b_size, nz)).detach()
        data_real = image
        loss_d += train_discriminator(optim_d, data_real, data_fake)
        data_fake = generator(create_noise(b_size, nz))
        loss_g += train_generator(optim_g, data_fake)
        
    # Final forward pass through generator to create fake data after one epoch
    generated_img = generator(noise).cpu().detach()
    
    # Save the generated results to disk
    save_generator_image(generated_img, f"../outputs/cifar10/gen_img{epoch}.png")
    epoch_loss_g = loss_g / bi # total generator loss for the epoch
    epoch_loss_d = loss_d / bi # total discriminator loss for the epoch
    losses_g.append(epoch_loss_g)
    losses_d.append(epoch_loss_d)
    
    print(f"Epoch {epoch+1} of {epochs}")
    print(f"Generator loss: {epoch_loss_g:.8f}, Discriminator loss: {epoch_loss_d:.8f}")
    
# Save the model
print('DONE TRAINING')
# save the model weights to disk
torch.save(generator.state_dict(), '../outputs/cifar10/generator.pth')

# plot and save the generator and discriminator loss
plt.figure()
plt.plot(losses_g, label='Generator loss')
plt.plot(losses_d, label='Discriminator Loss')
plt.legend()
plt.savefig('../outputs/cifar10/loss.png')
plt.show()

  0%|          | 0/25 [00:00<?, ?it/s]

  0%|          | 0/390 [00:00<?, ?it/s]

Epoch 1 of 25
Generator loss: 0.84182638, Discriminator loss: 1.41060674


  0%|          | 0/390 [00:00<?, ?it/s]

Epoch 2 of 25
Generator loss: 0.84199321, Discriminator loss: 1.40939081


  0%|          | 0/390 [00:00<?, ?it/s]

Epoch 3 of 25
Generator loss: 0.84135550, Discriminator loss: 1.41057253


  0%|          | 0/390 [00:00<?, ?it/s]

Epoch 4 of 25
Generator loss: 0.84103590, Discriminator loss: 1.41049826


  0%|          | 0/390 [00:00<?, ?it/s]

Epoch 5 of 25
Generator loss: 0.84057927, Discriminator loss: 1.41007519


  0%|          | 0/390 [00:00<?, ?it/s]

Epoch 6 of 25
Generator loss: 0.84204525, Discriminator loss: 1.41032183


  0%|          | 0/390 [00:00<?, ?it/s]

Epoch 7 of 25
Generator loss: 0.84193182, Discriminator loss: 1.41063225


  0%|          | 0/390 [00:00<?, ?it/s]

Epoch 8 of 25
Generator loss: 0.84213740, Discriminator loss: 1.41107023


  0%|          | 0/390 [00:00<?, ?it/s]

Epoch 9 of 25
Generator loss: 0.84008247, Discriminator loss: 1.41167641


  0%|          | 0/390 [00:00<?, ?it/s]

KeyboardInterrupt: 