Code adapted from/uses ideas from:

1) C. Willcocks, https://colab.research.google.com/gist/cwkx/f4b49cd3efc0e624bc22c89c90921931/spectral-norm-gan.ipynb (for base of code)

2) https://scikit-learn.org/stable/modules/generated/sklearn.neighbors.NearestNeighbors.html#sklearn.neighbors.NearestNeighbors.kneighbors (for kNN selection of images)

3) U. Desai, https://medium.com/@utk.is.here/keep-calm-and-train-a-gan-pitfalls-and-tips-on-training-generative-adversarial-networks-edd529764aa9 and https://github.com/utkd/gans/blob/master/cifar10dcgan.ipynb (for ideas on adding noise and flipping labels)

4) R. Chavhan https://github.com/ruchikachavhan/GANs (for ideas for layers in generator and discriminator)

5) J. Brownlee https://machinelearningmastery.com/how-to-train-stable-generative-adversarial-networks/ (for ideas for layers in GAN and layer/optimiser parameters)

In [0]:
%%capture
from os.path import exists
from wheel.pep425tags import get_abbr_impl, get_impl_ver, get_abi_tag
platform = '{}{}-{}'.format(get_abbr_impl(), get_impl_ver(), get_abi_tag())
cuda_output = !ldconfig -p|grep cudart.so|sed -e 's/.*\.\([0-9]*\)\.\([0-9]*\)$/cu\1\2/'
accelerator = cuda_output[0] if exists('/dev/nvidia0') else 'cpu'
!pip install -q torch torchvision livelossplot

**Main imports**

In [0]:
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from time import sleep
from livelossplot import PlotLosses

device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

latent_space_size = 100

In [0]:
imageSize = 64
transform2 = transforms.Compose([transforms.Scale(imageSize)]) 
trainset = torchvision.datasets.CIFAR10('data', train=True, download=True,transform = transform2)
print(trainset)
testset = torchvision.datasets.CIFAR10('data', train=False, download=True,transform=transform2)
print(testset)


In [0]:
# Normalize the image data. The inputs are 0-255 for each channel.  Convert them to float32
# where each value is > -1.9 and <1.0.  Several papers stress the importance of normalizing the input

horses = np.empty([6000,3,imageSize,imageSize]).astype(np.float32)
birds = np.empty([6000,3,imageSize,imageSize]).astype(np.float32)
horsecount = 0;
birdcount = 0;
for dataset in [trainset, testset]:
  for image,label in dataset:
    if label == 7:
      horses[horsecount] = (np.array(image).swapaxes(0,2).swapaxes(1,2)/128)-0.996
      horsecount += 1
    if label == 2:
      birds[birdcount] = (np.array(image).swapaxes(0,2).swapaxes(1,2)/128)-1.0
      birdcount += 1

from sklearn.neighbors import NearestNeighbors
nbrs = NearestNeighbors(n_neighbors=2000, algorithm='auto').fit(np.reshape(horses, (horses.shape[0], -1)))
distances, indices = nbrs.kneighbors(horses[4].reshape(1,-1))

new_horses = np.empty([2000,3,imageSize,imageSize]).astype(np.float32)
num_horses = 0
for i in range(6000):
  if i in indices[0]:
    for j in range(1):
      new_horses[num_horses] = horses[i]
      num_horses += 1

plt.figure(figsize=(10,10))
for i in range(25):
  plt.subplot(5,5,i+1)
  plt.xticks([])
  plt.yticks([])
  plt.grid(False)
  plt.imshow(((horses[indices[0][i]]/2)+0.5).swapaxes(1,2).swapaxes(0,2), cmap=plt.cm.binary)
  plt.xlabel(indices[0][i])

nbrs = NearestNeighbors(n_neighbors=1400, algorithm='auto').fit(np.reshape(birds, (birds.shape[0], -1)))
distances, indices = nbrs.kneighbors(birds[3387].reshape(1,-1)) #closed wing birds

new_birds = np.empty([1400,3,imageSize,imageSize]).astype(np.float32)
num_birds = 0
for i in range(6000):
  if i in indices[0]:
    for j in range(1):
      new_birds[num_birds] = birds[i]
      num_birds += 1

nbrs = NearestNeighbors(n_neighbors=400, algorithm='auto').fit(np.reshape(birds, (birds.shape[0], -1)))
distances, places = nbrs.kneighbors(birds[626].reshape(1,-1)) #open wing birds

open_birds = np.empty([400,3,imageSize,imageSize]).astype(np.float32)
num_birds = 0
for i in range(6000):
  if i in places[0]:
    for j in range(1):
      open_birds[num_birds] = birds[i]
      num_birds += 1

plt.figure(figsize=(10,10))
for i in range(25):
  plt.subplot(5,5,i+1)
  plt.xticks([])
  plt.yticks([])
  plt.grid(False)
  plt.imshow(((birds[places[0][i]]/2)+0.5).swapaxes(1,2).swapaxes(0,2), cmap=plt.cm.binary)
  plt.xlabel(places[0][i])

