In [None]:
import os
import random
import torch
import copy
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from numpy.random import randint
import scipy.stats as stats

# save images
from torchvision.utils import save_image
img_save_path = './plots/'
os.makedirs(img_save_path, exist_ok=True)

from CAN.parameters import *
from CAN.model_CAN_16_9 import * 

from deap import base, creator, tools, algorithms

from keras.models import Model
from keras.layers import Dense, Dropout
from keras.applications.inception_resnet_v2 import InceptionResNetV2
from keras.applications.inception_resnet_v2 import preprocess_input
from keras.utils.image_utils import load_img, img_to_array
import tensorflow as tf
from PIL import Image

from NIMA.utils.score_utils import mean_score

In [None]:
# Set random seed for reproducibility
# used seeds for simulation: 3,5,7,9,11
manualSeed = 7
#manualSeed = random.randint(1, 10000)
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
np.random.seed(manualSeed)


# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

In [None]:
### load Generator model

# Create the generator
netG = Generator(ngpu).to(device)

# Setup Adam optimizers for G
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# Path to model parameters
PATH_GEN = './CAN/models/GEN_16_9.pth'

# load checkpoint
checkpoint = torch.load(PATH_GEN, map_location=torch.device('cpu'))
netG.load_state_dict(checkpoint['model_state_dict'])
optimizerG.load_state_dict(checkpoint['optimizer_state_dict'])

# set model to evaluation
netG.eval()

In [None]:
### load Evaluation model

netEval = InceptionResNetV2(input_shape=(None, None, 3), include_top=False, pooling='avg', weights=None)
x = Dropout(0.75)(netEval.output)
x = Dense(10, activation='softmax')(x)
netEval = Model(netEval.input, x)
netEval.load_weights('./NIMA/weights/inception_resnet_weights.h5')

In [None]:
### Fitness functions

# simple random fitness
def fitnessRandom(ind):
    return np.random.uniform(), # deap needs the comma


# automatic evaluation

def AutomaticEvaluation(img_path):
    
    with tf.device('/CPU:0'):
        # load image, unfortunately from path due to problems in converting torch.tensor to PilImage
        target_size = (224, 224)
        image = load_img(img_path, target_size=target_size)
        image = img_to_array(image)
        image = np.expand_dims(image, axis=0)
        image = preprocess_input(image)

        # evaluate the image
        score = netEval.predict(image, batch_size=1, verbose=0)[0]
        score = mean_score(score)

    return score

def fitnessAutomaticEvaluation(ind):

    # part 1: create image
    
    with torch.no_grad():
        ind = torch.tensor(ind)
        ind = ind.view(1,nz,1,1)
        image = netG(ind.float()).detach().cpu()

        # not really sexy, but when directly converting to PilImage, image is slightly changed!
        # delete after reading it in, no need to keep
        save_image(image.data, "current_individual.png", normalize=True)
    
    # part 2: evluate image and delete it

    score = AutomaticEvaluation("current_individual.png")
    os.remove("current_individual.png")


    return score, # deap needs the comma


In [None]:
### DEAP setup

# maximation, one objective
creator.create("FitnessMax", base.Fitness, weights=(1.0,))
creator.create("Individual", list, fitness=creator.FitnessMax)

# set up individuals, random normal distribution
toolbox = base.Toolbox()
toolbox.register("attribute", np.random.normal)
toolbox.register("individual", tools.initRepeat, creator.Individual,
                 toolbox.attribute, n=nz)
toolbox.register("population", tools.initRepeat, list, toolbox.individual)

### set up mating, mutation, selection

# from experiments with recombination, uniform with indpb = .25 promising
toolbox.register("mate", tools.cxUniform, indpb = 0.25)
# only used in local search, else overwritten
toolbox.register("mutate", tools.mutGaussian, mu=0, sigma=1, indpb=1/100)
toolbox.register("select", tools.selStochasticUniversalSampling)
toolbox.register("evaluate", fitnessAutomaticEvaluation) 

In [None]:
### save some statistics of each generation
stats = tools.Statistics(key=lambda ind: ind.fitness.values)
stats.register("max", np.max)
stats.register("avg", np.mean)
stats.register("min", np.min)
stats.register("std", np.std)

