Skip to content

Commit

Permalink
action source added to the reinforce
Browse files Browse the repository at this point in the history
  • Loading branch information
awarebayes committed Dec 9, 2019
1 parent 8794906 commit 0945ec0
Show file tree
Hide file tree
Showing 8 changed files with 161 additions and 121 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ recnn/.idea/
*csv
*zip
.idea/
.desktopfolder

# Byte-compiled / optimized / DLL files
__pycache__/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@
"beta_net = Beta().to(cuda)\n",
"value_net = recnn.nn.Critic(1290, num_items, 2048, 54e-2).to(cuda)\n",
"policy_net = recnn.nn.DiscreteActor(1290, num_items, 2048).to(cuda)\n",
"\n",
"# as miracle24 has suggested https://github.com/awarebayes/RecNN/issues/7\n",
"# you can enable this to be more like the paper authors meant it to\n",
"# policy_net.action_source = {'pi': 'beta', 'beta': 'beta'}\n",
"\n",
"reinforce = recnn.nn.Reinforce(policy_net, value_net)\n",
"reinforce = reinforce.to(cuda)\n",
Expand Down

Large diffs are not rendered by default.

186 changes: 98 additions & 88 deletions examples/[Library Basics]/1. Getting Started.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion readme.md
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ The repo consists of two parts: the library (./recnn) and the playground (./exam
| Twin Delayed DDPG (TD3) | https://arxiv.org/abs/1802.09477 | examples/1.Vanilla RL/TD3 |
| Soft Actor-Critic | https://arxiv.org/abs/1801.01290 | examples/1.Vanilla RL/SAC |
| Batch Constrained Q-Learning | https://arxiv.org/abs/1812.02900 | examples/99.To be released/BCQ |
| REINFORCE Top-K Off-Policy Correction | https://arxiv.org/abs/1509.02971 | examples/2. REINFORCE TopK |
| REINFORCE Top-K Off-Policy Correction | https://arxiv.org/abs/1812.02353 | examples/2. REINFORCE TopK |

</p>

Expand Down
6 changes: 3 additions & 3 deletions recnn/data/dataset_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
Chain of responsibility pattern.
https://refactoring.guru/design-patterns/chain-of-responsibility/python/example
RecNN is designed to work with your dataflow.
RecNN is designed to work with your data flow.
Function that contain 'dataset' are needed to interact with environment.
The environment is provided via env.argument.
These functions can interact with env and set up some stuff how you like.
Expand Down Expand Up @@ -38,7 +38,7 @@ def prepare_dataset(**kwargs):
Notice: prepare_dataset doesn't take **reduce_items_to** argument, but it is required in truncate_dataset.
As I previously mentioned RecNN is designed to be argument agnostic, meaning you provide some kwarg in the
build_data_pipeline function and it is passed down the function chain. If needed, it will be used. Otherwise ignored
build_data_pipeline function, and it is passed down the function chain. If needed, it will be used. Otherwise, ignored
"""


Expand Down Expand Up @@ -82,7 +82,7 @@ def truncate_dataset(df, key_to_id, frame_size, env, reduce_items_to, sort_users
Truncate #items to num_items provided in the arguments
"""

# here n items to keep are adjusted
# here are adjusted n items to keep
num_items = reduce_items_to

to_remove = df['movieId'].value_counts().sort_values()[:-num_items].index
Expand Down
34 changes: 25 additions & 9 deletions recnn/nn/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ def __init__(self, input_dim, action_dim, hidden_size, init_w=0):
self.rewards = []
self.correction = []
self.lambda_k = []

# What's action source? See this issue: https://github.com/awarebayes/RecNN/issues/7
# by default {pi: pi, beta: beta}
# you can change it to be like {pi: beta, beta: beta} as miracle24 suggested

self.action_source = {'pi': 'pi', 'beta': 'beta'}
self.select_action = self._select_action

def forward(self, inputs):
Expand All @@ -98,25 +104,35 @@ def gc(self):
del self.lambda_k[:]

def _select_action(self, state, **kwargs):
probs = self.forward(state)
m = Categorical(probs)
action = m.sample()
self.saved_log_probs.append(m.log_prob(action))
return probs

# for reinforce without correction only pi_probs is available.
# the action source is ignored, since there is no beta

pi_probs = self.forward(state)
pi_categorical = Categorical(pi_probs)
pi_action = pi_categorical.sample()
self.saved_log_probs.append(pi_categorical.log_prob(pi_action))
return pi_probs

def pi_beta_sample(self, state, beta, action, **kwargs):
# 1. obtain probabilities
# note: detach is to block gradient
beta_probs = beta(state.detach(), action=action)
pi_probs = self.forward(state)

# 2. probabilities -> categorical distribution
# 2. probabilities -> categorical distribution.
beta_categorical = Categorical(beta_probs)
pi_categorical = Categorical(pi_probs)

# 3. sample actions
beta_action = beta_categorical.sample()
pi_action = pi_categorical.sample()
# 3. sample the actions
# See this issue: https://github.com/awarebayes/RecNN/issues/7
# usually it works like:
# pi_action = pi_categorical.sample(); beta_action = beta_categorical.sample();
# but changing the action_source to {pi: beta, beta: beta} can be configured to be:
# pi_action = beta_categorical.sample(); beta_action = beta_categorical.sample();
available_actions = {'pi': pi_categorical.sample(), 'beta': beta_categorical.sample()}
pi_action = available_actions[self.action_source['pi']]
beta_action = available_actions[self.action_source['beta']]

# 4. calculate stuff we need
pi_log_prob = pi_categorical.log_prob(pi_action)
Expand Down
4 changes: 2 additions & 2 deletions recnn/nn/update/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def reinforce_update(batch, params, nets, optimizer,
debug=None, writer=utils.DummyWriter(),
learn=True, step=-1):

# Due no its mechanics, reinforce doesn't support testing!
# Due to its mechanics, reinforce doesn't support testing!
learn = True

state, action, reward, next_state, done = data.get_base_batch(batch)
Expand All @@ -92,4 +92,4 @@ def reinforce_update(batch, params, nets, optimizer,

utils.write_losses(writer, losses, kind='train' if learn else 'test')

return losses
return losses

0 comments on commit 0945ec0

Please sign in to comment.