Skip to content

Commit

Permalink
Step based DQN
Browse files Browse the repository at this point in the history
  • Loading branch information
ShangtongZhang committed Jun 25, 2018
1 parent f8bbdfe commit aa5c265
Show file tree
Hide file tree
Showing 9 changed files with 156 additions and 159 deletions.
42 changes: 8 additions & 34 deletions deep_rl/agent/BaseAgent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,10 @@
class BaseAgent:
def __init__(self, config):
self.config = config
self.evaluation_env = self.config.evaluation_env
if self.evaluation_env is not None:
self.evaluation_state = self.evaluation_env.reset()
self.evaluation_return = 0

def close(self):
if hasattr(self.task, 'close'):
self.task.close()
if hasattr(self.evaluation_env, 'close'):
self.evaluation_env.close()

def save(self, filename):
torch.save(self.network.state_dict(), filename)
Expand All @@ -29,44 +23,24 @@ def load(self, filename):
state_dict = torch.load(filename, map_location=lambda storage, loc: storage)
self.network.load_state_dict(state_dict)

def evaluation_action(self, state):
self.config.state_normalizer.set_read_only()
state = self.config.state_normalizer(np.stack([state]))
action = self.network.predict(state, to_numpy=True)
self.config.state_normalizer.unset_read_only()
return np.argmax(action.flatten())
def eval_step(self, state):
raise Exception('eval_step not implemented')

def deterministic_episode(self):
env = self.config.evaluation_env
def eval_episode(self):
env = self.config.eval_env
state = env.reset()
total_rewards = 0
while True:
action = self.evaluation_action(state)
action = self.eval_step(state)
state, reward, done, _ = env.step(action)
total_rewards += reward
if done:
break
return total_rewards

def evaluation_episodes(self):
interval = self.config.evaluation_episodes_interval
if not interval or self.total_steps % interval:
return
def eval_episodes(self):
rewards = []
for ep in range(self.config.evaluation_episodes):
rewards.append(self.deterministic_episode())
for ep in range(self.config.eval_episodes):
rewards.append(self.eval_episode())
self.config.logger.info('evaluation episode return: %f(%f)' % (
np.mean(rewards), np.std(rewards) / np.sqrt(len(rewards))))

def evaluate(self, steps=1):
config = self.config
if config.evaluation_env is None or self.config.evaluation_episodes_interval:
return
for _ in range(steps):
action = self.evaluation_action(self.evaluation_state)
self.evaluation_state, reward, done, _ = self.evaluation_env.step(action)
self.evaluation_return += reward
if done:
self.evaluation_state = self.evaluation_env.reset()
self.config.logger.info('evaluation episode return: %f' % (self.evaluation_return))
self.evaluation_return = 0
118 changes: 58 additions & 60 deletions deep_rl/agent/DQN_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,71 +16,69 @@ def __init__(self, config):
self.config = config
self.replay = config.replay_fn()
self.task = config.task_fn()
self.network = config.network_fn(self.task.state_dim, self.task.action_dim)
self.target_network = config.network_fn(self.task.state_dim, self.task.action_dim)
self.network = config.network_fn()
self.target_network = config.network_fn()
self.optimizer = config.optimizer_fn(self.network.parameters())
self.criterion = nn.MSELoss()
self.target_network.load_state_dict(self.network.state_dict())
self.policy = config.policy_fn()
self.total_steps = 0
self.episode_reward = 0
self.episode_rewards = []
self.state = self.task.reset()

def episode(self, deterministic=False):
episode_start_time = time.time()
state = self.task.reset()
total_reward = 0.0
steps = 0
while True:
value = self.network.predict(np.stack([self.config.state_normalizer(state)]), True).flatten()
if deterministic:
action = np.argmax(value)
elif self.total_steps < self.config.exploration_steps:
action = np.random.randint(0, len(value))
else:
action = self.policy.sample(value)
next_state, reward, done, _ = self.task.step(action)
total_reward += reward
reward = self.config.reward_normalizer(reward)
if not deterministic:
self.replay.feed([state, action, reward, next_state, int(done)])
self.total_steps += 1
steps += 1
state = next_state
self.batch_indices = range_tensor(self.replay.batch_size)

