Skip to content

Commit

Permalink
Merge pull request #123 from alexwwang/tf.keras-version
Browse files Browse the repository at this point in the history
Tf.keras version fixed a 🐛 in AVRNN architecture
  • Loading branch information
BrikerMan committed Jun 28, 2019
2 parents 95119f8 + 7afa174 commit 6be7590
Showing 1 changed file with 3 additions and 6 deletions.
9 changes: 3 additions & 6 deletions kashgari/tasks/classification/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -508,11 +508,10 @@ def build_model_arc(self):

layer_bi_rnn1 = L.Bidirectional(L.GRU(**config['rnn_1']))

layers_concat = []
layers_concat.append(L.Concatenate(**config['concat_rnn']))
layers_concat.append(L.Lambda(lambda t: t[:, -1], name='last'))
layer_concat = L.Concatenate(**config['concat_rnn'])

layers_sensor = []
layers_sensor.append(L.Lambda(lambda t: t[:, -1], name='last'))
layers_sensor.append(L.GlobalMaxPooling1D())
layers_sensor.append(AttentionWeightedAverageLayer())
layers_sensor.append(L.GlobalAveragePooling1D())
Expand All @@ -526,9 +525,7 @@ def build_model_arc(self):
tensor_rnn = embed_model.output
for layer in layers_rnn0:
tensor_rnn = layer(tensor_rnn)
tensor_concat = [tensor_rnn, layer_bi_rnn1(tensor_rnn)]
for layer in layers_concat:
tensor_concat = layer(tensor_concat)
tensor_concat = layer_concat([tensor_rnn, layer_bi_rnn1(tensor_rnn)])
tensor_sensors = [layer(tensor_concat) for layer in layers_sensor]
tensor_output = layer_allviews(tensor_sensors)
for layer in layers_full_connect:
Expand Down

0 comments on commit 6be7590

Please sign in to comment.