In [1]:
import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.optim as optim

%matplotlib inline

In [97]:
class Memory:
    def __init__(self):
        self.policies = []
        self.actions = []
        self.rewards = []
    
    def store(self, a_proba, a_index, r):
        self.policies.append(a_proba)  # (t, n_actions)
        self.actions.append(a_index)   # (t,)
        self.rewards.append(r)         # (t,)
            
    def _clear(self):
        self.policies = []
        self.actions = []
        self.rewards = []
        
    def withdraw(self):        
        return torch.cat(self.policies, dim=0), torch.tensor(self.actions), self.rewards

In [73]:
class PolicyGradient:
    def __init__(self, state_dim, n_actions, gamma, learning_rate):
        self.state_dim = state_dim
        self.n_actions = n_actions
        self.gamma = gamma
        self.learning_rate = learning_rate
        self.memory = Memory()
        self.network = self._build()
        self.optimizer = optim.Adam(self.network.parameters(), lr=self.learning_rate)
        
    def _build(self):
        network = torch.nn.Sequential(
            nn.Linear(self.state_dim, 16),
            nn.ReLU(inplace=True),
            nn.Linear(16, self.n_actions),
            nn.Softmax(dim=-1)
        )
        
        return network
        
    def _calculate_reward(self, rewards):
        r_list = []
        
        R = 0
        for r in rewards[::-1]:
            R = r + self.gamma * R
            r_list.insert(0, R)
        
        r_tensor = torch.tensor(r_list)
        r_tensor = torch.unsqueeze(r_tensor, dim=1)  # (t, 1)
        
        return r_tensor
    
    def _one_hot_action(self, actions):
        t = actions.size(0)
        actions = torch.unsqueeze(actions, dim=1)  # (t, 1)
        one_hot_actions = torch.zeros((t, self.n_actions), requires_grad=False)  # (t, n_actions)
        one_hot_actions.scatter_(dim=1, index=actions, value=1)
        
        return one_hot_actions
    
    def _clear(self):
        self.memory._clear()
    
    def store(self, a_proba, a_index, r):
        self.memory.store(a_proba, a_index, r)
    
    def choose_action(self, s):
        s = torch.tensor(s, dtype=torch.float32)
        s = torch.unsqueeze(s, dim=0)
        a_proba = self.network(s)
        p = a_proba.detach().numpy().reshape(-1)
        a = np.random.choice(self.n_actions, p=p)
        
        return a_proba, a
        
    def update_policy(self):
        policies, actions, rewards = self.memory.withdraw()
        one_hot_actions = self._one_hot_action(actions)
        
        rewards = self._calculate_reward(rewards)  # (t, 1)
        rewards = (rewards - rewards.mean()) / (rewards.std() + 1e-8)
        loss = - torch.sum(rewards * (one_hot_actions * torch.log(policies)))
        
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        
        self._clear()

In [13]:
import gym

In [14]:
env = gym.make('MountainCar-v0')
state = env.reset()

[33mWARN: gym.spaces.Box autodetected dtype as <class 'numpy.float32'>. Please provide explicit dtype.[0m


In [101]:
actor = PolicyGradient(state_dim=len(state), n_actions=env.action_space.n, gamma=0.9, learning_rate=0.001)

In [102]:
step = 0
rewards = []

for episode in range(500):
    period = 0
    total_reward = 0
    
    state = env.reset()
    
    while True:
        env.render()
        
        action_proba, action = actor.choose_action(state)
        
        state_next, reward, done, info = env.step(action)
        
        total_reward += reward
        
        actor.store(action_proba, action, reward)
        
        state = state_next
        
        if done:
            actor.update_policy()
            break
            
        step += 1
        period += 1
    
    print('Episode: {:3d}, reward: {:3d}'.format(episode, int(total_reward)))
    rewards.append(total_reward)

print('game over')

Episode:   0, reward: -200


RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.