Play around with different recombination methods

In [None]:
import os
import random
import torch
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from numpy.random import randint
import scipy.stats as stats

# save images
from torchvision.utils import save_image
img_save_path = './plots/'
os.makedirs(img_save_path, exist_ok=True)

from CAN.parameters import *
from CAN.dataloader_wikiart import *
from CAN.model_CAN_16_9 import * 

from deap import base, creator, tools, algorithms

In [None]:

# Set random seed for reproducibility
manualSeed = 3
#manualSeed = random.randint(1, 10000) # use if you want new results
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
np.random.seed(manualSeed)


# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")


In [None]:
### load model

# Create the generator
netG = Generator(ngpu).to(device)

# Setup Adam optimizers for G
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# Path to model parameters
PATH_GEN = './CAN/models/GEN_16_9.pth'

# load checkpoint
checkpoint = torch.load(PATH_GEN, map_location=torch.device('cpu'))
netG.load_state_dict(checkpoint['model_state_dict'])
optimizerG.load_state_dict(checkpoint['optimizer_state_dict'])

# set model to evaluation
netG.eval()

In [None]:
## deap setup

# maximation, one objective
creator.create("FitnessMax", base.Fitness, weights=(1.0,))
creator.create("Individual", list, fitness=creator.FitnessMax)

# set up individuals, random normal distribution
toolbox = base.Toolbox()
toolbox.register("attribute", np.random.normal)
toolbox.register("individual", tools.initRepeat, creator.Individual,
                 toolbox.attribute, n=nz)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)

In [None]:
## Use selfmade functions for OnePoint and Interpolation for more direct handling

def crossover(p1, p2, pt):
    # select crossover point that is not on the end of the string
    #pt = randint(10, p2.size()[0]-11)
    # perform crossover
    c1 = torch.cat((p1[:pt], p2[pt:]))
    c2 = torch.cat((p2[:pt], p1[pt:]))
    return c1, c2

def Interpolation(ind1, ind2, p = 0.5):
    """Executes a whole arithmetic crossover on the two input individuals (= two parents namely parent1 and parent 2) 
    :returns: A tuple of two children.
    """
    return ind1*p + ind2*(1-p)

In [None]:
## start individuals, 2 parents
genes = torch.randn(2, nz, 1, 1, device=device)

# for deap setup
# population
pop = toolbox.population(n=2)

for i in range(2):
    ind = genes[i].tolist()
    #flatten lists
    ind = [item for sublist in ind for item in sublist]
    ind = [item for sublist in ind for item in sublist]
    pop[i] = (creator.Individual(ind))

img_list = []

with torch.no_grad():
    fake = netG(genes).detach().cpu()
    img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
    save_image(fake.data, img_save_path + 'Recombination/evolution_0.png', normalize=True)

# plot them
plt.figure(figsize=(15,15))
plt.axis("off")
plt.title("Start Images")
plt.imshow(np.transpose(vutils.make_grid(fake, padding=2, normalize=True),(1,2,0)))
plt.show()

In [None]:
#### Part 1: Check individual methods:


## OnePoint Crossover

children_OnePoint = []
# one-point
for i in range(9):
    c1, c2 = crossover(genes[0], genes[1], pt = i*10+10)
    children_OnePoint.append(c1)
    children_OnePoint.append(c2)

    with torch.no_grad():
        fake = netG(torch.stack((c1, c2), 0)).detach().cpu()
        img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
        save_image(fake.data, img_save_path + 'Recombination/OnePoint/evolution_%d.png' % (i+1), normalize=True)

In [None]:
## Interpolation

children_Interpolation = []
# one-point
for i in range(9):
    c1 = Interpolation(genes[0], genes[1], p = (i*10+10)/100)
    children_Interpolation.append(c1)

    with torch.no_grad():
        fake = netG(c1.view(1,nz,1,1)).detach().cpu()
        img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
        save_image(fake.data, img_save_path + 'Recombination/Interpolation/evolution_%d.png' % (i+1), normalize=True)


In [None]:
## Two Point Crossover, random

for i in range(9):

    child1, child2 = [toolbox.clone(ind) for ind in (pop[0], pop[1])]
    tools.cxTwoPoint(child1, child2)

    with torch.no_grad():
                new_genes = torch.tensor([child1, child2])
                new_genes = new_genes.view(2,nz,1,1)
                fake = netG(new_genes.float()).detach().cpu()
                img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
                save_image(fake.data, img_save_path + 'Recombination/TwoPoint/evolution_%d.png' % (i+1), normalize=True)

