Skip to content

Commit

Permalink
Merge pull request #94 from f0k/reshape-layer
Browse files Browse the repository at this point in the history
Add ReshapeLayer
  • Loading branch information
benanne committed Feb 8, 2015
2 parents 62ac92f + cd34882 commit ab058fa
Show file tree
Hide file tree
Showing 2 changed files with 199 additions and 1 deletion.
114 changes: 113 additions & 1 deletion lasagne/layers/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
__all__ = [
"FlattenLayer",
"flatten",
"ReshapeLayer",
"reshape",
"PadLayer",
"pad",
]
Expand All @@ -26,6 +28,116 @@ def get_output_for(self, input, *args, **kwargs):
flatten = FlattenLayer # shortcut


class ReshapeLayer(Layer):
"""
A layer reshaping its input tensor to another tensor of the same total
number of elements.
:parameters:
- incoming : a :class:`Layer` instance or a tuple
the layer feeding into this layer, or the expected input shape
- shape : tuple
The target shape specification. Any of its elements can be `[i]`,
a single-element list of int, denoting to use the size of the ith
input dimension. At most one element can be `-1`, denoting to
infer the size for this dimension to match the total number of
elements of the input tensor. Any remaining elements must be
positive integers directly giving the size of the corresponding
dimension.
:usage:
>>> from lasagne.layers import InputLayer, ReshapeLayer
>>> l_in = InputLayer((None, 100, 20))
>>> l1 = ReshapeLayer(l_in, ([0], [1], 2, 10))
>>> l1.get_output_shape()
(None, 100, 2, 10)
>>> l2 = ReshapeLayer(l_in, ([0], 1, 2, 5, -1))
>>> l2.get_output_shape()
(None, 1, 2, 5, 200)
:note:
The tensor elements will be fetched and placed in C-like order. That
is, reshaping `[1,2,3,4,5,6]` to shape `(2,3)` will result in a matrix
`[[1,2,3],[4,5,6]]`, not in `[[1,3,5],[2,4,6]]` (Fortran-like order),
regardless of the memory layout of the input tensor. For C-contiguous
input, reshaping is cheap, for others it may require copying the data.
"""

def __init__(self, incoming, shape):
super(ReshapeLayer, self).__init__(incoming)
shape = tuple(shape)
for s in shape:
if isinstance(s, int):
if s == 0 or s < - 1:
raise ValueError("`shape` integers must be positive or -1")
elif isinstance(s, list):
if len(s) != 1 or not isinstance(s[0], int) or s[0] < 0:
raise ValueError("`shape` input references must be "
"single-element lists of int >= 0")
else:
raise ValueError("`shape` must be a tuple of int and/or [int]")
if sum(s == -1 for s in shape) > 1:
raise ValueError("`shape` cannot contain multiple -1")
self.shape = shape

def get_output_shape_for(self, input_shape, *args, **kwargs):
# Initialize output shape from shape specification
output_shape = list(self.shape)
# First, replace all `[i]` with the corresponding input dimension, and
# mask parts of the shapes thus becoming irrelevant for -1 inference
masked_input_shape = list(input_shape)
masked_output_shape = list(output_shape)
for dim, o in enumerate(output_shape):
if isinstance(o, list):
if o[0] >= len(input_shape):
raise ValueError("specification contains [%d], but input "
"shape has %d dimensions only" %
(o[0], len(input_shape)))
output_shape[dim] = input_shape[o[0]]
masked_output_shape[dim] = input_shape[o[0]]
if ((input_shape[o[0]] is None)
and (masked_input_shape[o[0]] is None)):
# first time we copied this unknown input size: mask it, we
# have a 1:1 correspondence between out[dim] and in[o[0]]
# and can ignore it for -1 inference even if it is unknown.
masked_input_shape[o[0]] = 1
masked_output_shape[dim] = 1
# From the shapes, compute the sizes of the input and output tensor
noneprod = lambda x, y: None if x is None or y is None else x * y
input_size = reduce(noneprod, masked_input_shape)
output_size = reduce(noneprod, masked_output_shape)
del masked_input_shape, masked_output_shape
# Finally, infer value for -1 if needed
if -1 in output_shape:
dim = output_shape.index(-1)
if (input_size is None) or (output_size is None):
output_shape[dim] = None
output_size = None
else:
output_size *= -1
output_shape[dim] = input_size // output_size
output_size *= output_shape[dim]
# Sanity check
if ((input_size is not None) and (output_size is not None)
and (input_size != output_size)):
raise ValueError("%s cannot be reshaped to specification %s. "
"The total size mismatches." %
(input_shape, self.shape))
return tuple(output_shape)