In [None]:
def simple_Evolution(MU = 10, crossP = 0.5, mutP = 0.5, NGEN = 10):

    # set fitness function
    toolbox.register("evaluate", fitnessRandom)
    toolbox.register("mutate", tools.mutGaussian, mu=0, sigma=0.1, indpb=0.5)

    img_list = []
    genes = torch.randn(MU, nz, 1, 1, device=device)


    with torch.no_grad():
        fake = netG(genes).detach().cpu()
        img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
        save_image(fake.data, img_save_path + '/evolution_random_%d.png' % (0), normalize=True)


    #log and hof
    log = tools.Logbook()
    log.header = "gen", "new", "max", "mean", "min", "std"
    hof = tools.HallOfFame(maxsize=20)
    
    # population
    pop = toolbox.population(n=MU)
    # genes replacing the pop
    for i in range(MU):
        ind = genes[i].tolist()
        #flatten lists
        ind = [item for sublist in ind for item in sublist]
        ind = [item for sublist in ind for item in sublist]
        pop[i] = (creator.Individual(ind))

    results = []

    for i in range(NGEN):
        # let the algorithm run for one round only, then plot the results again
        pop, log = algorithms.eaSimple(pop, toolbox, cxpb=crossP, mutpb=mutP, ngen=1, stats=stats, halloffame=hof)
        results.append(log)

        with torch.no_grad():
            new_genes = torch.tensor(pop)
            new_genes = new_genes.view(MU,nz,1,1)
            fake = netG(new_genes.float()).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            save_image(fake.data, img_save_path + '/evolution_random_%d.png' % (i+1), normalize=True)


    return pop, results, hof, img_list

In [None]:
## local search instead of random mutation
def local_Search(genes, NGEN = 100):

    pop = toolbox.population(n=1)
    pop[0] = (creator.Individual(genes))
    ref = copy.deepcopy(pop[0])

    # switch to selBest, mutation is only used here and can be kept
    toolbox.register("select", tools.selBest)

    offspring = algorithms.eaMuPlusLambda(pop, toolbox, mu=1, lambda_=1, 
        cxpb=0, mutpb=1, ngen=NGEN, stats=stats, verbose=False)
    
    # switch back
    toolbox.register("select", tools.selStochasticUniversalSampling)

    same = False
    if all(x == y for x, y in zip(ref, offspring[0][0])):
        same = True

    return offspring[0][0], same

In [None]:
# diversity preservation metric to avoid too close individuals:

def preserve_Diversity(pop, MU, threshold = 25):
    invalid_inds = []
    dist = []
    # Threshold ~ 5-30, arbitrary

    for i in range(MU):
        for j in range(MU):
            if j > i:
                dist.append(sum([abs(x-y) for x, y in zip(pop[i], pop[j])]))
                if sum([abs(x-y) for x, y in zip(pop[i], pop[j])]) < threshold:
                    invalid_inds.append(i)

    return np.unique(invalid_inds).tolist(), dist 