all_birds = np.concatenate((new_birds,open_birds))
pegasus = np.concatenate((new_horses,all_birds))
# Use one-hot encoding for classes
pegasus_labels = np.concatenate((np.full((2000,3), [0,1,0]), np.full((1800,3), [0,0,1])))

# If you want to shuffle the images:
# random_idx = np.random.permutation(len(pegasus)) # Create a shuffle index
# pegasus = pegasus[random_idx] # Sort both arrays with the same random index
# pegasus_labels = pegasus_labels[random_idx]


In [0]:
from torch.utils.data import Dataset
class CustomDataset(Dataset):

    def __init__(self, data, labels):
        self.data = data
        self.labels = labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return (self.data[idx], self.labels[idx])

peg_dataset = CustomDataset(pegasus, pegasus_labels)

In [0]:
def cycle(iterable):
    while True:
        for x in iterable:
            yield x

horse_iterator = iter(cycle(torch.utils.data.DataLoader(peg_dataset,
                                           batch_size=64,
                                           shuffle=False)))




**Define two models: (1) Generator, and (2) Discriminator**

In [0]:
# define the model

from  torch.nn.modules.upsampling import Upsample

# custom weights initialization called on netG and netD
def weights_init(m):
    classname = m.__class__.__name__
    if classname.find('Conv') != -1:
        nn.init.normal_(m.weight.data, 0.0, 0.02)
    elif classname.find('BatchNorm') != -1:
        nn.init.normal_(m.weight.data, 1.0, 0.02)
        nn.init.constant_(m.bias.data, 0)

class Generator(nn.Module): 

    def __init__(self): 
        super(Generator, self).__init__() 
        self.model = nn.Sequential( 
            nn.ConvTranspose2d(latent_space_size, 512, 4, 1, 0, bias = False), 
            nn.BatchNorm2d(512), 
            nn.ReLU(True), 
            nn.ConvTranspose2d(512, 256, 4, 2, 1, bias = False), 
            nn.BatchNorm2d(256), 
            nn.ReLU(True), 
            nn.ConvTranspose2d(256, 128, 4, 2, 1, bias = False), 
            nn.BatchNorm2d(128), 
            nn.ReLU(True), 
            nn.ConvTranspose2d(128, 64, 4, 2, 1, bias = False), 
            nn.BatchNorm2d(64), 
            nn.ReLU(True),
            nn.ConvTranspose2d(64, 3, 4, 2, 1, bias = False), 
            nn.Tanh() 
        )

    def generate(self, input): 
        output = self.model(input) 
        return output 

class Discriminator(nn.Module): 
    def __init__(self): 
        super(Discriminator, self).__init__() 
        self.model = nn.Sequential( 
            torch.nn.utils.spectral_norm(nn.Conv2d(3, 64, 4, 2, 1, bias = False)), 
            nn.LeakyReLU(0.3, inplace = True), 
            torch.nn.utils.spectral_norm(nn.Conv2d(64, 128, 4, 2, 1, bias = False)), 
            nn.BatchNorm2d(128), 
            nn.LeakyReLU(0.3, inplace = True), 
            torch.nn.utils.spectral_norm(nn.Conv2d(128, 256, 4, 2, 1, bias = False)), 
            nn.BatchNorm2d(256), 
            nn.LeakyReLU(0.3, inplace = True), 
            torch.nn.utils.spectral_norm(nn.Conv2d(256, 512, 4, 2, 1, bias = False)), 
            nn.BatchNorm2d(512), 
            nn.LeakyReLU(0.3, inplace = True), 
            torch.nn.utils.spectral_norm(nn.Conv2d(512, 3, 4, 1, 0, bias = False)), 
            nn.Sigmoid() 
        )

    def discriminate(self, input): 
        output = self.model(input) 
        return output.view(-1,3) 

G = Generator().to(device)
D = Discriminator().to(device)

# Apply the weights_init function to randomly initialize all weights
#  to mean=0, stdev=0.2.
G.apply(weights_init)
D.apply(weights_init)

print(f'> Number of generator parameters {len(torch.nn.utils.parameters_to_vector(G.parameters()))}')
print(f'> Number of discriminator parameters {len(torch.nn.utils.parameters_to_vector(D.parameters()))}')

print(G)

lr = 0.0002
beta1 = 0.5

# initialise the optimiser
optimiser_G = torch.optim.Adam(G.parameters(), lr=lr, betas=(beta1, 0.999))
optimiser_D = torch.optim.Adam(D.parameters(), lr=lr, betas=(beta1, 0.999))
bce_loss = nn.BCELoss()
epoch = 0
liveplot = PlotLosses()


In [0]:
x = next(horse_iterator)
print(x[0].dtype)
print(x[0].size())
print(x[0].mean())

**Main training loop**

In [0]:
# training loop
from torch.autograd import Variable
import cv2

print("start")
epoch = 0


