# Solving Cart Pole using ES

In [1]:
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import torch
import matplotlib.pyplot as plt

In [2]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
device

device(type='cuda', index=0)

In [4]:
def model(x, unpacked_params):
    l1, b1, l2, b2, l3, b3 = unpacked_params
    y = torch.relu(F.linear(x, l1, b1))
    y = torch.relu(F.linear(y, l2, b2))
    y = F.linear(y, l3, b3)

    return y


# def model(x,unpacked_params):
#     l1,b1,l2,b2,l3,b3 = unpacked_params #A
#     y = torch.nn.functional.linear(x,l1,b1) #B
#     y = torch.relu(y) #C
#     y = torch.nn.functional.linear(y,l2,b2)
#     y = torch.relu(y)
#     y = torch.nn.functional.linear(y,l3,b3)
#     y = torch.log_softmax(y,dim=0) #D
#     return y


In [5]:
def unpack_params(params, layers=[(25, 4), (10, 25), (2, 10)]):
    unpacked_params = []
    e = 0
    for i, l in enumerate(layers):
        s, e = e, e + np.prod(l)
        weights = params[s:e].view(l)
        s, e = e, e + l[0]
        bias = params[s:e]
        unpacked_params.extend([weights, bias])
    return unpacked_params

In [6]:
def spawn_population(N=50, size=407):
    pop = []
    for i in range(N):
        vec = torch.randn(size)/2.0
        fit = 0
        pop.append({'params': vec, 'fitness': fit})
    return pop
    
def recombine(x1, x2):
    x1 = x1['params']
    x2 = x2['params']
    l = x1.shape[0]
    split_pt = np.random.randint(l)
    child1 = torch.zeros(l)
    child2 = torch.zeros(l)
    child1[:split_pt] = x1[:split_pt]
    child2[:split_pt] = x2[:split_pt]
    child1[split_pt:] = x2[split_pt:]
    child2[split_pt:] = x1[split_pt:]
    c1 = {'params': child1, 'fitness': 0}
    c2 = {'params': child2, 'fitness': 0}
    return c1, c2

def mutate(x, rate=0.01):
    x_ = x['params']
    num_to_change = int(rate*x_.shape[0])
    idx = np.random.randint(low=0, high=x_.shape[0], size=(num_to_change,))
    x_[idx] = torch.randn(num_to_change) / 10.0
    x['params'] = x_

    return x
    

In [7]:
import gym

env = gym.make('CartPole-v0')

def test_model(agent):
    done = False
    state = torch.from_numpy(env.reset()).float()
    score = 0
    while not done:
        params = unpack_params(agent['params'])
        logits = model(state, params)
        action = torch.distributions.Categorical(logits=logits).sample()
        state, reward, done, _ = env.step(action.item())
        state = torch.from_numpy(state).float()
        score += reward 
    return score

def evaluate_population(pop):
    total_fit = 0
    lp = len(pop)
    for agent in pop:
        agent['fitness'] = test_model(agent)
        total_fit += agent['fitness']

    return pop, total_fit / lp

        


In [8]:
def next_generation(pop, mut_rate=0.001, tournament_size=0.2):
    new_pop = []
    lp = len(pop)
    while len(new_pop) < len(pop):
        rids = np.random.randint(low=0, high=lp, size = int(tournament_size*lp))
        batch = np.array([[i, x['fitness']] for (i, x) in enumerate(pop) if i in rids])
        scores = batch[batch[:, 1].argsort()]
        i0, i1 = int(scores[-1][0]), int(scores[-2][0])
        parent_0, parent_1 = pop[i0], pop[i1]
        offspring_ = recombine(parent_0, parent_1)
        child_1 = mutate(offspring_[0], rate=mut_rate)
        child_2 = mutate(offspring_[1], rate=mut_rate)
        offspring = [child_1, child_2]
        new_pop.extend(offspring)
    return new_pop

In [9]:
num_generations = 25
population_size = 500
mutation_rate = 0.01
pop_fit = []
pop = spawn_population(N=population_size, size=407)
for i in range(num_generations):
    pop, avg_fit = evaluate_population(pop)
    print(f'Generation: {i+1}, score: {avg_fit}')
    pop_fit.append(pop_fit)
    pop = next_generation(pop, mut_rate=mutation_rate, tournament_size=0.2)


Generation: 1, score: 16.35
Generation: 2, score: 21.172
Generation: 3, score: 27.236
Generation: 4, score: 37.132
Generation: 5, score: 47.208
Generation: 6, score: 101.432
Generation: 7, score: 107.77
Generation: 8, score: 123.532
Generation: 9, score: 123.632
Generation: 10, score: 138.552
Generation: 11, score: 139.798
Generation: 12, score: 135.864
Generation: 13, score: 139.89
Generation: 14, score: 147.512
Generation: 15, score: 133.42
Generation: 16, score: 138.282
Generation: 17, score: 143.31
Generation: 18, score: 137.544
Generation: 19, score: 138.962
Generation: 20, score: 143.71
Generation: 21, score: 143.958
Generation: 22, score: 147.916
Generation: 23, score: 155.82
Generation: 24, score: 148.952
Generation: 25, score: 154.156
