Skip to content

Commit

Permalink
attention network work
Browse files Browse the repository at this point in the history
  • Loading branch information
Adam Green committed Jun 3, 2021
1 parent fb6973b commit d3a7752
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
4 changes: 2 additions & 2 deletions energypy/agent/qfunc.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def make_qfunc(obs_shape, n_actions, name, size_scale=1):
# if hyp.get('q-net') == 'attention':
# TODO
if False:
inputs, net = attention(obs_shape, 1, size_scale)
_, net = attention(obs_shape, 1, size_scale)
else:
inputs, net = dense(inputs, 1, size_scale)
_, net = dense(inputs, 1, size_scale)

return keras.Model(
inputs=[in_obs, in_act],
Expand Down
37 changes: 36 additions & 1 deletion energypy/train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
from energypy import utils, qfunc, policy, target, alpha


def train(
def train(*args, **kwargs):
if 'network' in kwargs.keys():
raise NotImplementedError()
else:
return train_one_head_network(*args, **kwargs)


def train_one_head_network(
batch,
actor,
onlines,
Expand Down Expand Up @@ -61,3 +68,31 @@ def train(
counters['alpha-update-seconds'] += utils.now() - st
counters['train-seconds'] += utils.now() - st
counters['train-steps'] += 1


def train_multi_head_network(
batch,
network,
log_alpha,
writer,
optimizers,
counters,
hyp
):
st = utils.now()
# train net



# maybe do the actor fwd pass here...
st = utils.now()
alpha.update(
batch,
actor,
log_alpha,
hyp,
optimizers['alpha'],
counters,
writer
)
counters['alpha-update-seconds'] += utils.now() - st

0 comments on commit d3a7752

Please sign in to comment.