def get_output_for(self, input, *args, **kwargs):
# Replace all `[i]` with the corresponding input dimension
output_shape = list(self.shape)
for dim, o in enumerate(output_shape):
if isinstance(o, list):
output_shape[dim] = input.shape[o[0]]
# Everything else is handled by Theano
return input.reshape(tuple(output_shape))

reshape = ReshapeLayer # shortcut


class PadLayer(Layer):
def __init__(self, incoming, width, val=0, batch_ndim=2, **kwargs):
super(PadLayer, self).__init__(incoming, **kwargs)
Expand All @@ -46,4 +158,4 @@ def get_output_shape_for(self, input_shape):
def get_output_for(self, input, *args, **kwargs):
return padding.pad(input, self.width, self.val, self.batch_ndim)

pad = PadLayer # shortcut
pad = PadLayer # shortcut
86 changes: 86 additions & 0 deletions lasagne/tests/layers/test_shape.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
from mock import Mock
import numpy
import pytest
import theano


class TestReshapeLayer:
@pytest.fixture
def layerclass(self):
from lasagne.layers.shape import ReshapeLayer
return ReshapeLayer

@pytest.fixture
def two_unknown(self):
from lasagne.layers.input import InputLayer
shape = (16, 3, None, None, 10)
return (InputLayer(shape),
theano.shared(numpy.ones((16, 3, 5, 7, 10))))

def test_no_reference(self, layerclass, two_unknown):
inputlayer, inputdata = two_unknown
layer = layerclass(inputlayer, (16, 3, 5, 7, 2, 5))
assert layer.get_output_shape() == (16, 3, 5, 7, 2, 5)
result = layer.get_output_for(inputdata).eval()
assert result.shape == (16, 3, 5, 7, 2, 5)

def test_reference_both(self, layerclass, two_unknown):
inputlayer, inputdata = two_unknown
layer = layerclass(inputlayer, (-1, [1], [2], [3], 2, 5))
assert layer.get_output_shape() == (16, 3, None, None, 2, 5)
result = layer.get_output_for(inputdata).eval()
assert result.shape == (16, 3, 5, 7, 2, 5)

def test_reference_one(self, layerclass, two_unknown):
inputlayer, inputdata = two_unknown
layer = layerclass(inputlayer, (-1, [1], [2], 7, 2, 5))
assert layer.get_output_shape() == (None, 3, None, 7, 2, 5)
result = layer.get_output_for(inputdata).eval()
assert result.shape == (16, 3, 5, 7, 2, 5)

def test_reference_twice(self, layerclass, two_unknown):
inputlayer, inputdata = two_unknown
layer = layerclass(inputlayer, (-1, [1], [2], [3], 2, [2]))
assert layer.get_output_shape() == (None, 3, None, None, 2, None)
result = layer.get_output_for(inputdata).eval()
assert result.shape == (16, 3, 5, 7, 2, 5)

def test_merge_with_unknown(self, layerclass, two_unknown):
inputlayer, inputdata = two_unknown
layer = layerclass(inputlayer, ([0], [1], [2], -1))
assert layer.get_output_shape() == (16, 3, None, None)
result = layer.get_output_for(inputdata).eval()
assert result.shape == (16, 3, 5, 70)

def test_merge_two_unknowns(self, layerclass, two_unknown):
inputlayer, inputdata = two_unknown
layer = layerclass(inputlayer, ([0], [1], -1, [4]))
assert layer.get_output_shape() == (16, 3, None, 10)
result = layer.get_output_for(inputdata).eval()
assert result.shape == (16, 3, 35, 10)

def test_size_mismatch(self, layerclass, two_unknown):
inputlayer, inputdata = two_unknown
layer = layerclass(inputlayer, (17, 3, [2], [3], -1))
with pytest.raises(ValueError) as excinfo:
layer.get_output_shape() == (16, 3, None, 10)
assert 'match' in str(excinfo.value)

def test_invalid_spec(self, layerclass, two_unknown):
inputlayer, inputdata = two_unknown
with pytest.raises(ValueError):
layerclass(inputlayer, (-16, 3, 5, 7, 10))
with pytest.raises(ValueError):
layerclass(inputlayer, (-1, 3, 5, 7, -1))
with pytest.raises(ValueError):
layerclass(inputlayer, ([-1], 3, 5, 7, 10))
with pytest.raises(ValueError):
layerclass(inputlayer, ([0, 1], 3, 5, 7, 10))
with pytest.raises(ValueError):
layerclass(inputlayer, (None, 3, 5, 7, 10))

def test_reference_out_of_range(self, layerclass, two_unknown):
inputlayer, inputdata = two_unknown
layer = layerclass(inputlayer, (16, 3, 5, 7, [5]))
with pytest.raises(ValueError):
layer.get_output_for(inputdata)

0 comments on commit ab058fa

Please sign in to comment.