if not deterministic and self.total_steps > self.config.exploration_steps \
and self.total_steps % self.config.sgd_update_frequency == 0:
experiences = self.replay.sample()
states, actions, rewards, next_states, terminals = experiences
states = self.config.state_normalizer(states)
next_states = self.config.state_normalizer(next_states)
q_next = self.target_network.predict(next_states, False).detach()
if self.config.double_q:
_, best_actions = self.network.predict(next_states).detach().max(1)
q_next = q_next.gather(1, best_actions.unsqueeze(1)).squeeze(1)
else:
q_next, _ = q_next.max(1)
terminals = tensor(terminals)
rewards = tensor(rewards)
q_next = self.config.discount * q_next * (1 - terminals)
q_next.add_(rewards)
actions = tensor(actions).unsqueeze(1).long()
q = self.network.predict(states, False)
q = q.gather(1, actions).squeeze(1)
loss = self.criterion(q, q_next)
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.network.parameters(), self.config.gradient_clip)
self.optimizer.step()
def eval_step(self, state):
self.config.state_normalizer.set_read_only()
state = self.config.state_normalizer(np.stack([state]))
q = self.network(state)
action = np.argmax(to_np(q).flatten())
self.config.state_normalizer.unset_read_only()
return action

self.evaluate()
if not deterministic and self.total_steps % self.config.target_network_update_freq == 0:
self.target_network.load_state_dict(self.network.state_dict())
if not deterministic and self.total_steps > self.config.exploration_steps:
self.policy.update_epsilon()
def step(self):
config = self.config
q_values = self.network(config.state_normalizer(np.stack([self.state])))
q_values = to_np(q_values).flatten()
if self.total_steps < config.exploration_steps \
or np.random.rand() < config.random_action_prob():
action = np.random.randint(0, config.action_dim)
else:
action = np.argmax(q_values)
next_state, reward, done, _ = self.task.step(action)
self.episode_reward += reward
self.total_steps += 1
if done:
self.episode_rewards.append(self.episode_reward)
self.episode_reward = 0
next_state = self.task.reset()
reward = config.reward_normalizer(reward)
self.replay.feed([self.state, action, reward, next_state, int(done)])
self.state = next_state

if done:
break
if self.total_steps > self.config.exploration_steps \
and self.total_steps % self.config.sgd_update_frequency == 0:
experiences = self.replay.sample()
states, actions, rewards, next_states, terminals = experiences
states = self.config.state_normalizer(states)
next_states = self.config.state_normalizer(next_states)
q_next = self.target_network(next_states).detach()
if self.config.double_q:
best_actions = torch.argmax(self.network(next_states), dim=-1)
q_next = q_next[self.batch_indices, best_actions]
else:
q_next = q_next.max(1)[0]
terminals = tensor(terminals)
rewards = tensor(rewards)
q_next = self.config.discount * q_next * (1 - terminals)
q_next.add_(rewards)
actions = tensor(actions).long()
q = self.network(states)
q = q[self.batch_indices, actions]
loss = (q_next - q).pow(2).mul(0.5).mean()
self.optimizer.zero_grad()
loss.backward()
nn.utils.clip_grad_norm_(self.network.parameters(), self.config.gradient_clip)
self.optimizer.step()

episode_time = time.time() - episode_start_time
self.config.logger.info('episode time %f, fps %f' %
(episode_time, float(steps) / episode_time))
return total_reward, steps
if self.total_steps % self.config.target_network_update_freq == 0:
self.target_network.load_state_dict(self.network.state_dict())
40 changes: 17 additions & 23 deletions deep_rl/component/replay.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import numpy as np
import torch.multiprocessing as mp
from collections import deque
from ..utils import *

class Replay:
def __init__(self, memory_size, batch_size):
Expand Down Expand Up @@ -43,59 +44,52 @@ def size(self):
def empty(self):
return not len(self.data)

class _ProcessWrapper(mp.Process):
class AsyncReplay(mp.Process):
FEED = 0
SAMPLE = 1
EXIT = 2
def __init__(self, pipe, memory_size, batch_size):
def __init__(self, memory_size, batch_size):
mp.Process.__init__(self)
self.pipe = pipe
self.__pipe, self.__worker_pipe = mp.Pipe()
self.memory_size = memory_size
self.batch_size = batch_size
self.cache_len = 2
self.__cache_len = 2
self.start()

def run(self):
torch.cuda.is_available()
from ..utils.torch_utils import tensor
replay = Replay(self.memory_size, self.batch_size)
cache = deque([], maxlen=self.cache_len)
cache = deque([], maxlen=self.__cache_len)

def sample():
batch_data = replay.sample()
batch_data = [tensor(x) for x in batch_data]
# for x in batch_data:
# x.share_memory_()
cache.append(batch_data)

while True:
op, data = self.pipe.recv()
op, data = self.__worker_pipe.recv()
if op == self.FEED:
replay.feed(data)
elif op == self.SAMPLE:
if len(cache) == 0:
sample()
self.pipe.send(cache.popleft())
while len(cache) < self.cache_len:
self.__worker_pipe.send(cache.popleft())
while len(cache) < self.__cache_len:
sample()
elif op == self.EXIT:
self.pipe.close()
self.__worker_pipe.close()
return
else:
raise Exception('Unknown command')


