From b2113b3b0a0fd0fedf0aea152b2e8200575c9bd0 Mon Sep 17 00:00:00 2001 From: skaae Date: Tue, 29 Dec 2015 19:47:26 +0100 Subject: [PATCH] remove Tensorvariable init --- lasagne/layers/recurrent.py | 102 +++++--------------- lasagne/tests/layers/test_recurrent.py | 124 ------------------------- 2 files changed, 22 insertions(+), 204 deletions(-) diff --git a/lasagne/layers/recurrent.py b/lasagne/layers/recurrent.py index e3d3f695..ee1e49c5 100644 --- a/lasagne/layers/recurrent.py +++ b/lasagne/layers/recurrent.py @@ -121,18 +121,14 @@ class CustomRecurrentLayer(MergeLayer): nonlinearity : callable or None Nonlinearity to apply when computing new state (:math:`\sigma`). If None is provided, no nonlinearity will be applied. - hid_init : callable, np.ndarray, theano.shared, TensorVariable or Layer - Initializer for initial hidden state (:math:`h_0`). If a - TensorVariable (Theano expression) is supplied, it will not be learned - regardless of the value of `learn_init`. + hid_init : callable, np.ndarray, theano.shared or :class:`Layer` + Initializer for initial hidden state (:math:`h_0`). backwards : bool If True, process the sequence backwards and then reverse the output again such that the output from the layer is always from :math:`x_1` to :math:`x_n`. learn_init : bool - If True, initial hidden values are learned. If `hid_init` is a - TensorVariable then the TensorVariable is used and - `learn_init` is ignored. + If True, initial hidden values are learned. gradient_steps : int Number of timesteps to include in the backpropagated gradient. If -1, backpropagate through the entire sequence. @@ -315,13 +311,7 @@ def __init__(self, incoming, input_to_hidden, hidden_to_hidden, self.nonlinearity = nonlinearity # Initialize hidden state - if isinstance(hid_init, T.TensorVariable): - if hid_init.ndim != len(hidden_to_hidden.output_shape): - raise ValueError( - "When hid_init is provided as a TensorVariable, it should " - "have the same shape as hidden_to_hidden.output_shape") - self.hid_init = hid_init - elif isinstance(hid_init, Layer): + if isinstance(hid_init, Layer): self.hid_init = hid_init else: self.hid_init = self.add_param( @@ -450,12 +440,7 @@ def step_masked(input_n, mask_n, hid_previous, *args): sequences = input step_fun = step - if isinstance(self.hid_init, Layer): - pass - elif isinstance(self.hid_init, T.TensorVariable): - # When hid_init is provided as a TensorVariable, use it as-is - hid_init = self.hid_init - else: + if not isinstance(self.hid_init, Layer): # The code below simply repeats self.hid_init num_batch times in # its first dimension. Turns out using a dot product and a # dimshuffle is faster than T.repeat. @@ -535,17 +520,14 @@ class RecurrentLayer(CustomRecurrentLayer): nonlinearity : callable or None Nonlinearity to apply when computing new state (:math:`\sigma`). If None is provided, no nonlinearity will be applied. - hid_init : callable, np.ndarray, theano.shared, TensorVariable or Layer - Initializer for initial hidden state (:math:`h_0`). If a - TensorVariable (Theano expression) is supplied, it will not be learned - regardless of the value of `learn_init`. + hid_init : callable, np.ndarray, theano.shared or :class:`Layer` + Initializer for initial hidden state (:math:`h_0`). backwards : bool If True, process the sequence backwards and then reverse the output again such that the output from the layer is always from :math:`x_1` to :math:`x_n`. learn_init : bool - If True, initial hidden values are learned. If `hid_init` is a - TensorVariable then `learn_init` is ignored. + If True, initial hidden values are learned. gradient_steps : int Number of timesteps to include in the backpropagated gradient. If -1, backpropagate through the entire sequence. @@ -748,22 +730,16 @@ class LSTMLayer(MergeLayer): nonlinearity : callable or None The nonlinearity that is applied to the output (:math:`\sigma_h`). If None is provided, no nonlinearity will be applied. - cell_init : callable, np.ndarray, theano.shared, TensorVariable or Layer - Initializer for initial cell state (:math:`c_0`). If a - TensorVariable (Theano expression) is supplied, it will not be learned - regardless of the value of `learn_init`. - hid_init : callable, np.ndarray, theano.shared, TensorVariable or Layer - Initializer for initial hidden state (:math:`h_0`). If a - TensorVariable (Theano expression) is supplied, it will not be learned - regardless of the value of `learn_init`. + cell_init : callable, np.ndarray, theano.shared or :class:`Layer` + Initializer for initial cell state (:math:`c_0`). + hid_init : callable, np.ndarray, theano.shared or :class:`Layer` + Initializer for initial hidden state (:math:`h_0`). backwards : bool If True, process the sequence backwards and then reverse the output again such that the output from the layer is always from :math:`x_1` to :math:`x_n`. learn_init : bool - If True, initial hidden values are learned. If `hid_init` or - `cell_init` are TensorVariables then the TensorVariable is used and - `learn_init` is ignored for that initial state. + If True, initial hidden values are learned. peepholes : bool If True, the LSTM uses peephole connections. When False, `ingate.W_cell`, `forgetgate.W_cell` and @@ -909,26 +885,14 @@ def add_gate_params(gate, gate_name): outgate.W_cell, (num_units, ), name="W_cell_to_outgate") # Setup initial values for the cell and the hidden units - if isinstance(cell_init, T.TensorVariable): - if cell_init.ndim != 2: - raise ValueError( - "When cell_init is provided as a TensorVariable, it should" - " have 2 dimensions and have shape (num_batch, num_units)") - self.cell_init = cell_init - elif isinstance(cell_init, Layer): + if isinstance(cell_init, Layer): self.cell_init = cell_init else: self.cell_init = self.add_param( cell_init, (1, num_units), name="cell_init", trainable=learn_init, regularizable=False) - if isinstance(hid_init, T.TensorVariable): - if hid_init.ndim != 2: - raise ValueError( - "When hid_init is provided as a TensorVariable, it should " - "have 2 dimensions and have shape (num_batch, num_units)") - self.hid_init = hid_init - elif isinstance(hid_init, Layer): + if isinstance(hid_init, Layer): self.hid_init = hid_init else: self.hid_init = self.add_param( @@ -1092,19 +1056,11 @@ def step_masked(input_n, mask_n, cell_previous, hid_previous, *args): step_fun = step ones = T.ones((num_batch, 1)) - if isinstance(self.cell_init, Layer): - pass - elif isinstance(self.cell_init, T.TensorVariable): - cell_init = self.cell_init - else: + if not isinstance(self.cell_init, Layer): # Dot against a 1s vector to repeat to shape (num_batch, num_units) cell_init = T.dot(ones, self.cell_init) - if isinstance(self.hid_init, Layer): - pass - elif isinstance(self.hid_init, T.TensorVariable): - hid_init = self.hid_init - else: + if not isinstance(self.hid_init, Layer): # Dot against a 1s vector to repeat to shape (num_batch, num_units) hid_init = T.dot(ones, self.hid_init) @@ -1196,18 +1152,14 @@ class GRULayer(MergeLayer): hidden_update : Gate Parameters for the hidden update (:math:`c_t`): :math:`W_{xc}`, :math:`W_{hc}`, :math:`b_c`, and :math:`\sigma_c`. - hid_init : callable, np.ndarray, theano.shared, TensorVariable or Layer - Initializer for initial hidden state (:math:`h_0`). If a - TensorVariable (Theano expression) is supplied, it will not be learned - regardless of the value of `learn_init`. + hid_init : callable, np.ndarray, theano.shared or :class:`Layer` + Initializer for initial hidden state (:math:`h_0`). backwards : bool If True, process the sequence backwards and then reverse the output again such that the output from the layer is always from :math:`x_1` to :math:`x_n`. learn_init : bool - If True, initial hidden values are learned. If `hid_init` is a - TensorVariable then the TensorVariable is used and - `learn_init` is ignored. + If True, initial hidden values are learned. gradient_steps : int Number of timesteps to include in the backpropagated gradient. If -1, backpropagate through the entire sequence. @@ -1335,13 +1287,7 @@ def add_gate_params(gate, gate_name): hidden_update, 'hidden_update') # Initialize hidden state - if isinstance(hid_init, T.TensorVariable): - if hid_init.ndim != 2: - raise ValueError( - "When hid_init is provided as a TensorVariable, it should " - "have 2 dimensions and have shape (num_batch, num_units)") - self.hid_init = hid_init - elif isinstance(hid_init, Layer): + if isinstance(hid_init, Layer): self.hid_init = hid_init else: self.hid_init = self.add_param( @@ -1487,11 +1433,7 @@ def step_masked(input_n, mask_n, hid_previous, *args): sequences = [input] step_fun = step - if isinstance(self.hid_init, Layer): - pass - elif isinstance(self.hid_init, T.TensorVariable): - hid_init = self.hid_init - else: + if not isinstance(self.hid_init, Layer): # Dot against a 1s vector to repeat to shape (num_batch, num_units) hid_init = T.dot(T.ones((num_batch, 1)), self.hid_init) diff --git a/lasagne/tests/layers/test_recurrent.py b/lasagne/tests/layers/test_recurrent.py index dc38cafc..32e5770d 100644 --- a/lasagne/tests/layers/test_recurrent.py +++ b/lasagne/tests/layers/test_recurrent.py @@ -104,49 +104,12 @@ def test_recurrent_hid_init_mask(): output = lasagne.layers.get_output(l_rec, inputs) -def test_recurrent_tensor_init(): - # check if passing in a TensorVariable to hid_init works - num_units = 5 - batch_size = 3 - seq_len = 2 - n_inputs = 4 - in_shp = (batch_size, seq_len, n_inputs) - l_inp = InputLayer(in_shp) - hid_init = T.matrix() - x = T.tensor3() - - l_rec = RecurrentLayer(l_inp, num_units, learn_init=True, - hid_init=hid_init) - # check that the tensor is used - assert hid_init == l_rec.hid_init - - # b, W_hid_to_hid and W_in_to_hid, should not return any inits - assert len(lasagne.layers.get_all_params(l_rec, trainable=True)) == 3 - - # b, should not return any inits - assert len(lasagne.layers.get_all_params(l_rec, regularizable=False)) == 1 - - # check that it compiles and runs - output = lasagne.layers.get_output(l_rec, x) - x_test = np.ones(in_shp, dtype='float32') - hid_init_test = np.ones((batch_size, num_units), dtype='float32') - output_val = output.eval({x: x_test, hid_init: hid_init_test}) - assert isinstance(output_val, np.ndarray) - - def test_recurrent_incoming_tuple(): input_shape = (2, 3, 4) l_rec = lasagne.layers.RecurrentLayer(input_shape, 5) assert l_rec.input_shapes[0] == input_shape -def test_recurrent_init_val_error(): - # check if errors are raised when init is non matrix tensor - hid_init = T.vector() - with pytest.raises(ValueError): - l_rec = RecurrentLayer(InputLayer((2, 2, 3)), 5, hid_init=hid_init) - - def test_recurrent_name(): l_in = lasagne.layers.InputLayer((2, 3, 4)) layer_name = 'l_rec' @@ -445,44 +408,6 @@ def test_lstm_nparams_learn_init(): assert len(lasagne.layers.get_all_params(l_lstm, regularizable=False)) == 6 -def test_lstm_tensor_init(): - # check if passing in TensorVariables to cell_init and hid_init works - num_units = 5 - batch_size = 3 - seq_len = 2 - n_inputs = 4 - in_shp = (batch_size, seq_len, n_inputs) - l_inp = InputLayer(in_shp) - hid_init = T.matrix() - cell_init = T.matrix() - x = T.tensor3() - - l_lstm = LSTMLayer(l_inp, num_units, peepholes=False, learn_init=True, - hid_init=hid_init, cell_init=cell_init) - - # check that the tensors are used and not overwritten - assert cell_init == l_lstm.cell_init - assert hid_init == l_lstm.hid_init - - # 3*n_gates, should not return any inits - # the 3 is because we have hid_to_gate, in_to_gate and bias for each gate - assert len(lasagne.layers.get_all_params(l_lstm, trainable=True)) == 12 - - # bias params(4), , should not return any inits - assert len(lasagne.layers.get_all_params(l_lstm, regularizable=False)) == 4 - - # check that it compiles and runs - output = lasagne.layers.get_output(l_lstm, x) - - x_test = np.ones(in_shp, dtype='float32') - hid_init_test = np.ones((batch_size, num_units), dtype='float32') - cell_init_test = np.ones_like(hid_init_test) - output_val = output.eval( - {x: x_test, cell_init: cell_init_test, hid_init: hid_init_test}) - - assert isinstance(output_val, np.ndarray) - - def test_lstm_hid_init_layer(): # test that you can set hid_init to be a layer l_inp = InputLayer((2, 2, 3)) @@ -536,16 +461,6 @@ def test_lstm_hid_init_mask(): output = lasagne.layers.get_output(l_lstm, inputs) -def test_lstm_init_val_error(): - # check if errors are raised when inits are non matrix tensor - vector = T.vector() - with pytest.raises(ValueError): - l_rec = LSTMLayer(InputLayer((2, 2, 3)), 5, hid_init=vector) - - with pytest.raises(ValueError): - l_rec = LSTMLayer(InputLayer((2, 2, 3)), 5, cell_init=vector) - - def test_lstm_grad_clipping(): # test that you can set grad_clip variable x = T.tensor3() @@ -763,45 +678,6 @@ def test_gru_nparams_learn_init_true(): assert len(lasagne.layers.get_all_params(l_gru, regularizable=False)) == 4 -def test_gru_tensor_init(): - # check if passing in a TensorVariable to hid_init works - num_units = 5 - batch_size = 3 - seq_len = 2 - n_inputs = 4 - in_shp = (batch_size, seq_len, n_inputs) - l_inp = InputLayer(in_shp) - hid_init = T.matrix() - x = T.tensor3() - - l_lstm = GRULayer(l_inp, num_units, learn_init=True, hid_init=hid_init) - - # check that the tensors are used and not overwritten - assert hid_init == l_lstm.hid_init - - # 3*n_gates, should not return any inits - # the 3 is because we have hid_to_gate, in_to_gate and bias for each gate - assert len(lasagne.layers.get_all_params(l_lstm, trainable=True)) == 9 - - # bias params(3), , should not return any inits - assert len(lasagne.layers.get_all_params(l_lstm, regularizable=False)) == 3 - - # check that it compiles and runs - output = lasagne.layers.get_output(l_lstm, x) - x_test = np.ones(in_shp, dtype='float32') - hid_init_test = np.ones((batch_size, num_units), dtype='float32') - - output_val = output.eval({x: x_test, hid_init: hid_init_test}) - assert isinstance(output_val, np.ndarray) - - -def test_gru_init_val_error(): - # check if errors are raised when init is non matrix tensorVariable - vector = T.vector() - with pytest.raises(ValueError): - l_rec = GRULayer(InputLayer((2, 2, 3)), 5, hid_init=vector) - - def test_gru_hid_init_layer(): # test that you can set hid_init to be a layer l_inp = InputLayer((2, 2, 3))