Skip to content

Commit

Permalink
Merge pull request #121 from alexwwang/tf.keras-version
Browse files Browse the repository at this point in the history
Tf.keras version models zoo constructed, first batch.
  • Loading branch information
BrikerMan committed Jun 28, 2019
2 parents 9e2369e + 69ff590 commit 7098d65
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 8 deletions.
15 changes: 11 additions & 4 deletions kashgari/layers/att_wgt_avg_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@
import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine.input_spec import InputSpec

L = keras.layers
initializers = keras.initializers
InputSpec = L.InputSpec

if tf.test.is_gpu_available(cuda_only=True):
L.LSTM = L.CuDNNLSTM
Expand All @@ -35,10 +35,12 @@ def build(self, input_shape):
self.input_spec = [InputSpec(ndim=3)]
assert len(input_shape) == 3

self.W = self.add_weight(shape=(input_shape[2], 1),
self.W = self.add_weight(shape=(input_shape[2].value, 1),
name='{}_w'.format(self.name),
initializer=self.init)
self.trainable_weights = [self.W]
initializer=self.init,
trainable=True
)
# self.trainable_weights = [self.W]
super(AttentionWeightedAverageLayer, self).build(input_shape)

def call(self, x, mask=None):
Expand Down Expand Up @@ -77,6 +79,11 @@ def compute_mask(self, inputs, input_mask=None):
else:
return None

def get_config(self):
config = {'return_attention': self.return_attention, }
base_config = super(AttentionWeightedAverageLayer, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


AttentionWeightedAverage = AttentionWeightedAverageLayer
AttWgtAvgLayer = AttentionWeightedAverageLayer
Expand Down
2 changes: 1 addition & 1 deletion kashgari/layers/kmax_pool_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@
import tensorflow as tf
from tensorflow.python import keras
from tensorflow.python.keras import backend as K
from tensorflow.python.keras.engine.input_spec import InputSpec

L = keras.layers
InputSpec = L.InputSpec

if tf.test.is_gpu_available(cuda_only=True):
L.LSTM = L.CuDNNLSTM
Expand Down
4 changes: 2 additions & 2 deletions kashgari/tasks/classification/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,13 +283,13 @@ def build_model_arc(self):
tensors_conv = [layer_conv(embed_tensor) for layer_conv in layers_conv]
tensors_matrix_sensor = []
for tensor_conv in tensors_conv:
tensor_sensors = []
tensor_sensors = [layer_sensor(tensor_conv) for layer_sensor in layers_sensor]
# tensor_sensors = []
# tensor_sensors.append(L.GlobalMaxPooling1D()(tensor_conv))
# tensor_sensors.append(AttentionWeightedAverageLayer()(tensor_conv))
# tensor_sensors.append(L.GlobalAveragePooling1D()(tensor_conv))
tensors_matrix_sensor.append(tensor_sensors)
tensors_views = [layer_view(tensors) for tensors in zip(*tensors_matrix_sensor)]
tensors_views = [layer_view(list(tensors)) for tensors in zip(*tensors_matrix_sensor)]
tensor = layer_allviews(tensors_views)
# tensors_v_cols = [L.concatenate(tensors, **config['v_col3']) for tensors
# in zip(*tensors_matrix_sensor)]
Expand Down
2 changes: 1 addition & 1 deletion tests/classification/test_bi_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def test_custom_hyper_params(self):
if isinstance(value, bool):
pass
elif isinstance(value, int):
hyper_params[layer][key] = value + 15
hyper_params[layer][key] = value + 15 if value > 64 else value
model = self.model_class(embedding=w2v_embedding_variable_len,
hyper_parameters=hyper_params)
model.fit(valid_x, valid_y, epochs=1)
Expand Down

0 comments on commit 7098d65

Please sign in to comment.