Skip to content

Commit

Permalink
reinforce without correction added for testing
Browse files Browse the repository at this point in the history
  • Loading branch information
awarebayes committed Nov 20, 2019
1 parent 8119df9 commit 520ec28
Show file tree
Hide file tree
Showing 8 changed files with 253 additions and 95 deletions.
3 changes: 3 additions & 0 deletions docs/source/data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ dataset_functions
What?
+++++

Chain of responsibility pattern:
refactoring.guru/design-patterns/chain-of-responsibility/python/example

RecNN is designed to work with your dataflow.
Function that contain 'dataset' are needed to interact with environment.
The environment is provided via env.argument.
Expand Down
3 changes: 2 additions & 1 deletion docs/source/examples/your_data.rst
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ Here is how default ML20M dataset is processed. Use this as a reference::

Although not required, it is advised that you return all of the arguments + kwargs. If the function is finishing
this may work fine, but if you are using **build_data_pipeline**, you need to do it as I said. Look in
reference/data/dataset_functions for further details.
reference/data/dataset_functions for further details. Chain of responsibility pattern:
refactoring.guru/design-patterns/chain-of-responsibility/python/example

Toy Dataset
+++++++++++
Expand Down

Large diffs are not rendered by default.

Large diffs are not rendered by default.

25 changes: 14 additions & 11 deletions recnn/data/dataset_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
What?
+++++
Chain of responsibility pattern.
https://refactoring.guru/design-patterns/chain-of-responsibility/python/example
RecNN is designed to work with your dataflow.
Function that contain 'dataset' are needed to interact with environment.
The environment is provided via env.argument.
Expand Down Expand Up @@ -46,6 +49,8 @@ def prepare_dataset(df, key_to_id, frame_size, env, sort_users=False, **kwargs):
[1, 34, 123, 2000], recnn makes it look like [0,1,2,3] for you.
"""

key_to_id = env.key_to_id

df['rating'] = df['rating'].progress_apply(lambda i: 2 * (i - 2.5))
df['movieId'] = df['movieId'].progress_apply(key_to_id.get)
users = df[['userId', 'movieId']].groupby(['userId']).size()
Expand All @@ -70,21 +75,19 @@ def app(x):
env.users = users

return {'df': df, 'key_to_id': key_to_id,
'frame_size': frame_size, 'env': env, 'sort_users': sort_users,
**kwargs}
'frame_size': frame_size, 'env': env, 'sort_users': sort_users, **kwargs}


def truncate_dataset(df, key_to_id, frame_size, env, reduce_items_to, sort_users=False, **kwargs):
"""
Truncate #items to num_items provided in the arguments
"""

# here n items to keep are adjusted
num_items = reduce_items_to

value_counts = df['movieId'].value_counts().sort_values()

to_remove = value_counts[:-num_items].index
to_keep = value_counts[-num_items:].index
to_remove = df['movieId'].value_counts().sort_values()[:-num_items].index
to_keep = df['movieId'].value_counts().sort_values()[-num_items:].index
to_remove_indices = df[df['movieId'].isin(to_remove)].index
num_removed = len(to_remove)

Expand All @@ -95,24 +98,24 @@ def truncate_dataset(df, key_to_id, frame_size, env, reduce_items_to, sort_users
del env.movie_embeddings_key_dict[i]

env.embeddings, env.key_to_id, env.id_to_key = make_items_tensor(env.movie_embeddings_key_dict)

print('action space is reduced to {} - {} = {}'.format(num_items + num_removed, num_removed,
num_items))

return {'df': df, 'key_to_id': key_to_id,
'frame_size': frame_size, 'env': env, 'sort_users': sort_users,
'reduce_items_to': reduce_items_to, **kwargs}
return {'df': df, 'key_to_id': env.key_to_id, 'env': env,
'frame_size': frame_size, 'sort_users': sort_users, **kwargs}


def build_data_pipeline(chain, **kwargs):
"""
curry function chain
Chain of responsibility pattern
:param chain: array of callable
:param **kwargs: any kwargs you like
"""

kwargdict = kwargs
for call in chain:
kwargdict = call(**kwargs)
kwargdict = call(**kwargdict)
return kwargdict

4 changes: 2 additions & 2 deletions recnn/nn/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ def __init__(self, policy_net, value_net1, value_net2):
class Reinforce(Algo):
def __init__(self, policy_net, value_net):

super(Algo, self).__init__()
super(Reinforce, self).__init__()

self.algorithm = update.reinforce_update

Expand Down Expand Up @@ -203,7 +203,7 @@ def __init__(self, policy_net, value_net):
'value_optimizer': value_optimizer
}

params = {
self.params = {
'reinforce': ChooseREINFORCE(ChooseREINFORCE.basic_reinforce),
'gamma': 0.99,
'min_value': -10,
Expand Down
12 changes: 11 additions & 1 deletion recnn/nn/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def forward(self, state, tanh=False):


class DiscreteActor(nn.Module):
def __init__(self, input_dim, action_dim, hidden_size, init_w=2e-1):
def __init__(self, input_dim, action_dim, hidden_size, init_w=0):
super(DiscreteActor, self).__init__()

self.linear1 = nn.Linear(input_dim, hidden_size)
Expand All @@ -82,13 +82,23 @@ def __init__(self, input_dim, action_dim, hidden_size, init_w=2e-1):
self.saved_log_probs = []
self.rewards = []

# with large action spaces it can be overflowed
# in order to prevent this, I set a max limit

self.save_limit = 15

def forward(self, inputs):
x = inputs
x = F.relu(self.linear1(x))
action_scores = self.linear2(x)
return F.softmax(action_scores)

def select_action(self, state):

if len(self.saved_log_probs) > self.save_limit:
del self.saved_log_probs[:]
del self.rewards[:]

probs = self.forward(state)
m = Categorical(probs)
action = m.sample()
Expand Down
33 changes: 19 additions & 14 deletions recnn/nn/update/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
from recnn import utils
from recnn import data
from recnn.utils import soft_update

from recnn.nn.update import value_update
import gc


class ChooseREINFORCE:
Expand All @@ -25,8 +25,7 @@ def basic_reinforce(policy, returns, *args, **kwargs):
def reinforce_with_correction():
raise NotImplemented

def __call__(self, policy, optimizer):

def __call__(self, policy, optimizer, learn=True):
R = 0

returns = []
Expand All @@ -39,9 +38,10 @@ def __call__(self, policy, optimizer):

policy_loss = self.method(policy, returns)

optimizer.zero_grad()
policy_loss.backward()
optimizer.step()
if learn:
optimizer.zero_grad()
policy_loss.backward()
optimizer.step()

del policy.rewards[:]
del policy.saved_log_probs[:]
Expand All @@ -51,7 +51,7 @@ def __call__(self, policy, optimizer):

def reinforce_update(batch, params, nets, optimizer,
device=torch.device('cpu'),
debug=None, writer= utils.DummyWriter(),
debug=None, writer=utils.DummyWriter(),
learn=False, step=-1):
state, action, reward, next_state, done = data.get_base_batch(batch)

Expand All @@ -60,17 +60,22 @@ def reinforce_update(batch, params, nets, optimizer,
nets['policy_net'].rewards.append(reward.mean())

value_loss = value_update(batch, params, nets, optimizer,
writer=writer, device=device,
debug=debug, learn=learn, step=step)
writer=writer,
device=device,
debug=debug, learn=learn, step=step)

if len(nets['policy_net'].saved_log_probs) > params['policy_step'] and learn:
policy_loss = params['reinforce'](nets['policy_net'], optimizer['policy_optimizer'], learn=learn)

print('step: ', step, '| value:', value_loss.item(), '| policy', policy_loss.item())

utils.soft_update(nets['value_net'], nets['target_value_net'], soft_tau=params['soft_tau'])
utils.soft_update(nets['policy_net'], nets['target_policy_net'], soft_tau=params['soft_tau'])

if step % params['policy_step'] == 0 and step > 0:
policy_loss = params['reinforce'](nets['policy_net'], optimizer['policy_optimizer'])
del nets['policy_net'].rewards[:]
del nets['policy_net'].saved_log_probs[:]
print('step: ', step, '| value:', value_loss.item(), '| policy', policy_loss.item())

soft_update(nets['value_net'], nets['target_value_net'], soft_tau=params['soft_tau'])
soft_update(nets['policy_net'], nets['target_policy_net'], soft_tau=params['soft_tau'])
gc.collect()

losses = {'value': value_loss.item(),
'policy': policy_loss.item(),
Expand Down

0 comments on commit 520ec28

Please sign in to comment.