Skip to content

Commit

Permalink
Rename function
Browse files Browse the repository at this point in the history
  • Loading branch information
awjuliani committed Jun 22, 2018
1 parent 5424a46 commit 0022854
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 13 deletions.
10 changes: 5 additions & 5 deletions python/unitytrainers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ def create_normalizer_update(self, vector_input):
return update_mean, update_variance

@staticmethod
def create_continuous_observation_encoder(observation_input, h_size, activation, num_layers, scope, reuse):
def create_vector_observation_encoder(observation_input, h_size, activation, num_layers, scope, reuse):
"""
Builds a set of hidden state encoders.
:param reuse: Whether to re-use the weights within the same scope.
Expand Down Expand Up @@ -128,8 +128,8 @@ def create_visual_observation_encoder(self, image_input, h_size, activation, num
hidden = c_layers.flatten(conv2)

with tf.variable_scope(scope+'/'+'flat_encoding'):
hidden_flat = self.create_continuous_observation_encoder(hidden, h_size, activation,
num_layers, scope, reuse)
hidden_flat = self.create_vector_observation_encoder(hidden, h_size, activation,
num_layers, scope, reuse)
return hidden_flat

def create_observation_streams(self, num_streams, h_size, num_layers):
Expand Down Expand Up @@ -165,8 +165,8 @@ def create_observation_streams(self, num_streams, h_size, num_layers):
visual_encoders.append(encoded_visual)
hidden_visual = tf.concat(visual_encoders, axis=1)
if brain.vector_observation_space_size > 0:
hidden_state = self.create_continuous_observation_encoder(vector_observation_input,
h_size, activation_fn, num_layers,
hidden_state = self.create_vector_observation_encoder(vector_observation_input,
h_size, activation_fn, num_layers,
"main_graph_{}".format(i), False)
if hidden_state is not None and hidden_visual is not None:
final_hidden = tf.concat([hidden_visual, hidden_state], axis=1)
Expand Down
16 changes: 8 additions & 8 deletions python/unitytrainers/ppo/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,15 +96,15 @@ def create_curiosity_encoders(self):
self.next_vector_in = tf.placeholder(shape=[None, self.o_size], dtype=tf.float32,
name='next_vector_observation')

encoded_vector_obs = self.create_continuous_observation_encoder(self.vector_in,
self.curiosity_enc_size,
self.swish, 2, "vector_obs_encoder",
False)
encoded_next_vector_obs = self.create_continuous_observation_encoder(self.next_vector_in,
self.curiosity_enc_size,
self.swish, 2,
encoded_vector_obs = self.create_vector_observation_encoder(self.vector_in,
self.curiosity_enc_size,
self.swish, 2, "vector_obs_encoder",
False)
encoded_next_vector_obs = self.create_vector_observation_encoder(self.next_vector_in,
self.curiosity_enc_size,
self.swish, 2,
"vector_obs_encoder",
True)
True)
encoded_state_list.append(encoded_vector_obs)
encoded_next_state_list.append(encoded_next_vector_obs)

Expand Down

0 comments on commit 0022854

Please sign in to comment.