sources:  
https://arxiv.org/abs/1511.06434  - 'Unsupervised Representation Learning with Deep Convolutional Generative Adversarial Networks'  
https://github.com/pytorch/examples/blob/master/dcgan/main.py - PyTorch implementation of the paper




changes made:
- **Doubled image size** - now 128x128 instead of 64x64 (adding a layer in both networks)
- **Unbalanced G/D channels** - ngf=160/ndf=40 instead of ngf=64/ndf=64 (gives an advantage to the generator) 
- **subtitutes arguments** (command prompt oriented) with hardcoded variables (notebook oriented)
- **checkpoint** to resume training (Kaggle CPU kernels time limit is 9 hours, which is around 50 epochs)

In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import sys
import os
import torch

In [None]:
os.listdir('../input')

In [None]:
# checkpoint

EPOCH_START = 650 # index value


cuda = torch.cuda.is_available()
map_location = None if cuda else 'cpu' # needed only if saved in gpu but loaded in cpu

#ckpt = 0
ckpt = torch.load(f'../input/d-dcgan-barrat-pytorch-my-landscapes-{EPOCH_START}e/checkpoint.tar', map_location=map_location)

In [None]:
# hardcoded variables

WORKERS = 2 # number of data loading workers
#NGPU = 1 # number of GPUs to use (paralel processing)

NC = 3
BATCHSIZE = 64
IMAGESIZE = 128
NZ = 100 # size of the latent z vector
NGF = 160 # number of generator feature maps after 'first' conv
NDF = 40 # number of discriminator feature maps after 'first' conv

EPOCHS = 50 # number of epochs to train for
LR = 0.0002
BETA1 = 0.5
BETA2 = 0.999
MANUALSEED = None # if None, randomly generated

DATAROOT = '../input'
PATH_OUT = '.' # folder to output images and model checkpoints
PATH_SAMPLES = f'{PATH_OUT}/samples'
PATH_FINAL = f'{PATH_OUT}/final'

In [None]:
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils




# seeding

if MANUALSEED is None:
    MANUALSEED = random.randint(1, 10000)
print("Random Seed: ", MANUALSEED)
random.seed(MANUALSEED)
torch.manual_seed(MANUALSEED)



# device

device = torch.device("cuda:0" if cuda else "cpu")
cudnn.benchmark = True ## kind of a cudnn auto-tuner, useful when inputs size dont vary

    
# data

dataset = dset.ImageFolder(root=DATAROOT,
                           transform=transforms.Compose([
                           ##    transforms.Resize(opt.imageSize), ## done outside of notebook
                           ##    transforms.CenterCrop(opt.imageSize), ## done outside of notebook
                               transforms.ToTensor(),
                               transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)), ## sequences of means, stds for each channel
                               
                           ])
                            )
## ImageFolder returns (sample, target) where target is class_index of the target class.
## Images are expected to be sorted in class folders, so any folder with images in root is considered a class,
## thus when calling dataloader, it returns a list of a batch of images and a batch of 'labels', in this case folder indexes.

dataloader = torch.utils.data.DataLoader(dataset, batch_size=BATCHSIZE,
                                         drop_last=True, ##
                                         shuffle=True, num_workers=WORKERS)



# custom weights initialization called on netG and netD

