Skip to content
This repository has been archived by the owner on Dec 11, 2022. It is now read-only.

Commit

Permalink
Channel order transpose, for image embedder. Updated unit test. (#87)
Browse files Browse the repository at this point in the history
  • Loading branch information
thomelane authored and galnov committed Nov 19, 2018
1 parent ff816b3 commit 7ba1a43
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
Expand Up @@ -70,7 +70,6 @@ def hybrid_forward(self, F: ModuleType, x: nd_sym_type, *args, **kwargs) -> nd_s
:param x: image representing environment state, of shape (batch_size, in_channels, height, width).
:return: embedding of environment state, of shape (batch_size, channels).
"""
if len(x.shape) != 4 and self.scheme != EmbedderScheme.Empty:
raise ValueError("Image embedders expect the input size to have 4 dimensions. The given size is: {}"
.format(x.shape))
# convert from NHWC to NCHW (default for MXNet Convolutions)
x = x.transpose((0,3,1,2))
return super(ImageEmbedder, self).hybrid_forward(F, x, *args, **kwargs)
Expand Up @@ -15,7 +15,8 @@ def test_image_embedder():
params = InputEmbedderParameters(scheme=EmbedderScheme.Medium)
emb = ImageEmbedder(params=params)
emb.initialize()
input_data = mx.nd.random.uniform(low=0, high=1, shape=(10, 3, 244, 244))
# input is NHWC, and not MXNet default NCHW
input_data = mx.nd.random.uniform(low=0, high=1, shape=(10, 244, 244, 3))
output = emb(input_data)
assert len(output.shape) == 2 # since last block was flatten
assert output.shape[0] == 10 # since batch_size is 10

0 comments on commit 7ba1a43

Please sign in to comment.