-
Notifications
You must be signed in to change notification settings - Fork 4.1k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add different types of visual encoder (nature cnn/resnet) #2289
Conversation
Noticed this doesn't actually hook the different encoder types up to any config. Is that intended to be a separate PR / planned at all? BTW, also please make sure to make a descriptive commit message and PR description (they can be the same) for this. |
Added config option for encoder type in trainer config |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Minor change request on removing the quotes in yaml. This occurs both in tests and in the main trainer_config
. Otherwise LGTM!
@@ -302,7 +434,9 @@ def create_discrete_action_masking_layer(all_logits, action_masks, action_size): | |||
), | |||
) | |||
|
|||
def create_observation_streams(self, num_streams, h_size, num_layers): | |||
def create_observation_streams( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add type annotations here.
@@ -401,13 +559,15 @@ def create_value_heads(self, stream_names, hidden_input): | |||
self.value_heads[name] = value | |||
self.value = tf.reduce_mean(list(self.value_heads.values()), 0) | |||
|
|||
def create_cc_actor_critic(self, h_size, num_layers): | |||
def create_cc_actor_critic(self, h_size, num_layers, vis_encode_type): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
type annotations
@@ -44,6 +44,7 @@ def __init__(self, seed, brain, trainer_params, is_training, load): | |||
m_size=self.m_size, | |||
seed=seed, | |||
stream_names=list(reward_signal_configs.keys()), | |||
vis_encode_type=trainer_params["vis_encode_type"], |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Convert to EncoderType enum here and pass that through instead of the string type. Handle a missing value with .get(). For example
vis_encode_type = EncoderType(trainer_params.get("vis_encode_type", "default"))
@dongruoping can you make sure to address the comments here from @chriselion in a following PR? |
Add different types of visual encoder (nature cnn/resnet)