In [1]:
import numpy as np
import os
import torch
import torch.nn
import torch.multiprocessing as mp
import matplotlib.pyplot as plt
from rollout import Rollout

device = torch.device('cuda')
mp.set_start_method('spawn')

In [2]:
device = torch.device('cuda')

encoder = torch.load("models/encoder-five-epochs.pt")
encoder.to(device)
encoder.share_memory()
encoder.eval()

rnn = torch.load("models/mdnrnn-wo-training.pt")
rnn.to(device)
rnn.share_memory()
rnn.eval()

cont = torch.load("models/controller-best-wo-training.pt")
cont.to(device)
cont.share_memory()
cont.eval()

cont.__name__ = 'global'



In [3]:
curr_best_score  = curr_mean_score = 0

In [4]:
from collections import OrderedDict 

NUM_GENERATION = 250
NUM_AVERAGE_REWARD_OVER = 4

def param2numpy(model):
    params = model.parameters()
    w = next(params) # shape 3x288
    b = next(params) # shape 3
    flat = torch.cat([w.flatten(), b], dim=-1)
    return flat.detach().cpu().numpy()

def load_param(model, params):
    w, b = params[:864], params[864:]
    w, b = [torch.tensor(i).cuda() for i in (w, b)]
    w = w.reshape(3, 288)
    updated_dict = OrderedDict({"fc.weight":w, "fc.bias":b})
    model.load_state_dict(updated_dict)

def get_reward_for_solution(soln):
    global cont
    load_param(cont, soln)
    
    pool = mp.Pool(processes=4)
    rewards = pool.starmap(Rollout, [[encoder, rnn, cont, False] for i in range(4)])
    pool.close()

    return -np.mean(rewards).item()

In [6]:
import cma
es = cma.CMAEvolutionStrategy(param2numpy(cont), 0.01)

(12_w,24)-aCMA-ES (mu_w=7.0,w_1=24%) in dimension 867 (seed=464670, Thu Dec 26 11:52:14 2019)


In [7]:
for generation in range(NUM_GENERATION):
    assert cont.__name__ == 'global'
    
    print("Generation:: {}..".format(generation + 1))
    solutions = es.ask(24)
    
    function_values = [get_reward_for_solution(s) for s in solutions]

    if -min(function_values) > curr_best_score:
        curr_best_score = -min(function_values)
        torch.save(cont, "models/controller-best.pt")
        print("Best saved with reward {}".format(-min(function_values)))
        
    if generation%10==0:
        torch.save(cont, "models/controller-generation-{}".format(generation))
        print("Model saved as: models/controller-generation-{}".format(generation))

    es.tell(solutions, function_values)
    es.logger.add()
    es.disp()

torch.save(cont, "models/controller-generation-last")

Generation:: 1..
Best saved with reward 535.9256078198114
Model saved as: models/controller-generation-0
Iterat #Fevals   function value  axis ratio  sigma  min&max std  t[m:s]
    1     24 -5.359256078198114e+02 1.0e+00 9.91e-03  1e-02  1e-02 8:30.9
Generation:: 2..
Best saved with reward 545.8722922098467
    2     48 -5.458722922098467e+02 1.0e+00 9.83e-03  1e-02  1e-02 16:55.4
Generation:: 3..
    3     72 -5.097833143133718e+02 1.0e+00 9.76e-03  1e-02  1e-02 25:06.7
Generation:: 4..
Best saved with reward 611.6602876683748
    4     96 -6.116602876683748e+02 1.0e+00 9.69e-03  1e-02  1e-02 33:36.6
Generation:: 5..
Best saved with reward 613.4733060113632
    5    120 -6.134733060113632e+02 1.0e+00 9.62e-03  1e-02  1e-02 42:00.0
Generation:: 6..
Best saved with reward 615.5078399969178
    6    144 -6.155078399969178e+02 1.0e+00 9.55e-03  1e-02  1e-02 50:12.5
Generation:: 7..
Best saved with reward 707.8666333721158
    7    168 -7.078666333721158e+02 1.0e+00 9.49e-03  9e-03  9e-03 

KeyboardInterrupt: 