In [None]:
def show_state(env, step=0, name="", info=""):
    """Fn to visualize the agent playing the game in a notebook
    """
    plt.figure(10)
    plt.clf()
    plt.imshow(env.render(mode="rgb_array"))
    plt.title("{} | Step: {} {}".format(name, step, info))
    plt.axis('off')
    display.clear_output(wait=True)
    display.display(plt.gcf())

In [None]:
import torch
from scipy import optimize
import torch.nn.functional as F
import math
import numpy as np
from functools import reduce
from collections import OrderedDict

class PyTorchObjective(object):
    """PyTorch objective function, wrapped to be called by scipy.optimize."""
    def __init__(self, agent):
        self.f = agent.nn # some pytorch module, that produces a scalar loss
        # make an x0 from the parameters in this module
        parameters = OrderedDict(agent.nn.named_parameters())
        self.param_shapes = {n:parameters[n].size() for n in parameters}
        # ravel and concatenate all parameters to make x0
        self.x0 = np.concatenate([parameters[n].data.numpy().ravel() 
                                   for n in parameters])
        
        self.eval_fn = agent.evaluate
        self.c = 0

    def unpack_parameters(self, x):
        """optimize.minimize will supply 1D array, chop it up for each parameter."""
        i = 0
        named_parameters = OrderedDict()
        for n in self.param_shapes:
            param_len = reduce(lambda x,y: x*y, self.param_shapes[n])
            # slice out a section of this length
            param = x[i:i+param_len]
            # reshape according to this size, and cast to torch
            param = param.reshape(*self.param_shapes[n])
            named_parameters[n] = torch.from_numpy(param)
            # update index
            i += param_len
        return named_parameters

    def pack_grads(self):
        """pack all the gradients from the parameters in the module into a
        numpy array."""
        grads = []
        for p in self.f.parameters():
            grad = p.grad.data.numpy()
            grads.append(grad.ravel())
        return np.concatenate(grads)

    def is_new(self, x):
        # if this is the first thing we've seen
        if not hasattr(self, 'cached_x'):
            return True
        else:
            # compare x to cached_x to determine if we've been given a new input
            x, self.cached_x = np.array(x), np.array(self.cached_x)
            error = np.abs(x - self.cached_x)
            return error.max() > 1e-8

    def cache(self, x):
        # unpack x and load into module 
        state_dict = self.unpack_parameters(x)
        self.f.load_state_dict(state_dict)
        # store the raw array as well
        self.cached_x = x
        # zero the gradient
        self.f.zero_grad()
        # use it to calculate the objective
        obj = self.eval_fn()
        # backprop the objective
        # obj.backward()
        self.cached_f = obj
        return obj

    def fun(self, x):
        self.c += 1
        if self.is_new(x):
            # print(self.c)
            self.cache(x)
        return self.cached_f

In [None]:
import gym_gvgai

In [None]:
from agent.NNagent import NNagent

In [None]:
from generator.env_gen_wrapper import GridGame

In [None]:
from scipy.optimize import Bounds

In [None]:
_x = NNagent(GridGame(game='zelda', 
                     play_length=200,
                     path=gym_gvgai.dir + '/envs/games/zelda_v0/', 
                     lvl_name='zelda_lvl0.txt', 
                     mechanics=['1', '2', '3', '+', 'g', 'w'], # monsters, key, door, wall
                  )
         )

In [None]:
_x

In [None]:
_x.nn

In [None]:
z = PyTorchObjective(_x)

In [None]:
z.x0.shape


In [None]:
import matplotlib.pyplot as plt


In [None]:
tile = _x.reset()

In [None]:
o = _x.get_action(tile)

In [None]:
o

In [None]:
bounds = [(-1, 1)]*z.x0.shape[0]

In [None]:
pop = np.random.randn(24, z.x0.shape[0])

In [None]:
q = np.vstack((z.x0, pop))

In [None]:
from utils.diff_evo import differential_evolution

In [None]:
ans = differential_evolution(z.fun, bounds, 
                             strategy='rand2exp',
                             popsize=q.shape[0], 
                             polish=False, 
                             init=q)

In [None]:
ans

In [None]:
state_dict = z.unpack_parameters(ans.x)
z.f.load_state_dict(state_dict)

In [None]:
z.f == _x.nn

In [None]:
import matplotlib.pyplot as plt
from IPython import display


In [None]:
_x.fitness(fn=show_state)