class AsyncReplay:
def __init__(self, memory_size, batch_size):
self.pipe, worker_pipe = mp.Pipe()
self.worker = _ProcessWrapper(worker_pipe, memory_size, batch_size)
self.worker.start()

def feed(self, exp):
self.pipe.send([_ProcessWrapper.FEED, exp])
self.__pipe.send([self.FEED, exp])

def sample(self):
self.pipe.send([_ProcessWrapper.SAMPLE, None])
return self.pipe.recv()
self.__pipe.send([self.SAMPLE, None])
return self.__pipe.recv()

def close(self):
self.pipe.send([_ProcessWrapper.EXIT, None])
self.__pipe.send([self.EXIT, None])
self.__pipe.close()

4 changes: 2 additions & 2 deletions deep_rl/component/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ def __init__(self, name='CartPole-v0', max_steps=200, log_dir=None):

class PixelAtari(BaseTask):
def __init__(self, name, seed=0, log_dir=None,
frame_skip=4, history_length=4, dataset=False):
frame_skip=4, history_length=4, dataset=False, episode_life=True):
BaseTask.__init__(self)
env = make_atari(name, frame_skip)
env.seed(seed)
if dataset:
env = DatasetEnv(env)
self.dataset_env = env
env = self.set_monitor(env, log_dir)
env = wrap_deepmind(env, history_length=history_length)
env = wrap_deepmind(env, history_length=history_length, episode_life=episode_life)
self.env = env
self.action_dim = self.env.action_space.n
self.state_dim = self.env.observation_space.shape
Expand Down
4 changes: 1 addition & 3 deletions deep_rl/network/network_heads.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@ def __init__(self, output_dim, body):
self.body = body
self.to(Config.DEVICE)

def predict(self, x, to_numpy=False):
def forward(self, x):
phi = self.body(tensor(x))
y = self.fc_head(phi)
if to_numpy:
y = y.cpu().detach().numpy()
return y

class DuelingNet(nn.Module, BaseNet):
Expand Down
44 changes: 22 additions & 22 deletions deep_rl/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,54 +18,54 @@ def __init__(self):
self.network_fn = None
self.actor_network_fn = None
self.critic_network_fn = None
self.policy_fn = None
self.replay_fn = None
self.random_process_fn = None
self.discount = 0.99
self.target_network_update_freq = 0
self.max_episode_length = 0
self.exploration_steps = 0
self.discount = None
self.target_network_update_freq = None
self.exploration_steps = None
self.logger = None
self.history_length = 1
self.history_length = None
self.double_q = False
self.tag = 'vanilla'
self.num_workers = 1
self.gradient_clip = 5
self.update_interval = 1
self.gradient_clip = 0.5
self.entropy_weight = 0.01
self.use_gae = False
self.gae_tau = 1.0
self.noise_decay_interval = 0
self.target_network_mix = 0.001
self.state_normalizer = RescaleNormalizer()
self.reward_normalizer = RescaleNormalizer()
self.hybrid_reward = False
self.episode_limit = 0
self.min_memory_size = 200
self.master_fn = None
self.master_optimizer_fn = None
self.num_heads = 10
self.min_epsilon = 0
self.save_interval = 0
self.max_steps = 0
self.render_episode_freq = 0
self.rollout_length = None
self.value_loss_weight = 1.0
self.iteration_log_interval = 30
self.categorical_v_min = -10
self.categorical_v_max = 10
self.categorical_n_atoms = 51
self.num_quantiles = 10
self.gaussian_noise_scale = 0.3
self.optimization_epochs = 4
self.num_mini_batches = 32
self.test_interval = 0
self.test_repetitions = 10
self.evaluation_env = None
self.termination_regularizer = 0
self.evaluation_episodes_interval = 0
self.evaluation_episodes = 0
self.sgd_update_frequency = 4
self.random_action_prob = None
self.__eval_env = None
self.log_interval = int(1e3)
self.save_interval = int(1e5)
self.eval_interval = 0
self.eval_episodes = 10

@property
def eval_env(self):
return self.__eval_env

@eval_env.setter
def eval_env(self, env):
self.__eval_env = env
self.state_dim = env.state_dim
self.action_dim = env.action_dim
self.task_name = env.name

def add_argument(self, *args, **kwargs):
self.parser.add_argument(*args, **kwargs)
Expand Down

0 comments on commit aa5c265

Please sign in to comment.