In [21]:
import gym

import numpy as np
import torch
from torch import nn
from torch.distributions import Categorical
import matplotlib.pyplot as plt

In [22]:
env = gym.make('CartPole-v0')

In [23]:
env.action_space.n

2

In [24]:
env.observation_space

Box([-4.8000002e+00 -3.4028235e+38 -4.1887903e-01 -3.4028235e+38], [4.8000002e+00 3.4028235e+38 4.1887903e-01 3.4028235e+38], (4,), float32)

In [25]:
state_shape, action_shape = env.observation_space.shape, (2,)

In [26]:
env.action_space

Discrete(2)

In [51]:
class Net(nn.Module):
    def __init__(self, state_shape, action_shape):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(np.prod(state_shape), 16), nn.ReLU(inplace=True),
            nn.Linear(16, 30), nn.ReLU(inplace=True),
            nn.Linear(30, 30), nn.ReLU(inplace=True),
            nn.Linear(30, np.prod(action_shape)),
        )
        
    def forward(self, obs):
        if not isinstance(obs, torch.Tensor):
            obs = torch.tensor(obs, dtype=torch.float)
        batch = obs.shape[0]
        logits = self.model(obs.view(batch, -1))
        return logits
    
    def calc_action(self, obs):
        bs = obs.shape[:-1]
        obs = obs.reshape(-1, obs.shape[-1])
        logits = self.forward(obs)
        m = Categorical(logits.softmax(-1))
        action = m.sample()
        return action.reshape(*bs, 1).numpy()

class Net(nn.Module):
    def __init__(self, state_shape, action_shape):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(np.prod(state_shape), 1),
        )
        
    def forward(self, obs):
        if not isinstance(obs, torch.Tensor):
            obs = torch.tensor(obs, dtype=torch.float)
        batch = obs.shape[0]
        logits = self.model(obs.view(batch, -1))
        return logits
    
    def calc_action(self, obs):
        # obs = obs.reshape(1, obs.shape[-1])
        logits = self.forward(obs[None])[0].item()
        action = 1 if logits>0 else 0
        return np.array(action)[None]
    
def random_policy():
    net = Net(state_shape, action_shape)
    # optim = torch.optim.Adam(net.parameters(), lr=1e-3)
    # policy = ts.policy.DQNPolicy(net, optim, discount_factor=0.9, estimation_step=3, target_update_freq=320)
    return net

In [52]:
agent = random_policy()

In [53]:
# %%time
import gym
env = gym.make('CartPole-v0')
r = 0
for i_episode in range(100):
    obs = env.reset()
    for t in range(100):
        action = agent.calc_action(obs)[0]
        obs, reward, done, info = env.step(action)
        if done:
            break
        r+=1
env.close()

In [54]:
r

1971

In [55]:
def run_agent(env, agent, n_episodes=None, n_trans=None):
    obs = env.reset()
    
    rs_eps = []
    rs = []
    
    i = 0
    while (n_episodes is not None and i<n_episodes) or (n_trans is not None and i<n_trans):
        action = agent.calc_action(obs)[0]
        obs, reward, done, info = env.step(action)
        rs.append(reward)
        
        if done:
            obs = env.reset()
            rs_eps.append(rs)
            rs = []
            
            if n_episodes is not None:
                i+=1
                
        if n_trans is not None:
            i+=1
    rs_eps.append(rs)
            
    return np.array([np.sum(rs) for rs in rs_eps]), [np.array(rs) for rs in rs_eps]
    

r, _ = run_agent(env, agent, n_trans=1000)

In [57]:
def train_es(agent, env, n_step, n_pop, n_trans):
    theta = nn.utils.parameters_to_vector(agent.parameters()).detach()
    print(theta.shape)
    sigma = 1e-1
    alpha = 1e-3
    
    for i in range(n_step):
        eps = torch.randn(n_pop, theta.shape[0])
        theta_dev = theta+sigma*eps
        rs = []
        for thetap in theta_dev:
            nn.utils.vector_to_parameters(thetap, agent.parameters())
            r, _ = run_agent(env, agent, n_episodes=100, n_trans=None)
            r = r.mean()
            rs.append(r)
        rs = torch.tensor(rs, dtype=torch.float32)
        print(rs.mean().item())
        rs = (rs-rs.mean())/rs.std()
        # print(rs.shape)
        theta += alpha/n_pop/sigma * (rs[:, None]*eps).sum(dim=0)
        # print(rs.mean())

agent = random_policy()
train_es(agent, env, 100, 50, 1000)

torch.Size([5])
77.57148742675781
70.43228149414062
76.85861206054688
76.88138580322266
91.90811157226562
98.66375732421875
91.12118530273438
101.0101089477539
109.3673324584961
97.78672790527344
95.45445251464844
98.8166275024414
101.55445098876953
107.7499008178711


KeyboardInterrupt: 


$\sigma \epsilon$

