In [1]:
import torch
import torch.nn as nn
import torch.autograd as autograd
import numpy as np

from pr_replay_buffer import PrioritizedBuffer
from network import ConvDQN, DQN

def mini_batch_train(env, agent, max_episodes, max_steps, batch_size):
    episode_rewards = []
    episodes=[]
    for episode in range(max_episodes):
        state = env.reset()
        episode_reward = 0
        episodes.append(episode)
        for step in range(max_steps):
            action = agent.get_action(state)
            next_state, reward, done, _ = env.step(action)
            agent.replay_buffer.push(state, action, reward, next_state, done)
            episode_reward += reward

            if len(agent.replay_buffer) > batch_size:
                agent.update(batch_size)   

            if done or step == max_steps-1:
                episode_rewards.append(episode_reward)
                print("Episode " + str(episode) + ": " + str(episode_reward))
                break
            env.render()
            state = next_state

    return episode_rewards,episodes
class PERAgent:

    def __init__(self, env, use_conv=True, learning_rate=3e-4, gamma=0.99, buffer_size=10000):
        self.env = env
        self.gamma = gamma
        self.replay_buffer = PrioritizedBuffer(buffer_size)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")        

        if use_conv:
            self.model = ConvDQN(self.env.observation_space.shape, env.action_space.n).to(self.device)
        else:
            self.model = DQN(self.env.observation_space.shape, env.action_space.n).to(self.device)
          
        self.optimizer = torch.optim.Adam(self.model.parameters())
        self.MSE_loss = nn.MSELoss()

    def get_action(self, state, eps=0.0):
        state = torch.FloatTensor(state).float().unsqueeze(0).to(self.device)
        qvals = self.model.forward(state)
        action = np.argmax(qvals.cpu().detach().numpy())
        
        if(np.random.rand() > eps):
            return self.env.action_space.sample()
          
        return action

    def _sample(self, batch_size):
        return self.replay_buffer.sample(batch_size)

    def _compute_TDerror(self, batch_size):
        transitions, idxs, IS_weights = self._sample(batch_size)
        states, actions, rewards, next_states, dones = transitions

        states = torch.FloatTensor(states).to(self.device)
        actions = torch.LongTensor(actions).to(self.device)
        rewards = torch.FloatTensor(rewards).to(self.device)
        next_states = torch.FloatTensor(next_states).to(self.device)
        dones = torch.FloatTensor(dones).to(self.device)
        IS_weights = torch.FloatTensor(IS_weights).to(self.device)

        curr_Q = self.model.forward(states).gather(1, actions.unsqueeze(1))
        curr_Q = curr_Q.squeeze(1)
        next_Q = self.model.forward(next_states)
        max_next_Q = torch.max(next_Q, 1)[0]
        expected_Q = rewards.squeeze(1) + self.gamma * max_next_Q

        td_errors = torch.pow(curr_Q - expected_Q, 2) * IS_weights

        return td_errors, idxs

    def update(self, batch_size):
        td_errors, idxs = self._compute_TDerror(batch_size)

        # update model
        td_errors_mean = td_errors.mean()
        self.optimizer.zero_grad()
        td_errors_mean.backward()
        self.optimizer.step()

        # update priorities
        for idx, td_error in zip(idxs, td_errors.cpu().detach().numpy()):
            self.replay_buffer.update_priority(idx, td_error)

  priority = 1.0 if self.current_length is 0 else self.sum_tree.tree.max()


In [2]:
import gym


env_id = "CartPole-v0"
MAX_EPISODES = 1000
MAX_STEPS = 500
BATCH_SIZE = 32

env = gym.make(env_id)
agent = PERAgent(env, use_conv=False)
episode_rewards = mini_batch_train(env, agent, MAX_EPISODES, MAX_STEPS, BATCH_SIZE)

Episode 0: 60.0
Episode 1: 15.0
Episode 2: 32.0
Episode 3: 25.0
Episode 4: 29.0
Episode 5: 21.0
Episode 6: 15.0
Episode 7: 30.0
Episode 8: 33.0
Episode 9: 22.0
Episode 10: 10.0
Episode 11: 16.0
Episode 12: 53.0
Episode 13: 29.0
Episode 14: 9.0
Episode 15: 12.0
Episode 16: 28.0
Episode 17: 11.0
Episode 18: 30.0
Episode 19: 24.0
Episode 20: 18.0
Episode 21: 24.0
Episode 22: 37.0
Episode 23: 27.0
Episode 24: 29.0
Episode 25: 20.0
Episode 26: 35.0
Episode 27: 17.0
Episode 28: 10.0
Episode 29: 18.0
Episode 30: 21.0
Episode 31: 37.0
Episode 32: 34.0
Episode 33: 11.0
Episode 34: 17.0
Episode 35: 19.0
Episode 36: 14.0
Episode 37: 26.0
Episode 38: 40.0
Episode 39: 14.0
Episode 40: 15.0
Episode 41: 13.0
Episode 42: 36.0
Episode 43: 25.0
Episode 44: 31.0
Episode 45: 17.0
Episode 46: 18.0
Episode 47: 14.0
Episode 48: 22.0
Episode 49: 38.0
Episode 50: 33.0
Episode 51: 44.0
Episode 52: 14.0
Episode 53: 13.0


KeyboardInterrupt: 