From 3b50f10630a5f5587294b9c6209357e8f76c82f6 Mon Sep 17 00:00:00 2001 From: lvapeab Date: Wed, 3 Jan 2018 00:18:37 +0100 Subject: [PATCH] Fix Conditional RNNs input specs --- keras/layers/recurrent.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/keras/layers/recurrent.py b/keras/layers/recurrent.py index 3bd5a908c8a..e0ad256091d 100644 --- a/keras/layers/recurrent.py +++ b/keras/layers/recurrent.py @@ -1667,6 +1667,7 @@ def __init__(self, units, recurrent_dropout=0., conditional_dropout=0., num_inputs=4, + static_ctx=False, **kwargs): super(GRUCond, self).__init__(**kwargs) @@ -1705,6 +1706,10 @@ def __init__(self, units, self.conditional_dropout = min(1., max(0., conditional_dropout)) if conditional_dropout is not None else 0. self.num_inputs = num_inputs self.input_spec = [InputSpec(ndim=3), InputSpec(ndim=3)] + if static_ctx: + self.input_spec = [InputSpec(ndim=3), InputSpec(ndim=2)] + else: + self.input_spec = [InputSpec(ndim=3), InputSpec(ndim=3)] for _ in range(len(self.input_spec), self.num_inputs): self.input_spec.append(InputSpec(ndim=2)) @@ -4226,6 +4231,7 @@ def __init__(self, units, recurrent_dropout=0., conditional_dropout=0., num_inputs=4, + static_ctx=False, **kwargs): super(LSTMCond, self).__init__(**kwargs) @@ -4271,7 +4277,11 @@ def __init__(self, units, self.recurrent_dropout = min(1., max(0., recurrent_dropout))if recurrent_dropout is not None else 0. self.conditional_dropout = min(1., max(0., conditional_dropout))if conditional_dropout is not None else 0. self.num_inputs = num_inputs - self.input_spec = [InputSpec(ndim=3), InputSpec(ndim=2)] + if static_ctx: + self.input_spec = [InputSpec(ndim=3), InputSpec(ndim=2)] + else: + self.input_spec = [InputSpec(ndim=3), InputSpec(ndim=3)] + for _ in range(len(self.input_spec), self.num_inputs): self.input_spec.append(InputSpec(ndim=2))