From 7ba1a4393ff563beceb420b1ed870e3e016140f6 Mon Sep 17 00:00:00 2001 From: Thom Lane Date: Mon, 19 Nov 2018 05:39:03 -0800 Subject: [PATCH] Channel order transpose, for image embedder. Updated unit test. (#87) --- .../mxnet_components/embedders/image_embedder.py | 5 ++--- .../mxnet_components/embedders/test_image_embedder.py | 3 ++- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/rl_coach/architectures/mxnet_components/embedders/image_embedder.py b/rl_coach/architectures/mxnet_components/embedders/image_embedder.py index 36842d8f5..bbfddbaf6 100644 --- a/rl_coach/architectures/mxnet_components/embedders/image_embedder.py +++ b/rl_coach/architectures/mxnet_components/embedders/image_embedder.py @@ -70,7 +70,6 @@ def hybrid_forward(self, F: ModuleType, x: nd_sym_type, *args, **kwargs) -> nd_s :param x: image representing environment state, of shape (batch_size, in_channels, height, width). :return: embedding of environment state, of shape (batch_size, channels). """ - if len(x.shape) != 4 and self.scheme != EmbedderScheme.Empty: - raise ValueError("Image embedders expect the input size to have 4 dimensions. The given size is: {}" - .format(x.shape)) + # convert from NHWC to NCHW (default for MXNet Convolutions) + x = x.transpose((0,3,1,2)) return super(ImageEmbedder, self).hybrid_forward(F, x, *args, **kwargs) diff --git a/rl_coach/tests/architectures/mxnet_components/embedders/test_image_embedder.py b/rl_coach/tests/architectures/mxnet_components/embedders/test_image_embedder.py index 0e7f9da53..9fecff856 100644 --- a/rl_coach/tests/architectures/mxnet_components/embedders/test_image_embedder.py +++ b/rl_coach/tests/architectures/mxnet_components/embedders/test_image_embedder.py @@ -15,7 +15,8 @@ def test_image_embedder(): params = InputEmbedderParameters(scheme=EmbedderScheme.Medium) emb = ImageEmbedder(params=params) emb.initialize() - input_data = mx.nd.random.uniform(low=0, high=1, shape=(10, 3, 244, 244)) + # input is NHWC, and not MXNet default NCHW + input_data = mx.nd.random.uniform(low=0, high=1, shape=(10, 244, 244, 3)) output = emb(input_data) assert len(output.shape) == 2 # since last block was flatten assert output.shape[0] == 10 # since batch_size is 10