def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        m.weight.data.normal_(0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        m.weight.data.normal_(1.0, 0.02)
        m.bias.data.fill_(0)
        
        
#######        
## G ##
#######

# architecture           
        
class Generator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            ## input size is  NZ x 1 x 1  already so we can convolute over it
            nn.ConvTranspose2d(     NZ, NGF * 16, 4, 1, 0, bias=False), ## 160 * 16 = 2,560 channels
            nn.BatchNorm2d(NGF * 16),
            nn.ReLU(True),
            # state size. (NGF*16) x 4 x 4                                  
            nn.ConvTranspose2d(NGF * 16, NGF * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(NGF * 8),
            nn.ReLU(True),
            # state size. (NGF*8) x 8 x 8
            nn.ConvTranspose2d(NGF * 8, NGF * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(NGF * 4),
            nn.ReLU(True),
            # state size. (NGF*4) x 16 x 16
            nn.ConvTranspose2d(NGF * 4, NGF * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(NGF * 2),
            nn.ReLU(True),
            # state size. (NGF*2) x 32 x 32
            nn.ConvTranspose2d(NGF * 2,     NGF, 4, 2, 1, bias=False),
            nn.BatchNorm2d(NGF),
            nn.ReLU(True),
            # state size. (NGF) x 64 x 64
            nn.ConvTranspose2d(    NGF,      NC, 4, 2, 1, bias=True), ## bias True here since no BN follows
            nn.Tanh()
            # state size. (NC) x 128 x 128
        )

    def forward(self, input):
        output = self.main(input) 
        
        return output

    
    
# instantiating and initializing/loading weights

netG = Generator().to(device)
if ckpt:
    netG.load_state_dict(ckpt['G_state_dict'])
    print('Resuming training... G weights loaded')
else:
    netG.apply(weights_init)
##print(netG)


#######        
## D ##
#######

# architecture   

class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.main = nn.Sequential(
            # input is (NC) x 128 x 128
            nn.Conv2d(NC, NDF, 4, 2, 1, bias=False), ## shouldnt bias be True since BN is not applied ?
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (NDF) x 64 x 64
            nn.Conv2d(NDF, NDF * 2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(NDF * 2),
            nn.LeakyReLU(0.2, inplace=True),
            # state size.NDF*2) x 32 x 32
            nn.Conv2d(NDF * 2, NDF * 4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(NDF * 4),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (NDF*4) x 16 x 16
            nn.Conv2d(NDF * 4, NDF * 8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(NDF * 8),
            nn.LeakyReLU(0.2, inplace=True),
            # state size.(NDF*8) x 8 x 8
            nn.Conv2d(NDF * 8, NDF * 16, 4, 2, 1, bias=False),
            nn.BatchNorm2d(NDF * 16),
            nn.LeakyReLU(0.2, inplace=True),
            # state size. (NDF*16) x 4 x 4
            nn.Conv2d(NDF * 16, 1, 4, 1, 0, bias=True), ## bias True here since no BN follows
            ## kernel_size=4x4, no padding, 1 output channel => reduction to a single unit output
            ## state size. (1) x 1 x 1
            nn.Sigmoid()
        )

    def forward(self, input):    
        output = self.main(input) ## (BATCHSIZE x 1 x 1 x 1)

        return output.view(-1) ## (BATCHSIZE)


# instantiating and initializing/loading weights

netD = Discriminator().to(device)
if ckpt:
    netD.load_state_dict(ckpt['D_state_dict'])
    print('Resuming training... D weights loaded')
else:
    netD.apply(weights_init)
##print(netD)



##############        
## TRAINING ##
##############


optimizerD = optim.Adam(netD.parameters(), lr=LR, betas=(BETA1, BETA2))
optimizerG = optim.Adam(netG.parameters(), lr=LR, betas=(BETA1, BETA2))


criterion = nn.BCELoss()
real_label = 1
fake_label = 0
logs = {'errD': ckpt['logs']['errD'] if ckpt else [],
        'errG': ckpt['logs']['errG'] if ckpt else [],
        'D_x': ckpt['logs']['D_x'] if ckpt else [],
        'D_G_z1': ckpt['logs']['D_G_z1']if ckpt else [],
        'D_G_z2': ckpt['logs']['D_G_z2'] if ckpt else [],
        'batch': ckpt['logs']['batch'] if ckpt else [],
        'epoch': ckpt['logs']['epoch'] if ckpt else [],
       }
batch_start = logs['batch'][-1] if ckpt else 0
fixed_noise = torch.randn(BATCHSIZE, NZ, 1, 1, device=device) ## same every time we sample to see the progress, size=(BATCHSIZE, NZ, 1, 1)
os.makedirs(PATH_SAMPLES, exist_ok=True)

for epoch in range(EPOCHS):                
    for i, (data, _) in enumerate(dataloader):   
        
        ############################
        # (1) Update D network: maximize log(D(x)) + log(1 - D(G(z)))
        ###########################
        # train with real      
        netD.zero_grad()
        real_images = data.to(device)
        label = torch.full((BATCHSIZE,), real_label, device=device) ## drop_last=True added to dataloader
        output = netD(real_images)
        D_x = output.mean().item() ## metrics        
        errD_real = criterion(output, label)
        
        errD_real.backward() ##

        # train with fake
        noise = torch.randn(BATCHSIZE, NZ, 1, 1, device=device) ## normal noise (uniform noise is another option but it is not recommended)
        fake = netG(noise)
        label.fill_(fake_label)
        output = netD(fake.detach()) ## .detach() 'freezes' the weights of netG (detaches from the graph fake)       
        ## actually, since optimizerD only updates netD parameters and netG.zero_grad() is called before errG.backward(), this just avoids fake to be cleared after
        ## the backward pass (it is needed later for G training), plus it is faster/more efficient. This also explains why no freezing is applied later when training G
        D_G_z1 = output.mean().item() ## metrics        
        errD_fake = criterion(output, label)
        errD = (errD_real + errD_fake)/2 ## metrics
        
        
        errD_fake.backward()
        
        optimizerD.step()

        ############################
        # (2) Update G network: maximize log(D(G(z)))
        ###########################
        netG.zero_grad()
        label.fill_(real_label)  # fake labels are real for generator cost
        output = netD(fake)
        D_G_z2 = output.mean().item() ## metrics
        errG = criterion(output, label)
        
        errG.backward()        

        optimizerG.step()


            
        batches_done = epoch * len(dataloader) + i ##
        
        #if i % 1 == 0: ## testing
        if batches_done % 50 == 0: ## ~ each two epochs  
            
            # output log print
            
            log_string = (f'[{epoch+EPOCH_START:03d}/{EPOCHS+EPOCH_START:02d}][{i:02d}/{len(dataloader):03d}]'
                         f'\tLoss_D: {errD.item():.4f}  Loss_G: {errG.item():.4f}'
                         f'\t\tD(x): {D_x:.4f}  D(G(z)): {D_G_z1:.4f}>>{D_G_z2:.4f}')
           
            print(log_string, end='\r') ##
            #sys.stdout.write(log_string)
            

            # save image samples
            
            fake = netG(fixed_noise)
            vutils.save_image(fake.detach()[:9], ##
                    f'{PATH_SAMPLES}/fake_{batches_done:05d}_{epoch+EPOCH_START:03d}.png', ##
                    nrow=3, ##
                    normalize=True)
            
            # logs
            
            logs['errD'] += [errD.item()] # errD is a Tensor
            logs['errG'] += [errG.item()] # errG is a Tensor
            logs['D_x'] += [D_x]
            logs['D_G_z1'] += [D_G_z1]
            logs['D_G_z2'] += [D_G_z2]
            logs['batch'] += [batches_done+batch_start]
            logs['epoch'] += [epoch+EPOCH_START]
        

# checkpoint

torch.save({
    'D_state_dict': netD.state_dict(),
    'G_state_dict': netG.state_dict(),
    'D_optimizer_state_dict': optimizerD.state_dict(),    
    'G_optimizer_state_dict': optimizerG.state_dict(),
    'logs': logs,
}, f'{PATH_OUT}/checkpoint.tar') # .tar is the pytorch convention for dictionaries

In [None]:
## plot losses

plt.figure(figsize=(10,10))
plt.axvline(x=batch_start, c='k', linewidth=1)
plt.plot(logs['batch'], logs['errD'], 'b-', linewidth=1, label='discriminator')
plt.plot(logs['batch'], logs['errG'], 'r-', linewidth=1, label='generator')
#plt.legend()
plt.title('Losses')

# secondary x axis (more correct is ax.secondary_xaxis but N/A in this matplotlib version)
ax = plt.gca()
ax2 = ax.twiny()
ax2.plot(logs['epoch'], logs['errD'], 'k-', alpha=0.)

In [None]:
## plot predictions

plt.figure(figsize=(10,10))
plt.axvline(x=batch_start, c='k', linewidth=1)
plt.plot(logs['batch'], logs['D_x'], 'b-', linewidth=1, label='real')
plt.plot(logs['batch'], logs['D_G_z1'], '-', c='orange', linewidth=1, label='fake1')
plt.plot(logs['batch'], logs['D_G_z2'], 'r-', linewidth=1, label='fake2')
#plt.legend()
plt.title('Discriminator predictions')

# secondary x axis (more correct is ax.secondary_xaxis but N/A in this matplotlib version)
ax = plt.gca()
ax2 = ax.twiny()
ax2.plot(logs['epoch'], logs['D_x'], 'k-', alpha=0.)

In [None]:
## training process images

sorted_images = sorted(os.listdir(PATH_SAMPLES))

n_images = len(sorted_images)
#n_images = 7

#rows, columns = 3, 3
columns = 3
rows = np.ceil(n_images/columns) 

fig = plt.figure(figsize=(40, 40*rows/columns)) # in order to fill notebook width
fig.subplots_adjust(hspace=.05, wspace=0)
ax = [] # axes objects (plotting)

for i, file_name in enumerate(sorted_images):
    img = Image.open(f'{PATH_SAMPLES}/{file_name}')
    ax.append(fig.add_subplot(rows, columns, i+1))
    #ax[-1].grid(False) # not needed
    ax[-1].set_xticks([])
    ax[-1].set_yticks([])
    ax[-1].title.set_text(file_name)
    ax[-1].title.set_fontsize(25)
    plt.imshow(img) # only shown with this

In [None]:
## final images


os.makedirs(PATH_FINAL, exist_ok=True)

vutils.save_image(real_images[:25], ##           
        f'{PATH_FINAL}/real.png',
        nrow=5, ##
        normalize=True)
fake = netG(fixed_noise)
vutils.save_image(fake.detach()[:25],
        f'{PATH_FINAL}/fake.png', ##
        nrow=5, ##
        normalize=True)


sorted_images = os.listdir(PATH_FINAL)
n_images = len(sorted_images)

rows, columns = 1, 2

fig = plt.figure(figsize=(40, 40*rows/columns)) # figure object
fig.subplots_adjust(hspace=.05, wspace=0)
ax = [] # axes objects (plotting)

for i, file_name in enumerate(sorted_images):
    img = Image.open(f'{PATH_FINAL}/{file_name}')
    ax.append(fig.add_subplot(rows, columns, i+1))
    #ax[-1].grid(False) # not needed
    ax[-1].set_xticks([])
    ax[-1].set_yticks([])
    ax[-1].title.set_text(file_name)
    ax[-1].title.set_fontsize(50)
    plt.imshow(img) # only shown with this