Skip to content

Commit

Permalink
🐛 fix bugs caused by discrepancies between tf.keras & Keras
Browse files Browse the repository at this point in the history
  • Loading branch information
alexwwang committed Jun 28, 2019
1 parent b76d798 commit f15fd0a
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 7 deletions.
15 changes: 11 additions & 4 deletions kashgari/layers/att_wgt_avg_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine.input_spec import InputSpec

L = keras.layers
initializers = keras.initializers
InputSpec = L.InputSpec

if tf.test.is_gpu_available(cuda_only=True):
L.LSTM = L.CuDNNLSTM
Expand All @@ -35,10 +35,12 @@ def build(self, input_shape):
self.input_spec = [InputSpec(ndim=3)]
assert len(input_shape) == 3

self.W = self.add_weight(shape=(input_shape[2], 1),
self.W = self.add_weight(shape=(input_shape[2].value, 1),
name='{}_w'.format(self.name),
initializer=self.init)
self.trainable_weights = [self.W]
initializer=self.init,
trainable=True
)
# self.trainable_weights = [self.W]
super(AttentionWeightedAverageLayer, self).build(input_shape)

def call(self, x, mask=None):
Expand Down Expand Up @@ -77,6 +79,11 @@ def compute_mask(self, inputs, input_mask=None):
else:
return None

def get_config(self):
config = {'return_attention': self.return_attention,}
base_config = super(AttentionWeightedAverageLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


AttentionWeightedAverage = AttentionWeightedAverageLayer
AttWgtAvgLayer = AttentionWeightedAverageLayer
Expand Down
2 changes: 1 addition & 1 deletion kashgari/layers/kmax_pool_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine.input_spec import InputSpec

L = keras.layers
InputSpec = L.InputSpec

if tf.test.is_gpu_available(cuda_only=True):
L.LSTM = L.CuDNNLSTM
Expand Down
4 changes: 2 additions & 2 deletions kashgari/tasks/classification/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,13 +283,13 @@ def build_model_arc(self):
tensors_conv = [layer_conv(embed_tensor) for layer_conv in layers_conv]
tensors_matrix_sensor = []
for tensor_conv in tensors_conv:
tensor_sensors = []
tensor_sensors = [layer_sensor(tensor_conv) for layer_sensor in layers_sensor]
# tensor_sensors = []
# tensor_sensors.append(L.GlobalMaxPooling1D()(tensor_conv))
# tensor_sensors.append(AttentionWeightedAverageLayer()(tensor_conv))
# tensor_sensors.append(L.GlobalAveragePooling1D()(tensor_conv))
tensors_matrix_sensor.append(tensor_sensors)
tensors_views = [layer_view(tensors) for tensors in zip(*tensors_matrix_sensor)]
tensors_views = [layer_view(list(tensors)) for tensors in zip(*tensors_matrix_sensor)]
tensor = layer_allviews(tensors_views)
# tensors_v_cols = [L.concatenate(tensors, **config['v_col3']) for tensors
# in zip(*tensors_matrix_sensor)]
Expand Down

0 comments on commit f15fd0a

Please sign in to comment.