def grayscale(data, dtype='float32'):
    # luma coding weighted average in video systems
    r, g, b = np.asarray(.3, dtype=dtype), np.asarray(.59, dtype=dtype), np.asarray(.11, dtype=dtype)
    rst = r * data[:, 0, :, :] + g * data[:, 1, :, :] + b * data[:, 2, :, :]
    # add channel dimension
    rst = np.expand_dims(rst, axis=1)
    return rst

while (epoch<100):

    # arrays for metrics
    logs = {}
    gen_loss_arr = np.zeros(0)
    dis_loss_arr = np.zeros(0)
    enc_loss_arr = np.zeros(0)

    # iterate over some of the train dateset
    for i in range(60):

        # train discriminator 
        for j in range(2):
            x,labels = next(horse_iterator)
            x = x.to(device)
            optimiser_D.zero_grad()

            # The last mini batch may be smaller than batchSize
            mini_batch_size = x.size()[0]
          
            # # target = Variable(torch.ones(mini_batch_size)).to(device)
          
            # noise_prop = 0.05 # Randomly flip 5% of labels

            # # Prepare labels for real data
            # true_labels = np.ones((mini_batch_size)) - np.random.uniform(low=0.0, high=0.1, size=(mini_batch_size))
            # #flipped_idx = np.random.choice(np.arange(len(true_labels)), size=int(noise_prop*len(true_labels)))
            # #true_labels[flipped_idx] = 1 - true_labels[flipped_idx]
            # target = Variable(torch.from_numpy(true_labels)).float().to(device)

            fake_target = torch.tensor(np.full((mini_batch_size,3), [1, 0, 0])).float().to(device)

            g = G.generate(torch.randn(x.size(0), latent_space_size, 1, 1).to(device))
            l_r = bce_loss(D.discriminate(x), labels.float().to(device)) # real -> 1
            l_f = bce_loss(D.discriminate(g.detach()), fake_target) #  fake -> 0
            loss_d = (l_r + l_f)
            loss_d.backward()
            optimiser_D.step()
          
        # train generator
        x,labels = next(horse_iterator)
        x = x.to(device)
        optimiser_G.zero_grad()
        g = G.generate(torch.randn(x.size(0), latent_space_size, 1, 1).to(device))

        loss_g = bce_loss(D.discriminate(g),torch.tensor(np.full((x.size()[0],3), [0.05,0.6,0.4])).float().to(device) ) # fake -> 1
        loss_g.backward()
        optimiser_G.step()

        gen_loss_arr = np.append(gen_loss_arr, loss_g.item())
        dis_loss_arr = np.append(dis_loss_arr, loss_d.item())

    # plot some examples
    imgs = np.array(g.cpu().detach())

    print("Epoch: %d" % epoch)
    plt.figure(figsize=(10,10))
    for i in range(10):
        k = (imgs[i]/2)+0.5
       # smooth = cv2.GaussianBlur(k,(5,5),0)
       # img = k + 0.5*(k-smooth)
        gaussian = cv2.GaussianBlur(k, (7,7), 6.0)
        unsharp_image = cv2.addWeighted(k, 1.5, gaussian, -0.5, 0)
        img = cv2.normalize(unsharp_image, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
        plt.subplot(5,5,i+1)
        plt.xticks([])
        plt.yticks([])
        plt.grid(False)
        plt.imshow(img.swapaxes(1,2).swapaxes(0,2), cmap=plt.cm.binary)
        plt.xlabel("Generated % d" % i)
    plt.show()

    epoch = epoch+1

In [0]:
# Get the horsiness and birdiness of images
# def test_classifier(desc, dataset):
#   set_loader = iter(torch.utils.data.DataLoader(dataset,batch_size=16,shuffle=False))
#   batch = next(set_loader)
#   classes = D.discriminate(batch.to(device))
#   print('Classification of %s images:' % desc)
#   for real, horse, bird in classes:
#     print("%.4f %.4f %.4f" % (real, horse, bird))

# test_classifier("horse", horses)
# test_classifier("bird", birds)

In [0]:
def sortby(x):
    return x[:,1]+x[:,2] # Sort by horsiness + birdiness

def best_pegasus():
  gen_images = G.generate(torch.randn(64, latent_space_size, 1, 1).to(device))
  classes = D.discriminate(gen_images).cpu().detach().numpy()
  ordered = np.argsort(sortby(classes))
  best = gen_images[ordered]
  classes = np.flip(classes[ordered], axis=0)
  return((best[0], classes[0]))

In [0]:
plt.figure(figsize=(20,20))
for i in range(64):
    plt.subplot(8,8,i+1)
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    best_img, classes = best_pegasus()
    img = best_img.cpu().detach().numpy().swapaxes(0,2).swapaxes(0,1)
    gaussian = cv2.GaussianBlur(img, (7,7), 6.0)
    unsharp_image = cv2.addWeighted(img, 1.5, gaussian, -0.5, 0)
    final_img = cv2.normalize(unsharp_image, None, alpha=0, beta=1, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_32F)
    plt.imshow(final_img, interpolation = 'bicubic')
    #plt.xlabel('%.3f %.5f %.5f' % (classes[0], classes[1], classes[2]))
    plt.xlabel(i)