Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding a fully connected visual encoder for super small visual input + tests #5351

Merged
merged 5 commits into from
May 11, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions ml-agents/mlagents/trainers/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def as_dict(self):


class EncoderType(Enum):
FULLY_CONNECTED = "fully_connected"
MATCH3 = "match3"
SIMPLE = "simple"
NATURE_CNN = "nature_cnn"
Expand Down
11 changes: 10 additions & 1 deletion ml-agents/mlagents/trainers/tests/torch/test_encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from mlagents.trainers.torch.encoders import (
VectorInput,
Normalizer,
SmallVisualEncoder,
FullyConnectedVisualEncoder,
SimpleVisualEncoder,
ResNetVisualEncoder,
NatureVisualEncoder,
Expand Down Expand Up @@ -73,7 +75,14 @@ def test_vector_encoder(mock_normalizer):

@pytest.mark.parametrize("image_size", [(36, 36, 3), (84, 84, 4), (256, 256, 5)])
@pytest.mark.parametrize(
"vis_class", [SimpleVisualEncoder, ResNetVisualEncoder, NatureVisualEncoder]
"vis_class",
[
SimpleVisualEncoder,
ResNetVisualEncoder,
NatureVisualEncoder,
SmallVisualEncoder,
FullyConnectedVisualEncoder,
],
)
def test_visual_encoder(vis_class, image_size):
num_outputs = 128
Expand Down
24 changes: 24 additions & 0 deletions ml-agents/mlagents/trainers/torch/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,30 @@ def update_normalization(self, inputs: torch.Tensor) -> None:
self.normalizer.update(inputs)


class FullyConnectedVisualEncoder(nn.Module):
def __init__(
self, height: int, width: int, initial_channels: int, output_size: int
):
super().__init__()
self.output_size = output_size
self.input_size = height * width * initial_channels
self.dense = nn.Sequential(
linear_layer(
self.input_size,
self.output_size,
kernel_init=Initialization.KaimingHeNormal,
kernel_gain=1.41, # Use ReLU gain
),
nn.LeakyReLU(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could use another activation here, but I think ReLU is enough.

)

def forward(self, visual_obs: torch.Tensor) -> torch.Tensor:
if not exporting_to_onnx.is_exporting():
visual_obs = visual_obs.permute([0, 3, 1, 2])
hidden = visual_obs.reshape(-1, self.input_size)
return self.dense(hidden)


class SmallVisualEncoder(nn.Module):
"""
CNN architecture used by King in their Candy Crush predictor
Expand Down
3 changes: 3 additions & 0 deletions ml-agents/mlagents/trainers/torch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
ResNetVisualEncoder,
NatureVisualEncoder,
SmallVisualEncoder,
FullyConnectedVisualEncoder,
VectorInput,
)
from mlagents.trainers.settings import EncoderType, ScheduleType
Expand All @@ -20,6 +21,7 @@ class ModelUtils:
# Minimum supported side for each encoder type. If refactoring an encoder, please
# adjust these also.
MIN_RESOLUTION_FOR_ENCODER = {
EncoderType.FULLY_CONNECTED: 0,
EncoderType.MATCH3: 5,
EncoderType.SIMPLE: 20,
EncoderType.NATURE_CNN: 36,
Expand Down Expand Up @@ -123,6 +125,7 @@ def get_encoder_for_type(encoder_type: EncoderType) -> nn.Module:
EncoderType.NATURE_CNN: NatureVisualEncoder,
EncoderType.RESNET: ResNetVisualEncoder,
EncoderType.MATCH3: SmallVisualEncoder,
EncoderType.FULLY_CONNECTED: FullyConnectedVisualEncoder,
}
return ENCODER_FUNCTION_BY_TYPE.get(encoder_type)

Expand Down