In [None]:
## uniform crossover, random

for i in range(9):

    child1, child2 = [toolbox.clone(ind) for ind in (pop[0], pop[1])]
    tools.cxUniform(child1, child2, indpb = (i*10+10)/100)

    with torch.no_grad():
                new_genes = torch.tensor([child1, child2])
                new_genes = new_genes.view(2,nz,1,1)
                fake = netG(new_genes.float()).detach().cpu()
                img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
                save_image(fake.data, img_save_path + 'Recombination/Uniform/evolution_%d.png' % (i+1), normalize=True)

In [None]:
## Blend crossover
## lead to genes out of expected distribution!

for i in range(9):

    child1, child2 = [toolbox.clone(ind) for ind in (pop[0], pop[1])]
    tools.cxBlend(child1, child2, alpha = (i*10+10)/200)
    # max alpha = 0.5

    with torch.no_grad():
                new_genes = torch.tensor([child1, child2])
                new_genes = new_genes.view(2,nz,1,1)
                fake = netG(new_genes.float()).detach().cpu()
                img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
                save_image(fake.data, img_save_path + 'Recombination/Blend/evolution_%d.png' % (i+1), normalize=True)

In [None]:
### part 2: main experiment, try for several parents:


def experiment(examples = 5):
    
    img_list = []

    for i in range(examples):

        ## start individuals
        genes = torch.randn(2, nz, 1, 1, device=device)

        with torch.no_grad():
            fake = netG(genes).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            save_image(fake.data, img_save_path + 'Recombination/Experiment/evolution_%d.png' % (i), normalize=True)

        
        ### recombinations

        # Interpolation
        for j in range(9):
            c1 = Interpolation(genes[0], genes[1], p = (j*10+10)/100)
            children_Interpolation.append(c1)

            with torch.no_grad():
                fake = netG(c1.view(1,nz,1,1)).detach().cpu()
                img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
                save_image(fake.data, img_save_path + 'Recombination/Experiment/evolution_%d_Interpolation_%d.png' % (i, j), normalize=True)


        # OnePoint X-Over
        for j in range(9):
            c1, c2 = crossover(genes[0], genes[1], pt = j*10+10)
            children_OnePoint.append(c1)
            children_OnePoint.append(c2)

            with torch.no_grad():
                fake = netG(torch.stack((c1, c2), 0)).detach().cpu()
                img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
                save_image(fake.data, img_save_path + 'Recombination/Experiment/evolution_%d_OnePoint_%d.png' % (i, j), normalize=True)

        
        # TwoPoint X-Over, in DEAP Setup
        pop = toolbox.population(n=2)

        for j in range(2):
            ind = genes[j].tolist()
            #flatten lists
            ind = [item for sublist in ind for item in sublist]
            ind = [item for sublist in ind for item in sublist]
            pop[j] = (creator.Individual(ind))

        for j in range(9):

            child1, child2 = [toolbox.clone(ind) for ind in (pop[0], pop[1])]
            tools.cxTwoPoint(child1, child2)

            with torch.no_grad():
                        new_genes = torch.tensor([child1, child2])
                        new_genes = new_genes.view(2,nz,1,1)
                        fake = netG(new_genes.float()).detach().cpu()
                        img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
                        save_image(fake.data, img_save_path + 'Recombination/Experiment/evolution_%d_TwoPoint_%d.png' % (i, j), normalize=True)


        for j in range(9):

            child1, child2 = [toolbox.clone(ind) for ind in (pop[0], pop[1])]
            tools.cxUniform(child1, child2, indpb = (i*10+10)/100)
            
            with torch.no_grad():
                        new_genes = torch.tensor([child1, child2])
                        new_genes = new_genes.view(2,nz,1,1)
                        fake = netG(new_genes.float()).detach().cpu()
                        img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
                        save_image(fake.data, img_save_path + 'Recombination/Experiment/evolution_%d_Uniform_%d.png' % (i, j), normalize=True)


    return img_list


img_list_experiment = experiment(5)

In [None]:
## Visualize experiment:
fig = plt.figure(figsize=(8,6))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list_experiment]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())