In [None]:
import os
import random
import torch
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torchvision.utils as vutils
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from IPython.display import HTML
from numpy.random import randint

# save images
from torchvision.utils import save_image
img_save_path = './plots/'
os.makedirs(img_save_path, exist_ok=True)

from CAN.parameters import *
from CAN.dataloader_wikiart import *
from CAN.model_CAN_16_9 import * 

In [None]:
### setup

# Set random seed for reproducibility
manualSeed = 3
#manualSeed = random.randint(1, 10000)
print("Random Seed: ", manualSeed)
random.seed(manualSeed)
torch.manual_seed(manualSeed)
np.random.seed(manualSeed)


# Decide which device we want to run on
device = torch.device("cuda:0" if (torch.cuda.is_available() and ngpu > 0) else "cpu")

In [None]:
### load model

# Create the generator
netG = Generator(ngpu).to(device)

# Setup Adam optimizers for G
optimizerG = optim.Adam(netG.parameters(), lr=lr, betas=(beta1, 0.999))

# Path to model parameters
PATH_GEN = './CAN/models/GEN_16_9.pth'

# load checkpoint
checkpoint = torch.load(PATH_GEN, map_location=torch.device('cpu'))
netG.load_state_dict(checkpoint['model_state_dict'])
optimizerG.load_state_dict(checkpoint['optimizer_state_dict'])

# set model to evaluation
netG.eval()

In [None]:
## set a first set of genes

# Create bunch of latent vectors that we will use for evolution
genes = torch.randn(64, 100, 1, 1, device=device)

In [None]:
## random walk of little gaussian mutations

def random_walk(genes, iterations = 1000, std = 0.05, indpb = 0.5):
    
    img_list = []

    with torch.no_grad():
        fake = netG(genes).detach().cpu()
        img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
        save_image(fake.data, img_save_path + '/random_walk_%d.png' % (0), normalize=True)

    for i in range(iterations):
        for j in range(genes.size()[0]):
            # if only one gene is changed
            #locus = torch.randint(0, 100, size = (1,))
            #genes[j][locus] += mutation
            for gen in range(100):
                # indpb for every single gene in genes
                if torch.rand(1) > indpb:
                    mutation = torch.normal(0, std, size = (1,))
                    genes[j][gen] += mutation

        with torch.no_grad():
            fake = netG(genes).detach().cpu()
            img_list.append(vutils.make_grid(fake, padding=2, normalize=True))
            save_image(fake.data, img_save_path + '/random_walk_%d.png' % (i+1), normalize=True)
    
    return img_list

In [None]:
img_list = random_walk(genes, iterations = 100)

In [None]:
# Visualize Random Walk
fig = plt.figure(figsize=(8,6))
plt.axis("off")
ims = [[plt.imshow(np.transpose(i,(1,2,0)), animated=True)] for i in img_list]
ani = animation.ArtistAnimation(fig, ims, interval=1000, repeat_delay=1000, blit=True)

HTML(ani.to_jshtml())