Skip to content

Commit

Permalink
Added cx.Layer(bidirectional=mode)
Browse files Browse the repository at this point in the history
  • Loading branch information
dsblank committed Sep 7, 2018
1 parent 4866842 commit 4bab1a6
Showing 1 changed file with 32 additions and 3 deletions.
35 changes: 32 additions & 3 deletions conx/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,15 @@ def __init__(self, name, *args, **params):
else:
self.dropout = 0

if 'bidirectional' in params:
bidirectional = params['bidirectional']
del params["bidirectional"] # we handle it
if bidirectional not in ['sum', 'mul', 'concat', 'ave', True, None]:
raise Exception('bad bidirectional value: %s' % (bidirectional,))
self.bidirectional = bidirectional
else:
self.bidirectional = None

if 'time_distributed' in params:
time_distributed = params['time_distributed']
del params["time_distributed"] # we handle time distributed wrappers
Expand Down Expand Up @@ -266,23 +275,39 @@ def make_keras_functions(self):
Make all Keras functions for this layer, including its own,
dropout, etc.
"""
from keras.layers import TimeDistributed
from keras.layers import TimeDistributed, Bidirectional
k = self.make_keras_function() # can override
### wrap layer:
if self.bidirectional:
if self.bidirectional is True:
k = Bidirectional(k, name=self.name)
else:
k = Bidirectional(k, merge_mode=self.bidirectional, name=self.name)
if self.time_distributed:
k = TimeDistributed(k, name=self.name)
### sequence:
if self.dropout > 0:
return [k, keras.layers.Dropout(self.dropout)]
k = [k] + [keras.layers.Dropout(self.dropout)]
else:
return [k]
k = [k]
return k

def make_keras_functions_text(self):
"""
Make all Keras functions for this layer, including its own,
dropout, etc.
"""
def bidir_mode(name):
if name in [True, None]:
return "concat"
else:
return name
program = self.make_keras_function_text()
if self.time_distributed:
program = "keras.layers.TimeDistributed(%s, name='%s')" % (program, self.name)
if self.bidirectional:
program = "keras.layers.Bidirectional(%s, name='%s', mode='%s')" % (
program, self.name, bidir_mode(self.bidirectional))
if self.dropout > 0:
return "[%s, keras.layers.Dropout(self.dropout)]" % program
else:
Expand Down Expand Up @@ -455,6 +480,8 @@ def format_range(minmax):
retval += "\n shape = %s" % (self.shape, )
if self.dropout:
retval += "\n dropout = %s" % self.dropout
if self.bidirectional:
retval += "\n bidirectional = %s" % self.bidirectional
if kind == "input":
retval += "\n Keras class = Input"
else:
Expand Down Expand Up @@ -545,6 +572,8 @@ def print_summary(self, fp=sys.stdout):
print(" * **Activation function**:", self.activation, file=fp)
if self.dropout:
print(" * **Dropout percent** :", self.dropout, file=fp)
if self.bidirectional:
print(" * **Bidirectional mode** :", self.bidirectional, file=fp)

def make_keras_function(self):
"""
Expand Down

0 comments on commit 4bab1a6

Please sign in to comment.