In [None]:
def complex_Evolution(MU = 15, crossP = 0.5, mutP = 0.5, NGEN = 25):

    # create first generation
    img_list = []
    genes = torch.randn(MU, nz, 1, 1, device=device)

    #log and hof
    logbook = tools.Logbook()
    logbook.header = ['gen', 'nevals', 'imm'] + (stats.fields if stats else [])
    hof = tools.HallOfFame(maxsize=20)
    
    # population
    pop = toolbox.population(n=MU)
    # genes replacing the pop
    for i in range(MU):
        ind = genes[i].tolist()
        #flatten lists
        ind = [item for sublist in ind for item in sublist]
        ind = [item for sublist in ind for item in sublist]
        pop[i] = (creator.Individual(ind))


    ## evaluation of first generation
    invalid_ind = [ind for ind in pop if not ind.fitness.valid]
    fitnesses = toolbox.map(toolbox.evaluate, invalid_ind)
    for ind, fit in zip(invalid_ind, fitnesses):
        ind.fitness.values = fit

    # get stats from first generation
    record = stats.compile(pop)
    logbook.record(gen=0, nevals=len(invalid_ind), imm=0, **record)
    hof.update(pop)
    print(logbook.stream)


    for i in range(NGEN):

        # save images
        with torch.no_grad():
            new_genes = torch.tensor(pop)
            new_genes = new_genes.view(MU,nz,1,1)
            fake = netG(new_genes.float()).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True, nrow = 5))
            save_image(fake.data, img_save_path + '/evolution_automatic_%d.png' % (i), normalize=True, nrow = 5)


        # let the algorithm run for one round only, then plot the results again
        
        # Select the next generation individuals
        offspring = toolbox.select(pop, len(pop))
        # Clone the selected individuals
        offspring = list(map(toolbox.clone, offspring))

        # Apply crossover on the offspring
        for child1, child2 in zip(offspring[::2], offspring[1::2]):
            if random.random() < crossP:
                toolbox.mate(child1, child2)
                del child1.fitness.values
                del child2.fitness.values

        # Apply mutation on the offspring: Here memetic algorithm using local search!
        for mutant in range(len(offspring)):
            if random.random() < mutP:
                offspring[mutant], same = local_Search(offspring[mutant])
                if same == False:
                    del offspring[mutant].fitness.values


        # Preserve diversity, for novelty search + exploration + against user fatigue
        boring_inds, _ = preserve_Diversity(offspring, MU)
        for ind in boring_inds:

            ## to check the measure:
            with torch.no_grad():
                new_genes = torch.tensor(offspring[ind])
                new_genes = new_genes.view(1,nz,1,1)
                fake = netG(new_genes.float()).detach().cpu()
                save_image(fake.data, img_save_path + '/evolution_%d_replacedInd_%d.png' % (i+1, ind), normalize=True, nrow = 5)

            # replace by random immigrant if to close to other individuals
            offspring[ind] = toolbox.population(n=1)[0]
            # and go through local search directly
            offspring[ind],_ = local_Search(offspring[ind])
            del offspring[ind].fitness.values

        # Evaluate the individuals with an invalid fitness
        invalid_ind = [ind for ind in offspring if not ind.fitness.valid]
        fitnesses = toolbox.map(toolbox.evaluate, invalid_ind)
        for ind, fit in zip(invalid_ind, fitnesses):
            ind.fitness.values = fit

        # The population is entirely replaced by the offspring
        pop[:] = offspring

        # statistics and hall of fame
        record = stats.compile(pop)
        logbook.record(gen=i+1, nevals=len(invalid_ind), imm=len(boring_inds), **record)
        hof.update(pop)
        print(logbook.stream)

        ## upscale hall of fame after last generation
        if i == 24:
            
            from high_resolution import upscale

            for j, img in enumerate(hof):

                # save images
                with torch.no_grad():
                    new_genes = torch.tensor(img)
                    new_genes = new_genes.view(1,nz,1,1)
                    fake = netG(new_genes.float()).detach().cpu()
                    save_image(fake.data, img_save_path + '/evolution_automatic_%d_hall_of_fame_%d.png' % (i+1,j), normalize=True)


                upscale(img_path = img_save_path + '/evolution_automatic_%d_hall_of_fame_%d.png' % (i+1,j), scale = 8)

    # finally also save last generation    
    with torch.no_grad():
        new_genes = torch.tensor(pop)
        new_genes = new_genes.view(MU,nz,1,1)
        fake = netG(new_genes.float()).detach().cpu()
        img_list.append(vutils.make_grid(fake, padding=2, normalize=True, nrow = 5))
        save_image(fake.data, img_save_path + '/evolution_automatic_%d.png' % (i+1), normalize=True, nrow = 5)

    return pop, logbook, hof, img_list

In [None]:
pop, log, hof, img_list = complex_Evolution()

In [None]:
# Check whether after evolution genes still fit expected distribution
avg = []
std = []
for i,ind in enumerate(pop):
    avg.append(np.mean(ind))
    std.append(np.std(ind))
print(np.mean(avg))
print(np.mean(std))

In [None]:
# write Logbook as csv
gen = []
nevals = []
imm = []
maxi = []
avg = []
mini = []
std = []

for entry in log:
    gen.append(entry['gen'])
    nevals.append(entry['nevals'])
    imm.append(entry['imm'])
    maxi.append(entry['max'])
    avg.append(entry['avg'])
    mini.append(entry['min'])
    std.append(entry['std'])

import csv
from itertools import zip_longest
d = [gen, nevals, imm, maxi, avg, mini, std]
export_data = zip_longest(*d, fillvalue = '')
with open('evolution_automatic_results.csv', 'w', encoding="ISO-8859-1", newline='') as myfile:
    wr = csv.writer(myfile)
    wr.writerow(("gen", "nevals", "immigrants","max", "avg", "min", "std"))
    wr.writerows(export_data)
myfile.close()

In [None]:
## plot results

# set upper and lower bound
se = ([x/np.sqrt(15) for x in std])

lower = [a_i - b_i for a_i, b_i in zip(avg, se)]
upper = [a_i + b_i for a_i, b_i in zip(avg, se)]


fig, ax = plt.subplots()
ax.plot(gen, avg, color = 'blue', label = 'mean')
ax.fill_between(gen, lower, upper, alpha=0.2)
ax.plot(gen, maxi, color = 'darkblue', label = 'max', linestyle='dashed', linewidth = 0.5)
ax.plot(gen, mini, color = 'lightblue', label = 'min', linestyle='dashed', linewidth = 0.5)

ax.set_ylabel('Fitness')
ax.set_xlabel("Generation")
ax.legend(loc = 'upper left')

plt.show()

In [None]:
# Visualize Evolution
fig = plt.figure(figsize=(8,6))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())