Skip to content

Commit

Permalink
Cleanup visual obs setup (#2647)
Browse files Browse the repository at this point in the history
* DRY up the setup code

* fstrings
  • Loading branch information
Chris Elion committed Oct 1, 2019
1 parent fbc8857 commit d873d16
Showing 1 changed file with 28 additions and 42 deletions.
70 changes: 28 additions & 42 deletions ml-agents/mlagents/trainers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,8 @@ def create_vector_observation_encoder(
)
return hidden

@staticmethod
def create_visual_observation_encoder(
self,
image_input: tf.Tensor,
h_size: int,
activation: ActivationFunction,
Expand Down Expand Up @@ -288,13 +288,13 @@ def create_visual_observation_encoder(
hidden = c_layers.flatten(conv2)

with tf.variable_scope(scope + "/" + "flat_encoding"):
hidden_flat = self.create_vector_observation_encoder(
hidden_flat = LearningModel.create_vector_observation_encoder(
hidden, h_size, activation, num_layers, scope, reuse
)
return hidden_flat

@staticmethod
def create_nature_cnn_visual_observation_encoder(
self,
image_input: tf.Tensor,
h_size: int,
activation: ActivationFunction,
Expand Down Expand Up @@ -343,13 +343,13 @@ def create_nature_cnn_visual_observation_encoder(
hidden = c_layers.flatten(conv3)

with tf.variable_scope(scope + "/" + "flat_encoding"):
hidden_flat = self.create_vector_observation_encoder(
hidden_flat = LearningModel.create_vector_observation_encoder(
hidden, h_size, activation, num_layers, scope, reuse
)
return hidden_flat

@staticmethod
def create_resnet_visual_observation_encoder(
self,
image_input: tf.Tensor,
h_size: int,
activation: ActivationFunction,
Expand Down Expand Up @@ -411,7 +411,7 @@ def create_resnet_visual_observation_encoder(
hidden = c_layers.flatten(hidden)

with tf.variable_scope(scope + "/" + "flat_encoding"):
hidden_flat = self.create_vector_observation_encoder(
hidden_flat = LearningModel.create_vector_observation_encoder(
hidden, h_size, activation, num_layers, scope, reuse
)
return hidden_flat
Expand Down Expand Up @@ -470,7 +470,7 @@ def create_observation_streams(
num_layers: int,
vis_encode_type: EncoderType = EncoderType.SIMPLE,
stream_scopes: List[str] = None,
) -> tf.Tensor:
) -> List[tf.Tensor]:
"""
Creates encoding stream for observations.
:param num_streams: Number of streams to create.
Expand All @@ -491,54 +491,40 @@ def create_observation_streams(
self.visual_in.append(visual_input)
vector_observation_input = self.create_vector_input()

# Pick the encoder function based on the EncoderType
create_encoder_func = LearningModel.create_visual_observation_encoder
if vis_encode_type == EncoderType.RESNET:
create_encoder_func = LearningModel.create_resnet_visual_observation_encoder
elif vis_encode_type == EncoderType.NATURE_CNN:
create_encoder_func = (
LearningModel.create_nature_cnn_visual_observation_encoder
)

final_hiddens = []
for i in range(num_streams):
visual_encoders = []
hidden_state, hidden_visual = None, None
_scope_add = stream_scopes[i] if stream_scopes else ""
if self.vis_obs_size > 0:
if vis_encode_type == EncoderType.RESNET:
for j in range(brain.number_visual_observations):
encoded_visual = self.create_resnet_visual_observation_encoder(
self.visual_in[j],
h_size,
activation_fn,
num_layers,
_scope_add + "main_graph_{}_encoder{}".format(i, j),
False,
)
visual_encoders.append(encoded_visual)
elif vis_encode_type == EncoderType.NATURE_CNN:
for j in range(brain.number_visual_observations):
encoded_visual = self.create_nature_cnn_visual_observation_encoder(
self.visual_in[j],
h_size,
activation_fn,
num_layers,
_scope_add + "main_graph_{}_encoder{}".format(i, j),
False,
)
visual_encoders.append(encoded_visual)
else:
for j in range(brain.number_visual_observations):
encoded_visual = self.create_visual_observation_encoder(
self.visual_in[j],
h_size,
activation_fn,
num_layers,
_scope_add + "main_graph_{}_encoder{}".format(i, j),
False,
)
visual_encoders.append(encoded_visual)
for j in range(brain.number_visual_observations):
encoded_visual = create_encoder_func(
self.visual_in[j],
h_size,
activation_fn,
num_layers,
scope=f"{_scope_add}main_graph_{i}_encoder{j}",
reuse=False,
)
visual_encoders.append(encoded_visual)
hidden_visual = tf.concat(visual_encoders, axis=1)
if brain.vector_observation_space_size > 0:
hidden_state = self.create_vector_observation_encoder(
vector_observation_input,
h_size,
activation_fn,
num_layers,
_scope_add + "main_graph_{}".format(i),
False,
scope=f"{_scope_add}main_graph_{i}",
reuse=False,
)
if hidden_state is not None and hidden_visual is not None:
final_hidden = tf.concat([hidden_visual, hidden_state], axis=1)
Expand Down

0 comments on commit d873d16

Please sign in to comment.