Skip to content

Commit

Permalink
reinforce patched
Browse files Browse the repository at this point in the history
  • Loading branch information
awarebayes committed Nov 21, 2019
1 parent 520ec28 commit 7c4e1d3
Show file tree
Hide file tree
Showing 5 changed files with 38 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -415,8 +415,10 @@
" if step % params['policy_step'] == 0 and step > 0:\n",
" \n",
" policy_loss = params['reinforce'](nets['policy_net'], optimizer['policy_optimizer'], learn=learn)\n",
" \n",
" del nets['policy_net'].rewards[:]\n",
" del nets['policy_net'].saved_log_probs[:]\n",
" \n",
" print('step: ', step, '| value:', value_loss.item(), '| policy', policy_loss.item())\n",
" \n",
" recnn.utils.soft_update(nets['value_net'], nets['target_value_net'], soft_tau=params['soft_tau'])\n",
Expand Down

Large diffs are not rendered by default.

10 changes: 0 additions & 10 deletions recnn/nn/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,23 +82,13 @@ def __init__(self, input_dim, action_dim, hidden_size, init_w=0):
self.saved_log_probs = []
self.rewards = []

# with large action spaces it can be overflowed
# in order to prevent this, I set a max limit

self.save_limit = 15

def forward(self, inputs):
x = inputs
x = F.relu(self.linear1(x))
action_scores = self.linear2(x)
return F.softmax(action_scores)

def select_action(self, state):

if len(self.saved_log_probs) > self.save_limit:
del self.saved_log_probs[:]
del self.rewards[:]

probs = self.forward(state)
m = Categorical(probs)
action = m.sample()
Expand Down
2 changes: 1 addition & 1 deletion recnn/nn/update/bcq.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def bcq_update(batch, params, nets, optimizer,
perturbator_loss = perturbator_loss.mean()

if learn:
if step % params['perturbator_step']:
if step % params['perturbator_step'] == 0:
optimizer['perturbator_optimizer'].zero_grad()
perturbator_loss.backward()
torch.nn.utils.clip_grad_norm_(nets['perturbator_net'].parameters(), -1, 1)
Expand Down
13 changes: 8 additions & 5 deletions recnn/nn/update/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,19 +52,22 @@ def __call__(self, policy, optimizer, learn=True):
def reinforce_update(batch, params, nets, optimizer,
device=torch.device('cpu'),
debug=None, writer=utils.DummyWriter(),
learn=False, step=-1):
learn=True, step=-1):

# Due no its mechanics, reinforce doesn't support testing!
learn = True

state, action, reward, next_state, done = data.get_base_batch(batch)

predicted_action, predicted_probs = nets['policy_net'].select_action(state)
reward = nets['value_net'](state, predicted_probs).detach()
nets['policy_net'].rewards.append(reward.mean())

value_loss = value_update(batch, params, nets, optimizer,
writer=writer,
device=device,
debug=debug, learn=learn, step=step)
writer=writer, device=device,
debug=debug, learn=True, step=step)

if len(nets['policy_net'].saved_log_probs) > params['policy_step'] and learn:
if step % params['policy_step'] == 0 and step > 0:
policy_loss = params['reinforce'](nets['policy_net'], optimizer['policy_optimizer'], learn=learn)

print('step: ', step, '| value:', value_loss.item(), '| policy', policy_loss.item())
Expand Down

0 comments on commit 7c4e1d3

Please sign in to comment.