Skip to content
This repository has been archived by the owner on Mar 3, 2024. It is now read-only.

Commit

Permalink
Add more choices for character feature extraction
Browse files Browse the repository at this point in the history
  • Loading branch information
CyberZHG committed Sep 26, 2018
1 parent c590878 commit 35f7bb7
Show file tree
Hide file tree
Showing 7 changed files with 129 additions and 28 deletions.
6 changes: 4 additions & 2 deletions README.md
Expand Up @@ -79,6 +79,8 @@ model.summary()

The output shape of `embd_layer` should be `(None, None, 600)`, which represents the batch size, the length of sentence and the length of encoded word feature.

`char_hidden_layer_type` could be 'lstm', 'gru', 'cnn', a Keras layer or a list of Keras layers.

#### `get_batch_input`

The function is used to generate the batch inputs for the model.
Expand Down Expand Up @@ -110,7 +112,7 @@ inputs, embd_layer = get_embedding_layer(
char_embd_dim=50,
char_hidden_dim=150,
word_embd_weights=word_embd_weights,
rnn='lstm',
char_hidden_layer_type='lstm',
)
```

Expand Down Expand Up @@ -157,6 +159,6 @@ model.fit_generator(
)
```

### Citation
## Citation

Several papers have done the same thing. Just choose the one you have seen.
6 changes: 4 additions & 2 deletions README.rst
Expand Up @@ -99,6 +99,8 @@ Generate the first few layers that encodes words in a sentence:
The output shape of ``embd_layer`` should be ``(None, None, 600)``\ , which represents the batch size, the length of sentence and the length of encoded word feature.

``char_hidden_layer_type`` could be 'lstm', 'gru', 'cnn', a Keras layer or a list of Keras layers.

``get_batch_input``
~~~~~~~~~~~~~~~~~~~~~~~

Expand Down Expand Up @@ -133,7 +135,7 @@ A helper function that loads pre-trained embeddings for initializing the weights
char_embd_dim=50,
char_hidden_dim=150,
word_embd_weights=word_embd_weights,
rnn='lstm',
char_hidden_layer_type='lstm',
)
Wrapper Class ``WordCharEmbd``
Expand Down Expand Up @@ -181,6 +183,6 @@ There is a wrapper class that makes things easier.
)
Citation
^^^^^^^^
--------

Several papers have done the same thing. Just choose the one you have seen.
2 changes: 1 addition & 1 deletion demo/sentiment_analysis.py
Expand Up @@ -66,7 +66,7 @@
word_embd_dim=150,
char_embd_dim=30,
char_hidden_dim=75,
rnn='lstm'
char_hidden_layer_type='lstm'
)
lstm_layer = keras.layers.Bidirectional(
keras.layers.LSTM(units=50),
Expand Down
51 changes: 37 additions & 14 deletions keras_wc_embd/word_char_embd.py
Expand Up @@ -54,12 +54,13 @@ def get_embedding_layer(word_dict_len,
word_embd_dim=300,
char_embd_dim=30,
char_hidden_dim=150,
rnn='lstm',
char_hidden_layer_type='lstm',
word_embd_weights=None,
char_embd_weights=None,
word_embd_trainable=None,
char_embd_trainable=None,
mask_zeros=True):
word_mask_zero=True,
char_mask_zero=True):
"""Get the merged embedding layer.
:param word_dict_len: The number of words in the dictionary including the ones mapped to 0 or 1.
Expand All @@ -72,8 +73,9 @@ def get_embedding_layer(word_dict_len,
:param char_embd_weights: A numpy array representing the pre-trained embeddings for characters.
:param word_embd_trainable: Whether the word embedding layer is trainable.
:param char_embd_trainable: Whether the character embedding layer is trainable.
:param rnn: The type of the recurrent layer, 'lstm' or 'gru'.
:param mask_zeros: Whether enable the mask.
:param char_hidden_layer_type: The type of the recurrent layer, 'lstm' or 'gru'.
:param word_mask_zero: Whether enable the mask for words.
:param char_mask_zero: Whether enable the mask for characters.
:return inputs, embd_layer: The keras layer.
"""
Expand All @@ -99,21 +101,21 @@ def get_embedding_layer(word_dict_len,
word_embd_layer = keras.layers.Embedding(
input_dim=word_dict_len,
output_dim=word_embd_dim,
mask_zero=mask_zeros,
mask_zero=word_mask_zero,
weights=word_embd_weights,
trainable=word_embd_trainable,
name='Embedding_Word',
)(word_input_layer)
char_embd_layer = keras.layers.Embedding(
input_dim=char_dict_len,
output_dim=char_embd_dim,
mask_zero=mask_zeros,
mask_zero=char_mask_zero,
weights=char_embd_weights,
trainable=char_embd_trainable,
name='Embedding_Char_Pre',
)(char_input_layer)
if rnn == 'lstm':
char_rnn_layer = keras.layers.Bidirectional(
if char_hidden_layer_type == 'lstm':
char_hidden_layer = keras.layers.Bidirectional(
keras.layers.LSTM(
units=char_hidden_dim,
input_shape=(max_word_len, char_dict_len),
Expand All @@ -122,8 +124,8 @@ def get_embedding_layer(word_dict_len,
),
name='Bi-LSTM_Char',
)
else:
char_rnn_layer = keras.layers.Bidirectional(
elif char_hidden_layer_type == 'gru':
char_hidden_layer = keras.layers.Bidirectional(
keras.layers.GRU(
units=char_hidden_dim,
input_shape=(max_word_len, char_dict_len),
Expand All @@ -132,10 +134,31 @@ def get_embedding_layer(word_dict_len,
),
name='Bi-GRU_Char',
)
char_embd_layer = keras.layers.TimeDistributed(
layer=char_rnn_layer,
name='Embedding_Char'
)(char_embd_layer)
elif char_hidden_layer_type == 'cnn':
char_hidden_layer = [
keras.layers.Conv1D(
filters=max(1, char_hidden_dim // 5),
kernel_size=3,
activation='relu',
),
keras.layers.Flatten(),
keras.layers.Dense(
units=char_hidden_dim,
name='Dense_Char',
),
]
elif type(char_hidden_layer_type) is list or isinstance(char_hidden_layer_type, keras.layers.Layer):
char_hidden_layer = char_hidden_layer_type
else:
raise NotImplementedError('Unknown character hidden layer type: %s' % char_hidden_layer_type)
if type(char_hidden_layer) is not list:
char_hidden_layer = [char_hidden_layer]
for i, layer in enumerate(char_hidden_layer):
if i == len(char_hidden_layer) - 1:
name = 'Embedding_Char'
else:
name = 'Embedding_Char_Pre_%d' % (i + 1)
char_embd_layer = keras.layers.TimeDistributed(layer=layer, name=name)(char_embd_layer)
embd_layer = keras.layers.Concatenate(
name='Embedding',
)([word_embd_layer, char_embd_layer])
Expand Down
15 changes: 9 additions & 6 deletions keras_wc_embd/wrapper.py
Expand Up @@ -73,14 +73,15 @@ def get_embedding_layer(self,
word_embd_dim=300,
char_embd_dim=30,
char_hidden_dim=150,
rnn='lstm',
char_hidden_layer_type='lstm',
word_embd_weights=None,
word_embd_file_path=None,
char_embd_weights=None,
char_embd_file_path=None,
word_embd_trainable=None,
char_embd_trainable=None,
mask_zeros=True):
word_mask_zero=True,
char_mask_zero=True,):
"""Get the merged embedding layer.
:param word_embd_dim: The dimensions of the word embedding.
Expand All @@ -92,8 +93,9 @@ def get_embedding_layer(self,
:param char_embd_file_path: The file that contains the character embeddings.
:param word_embd_trainable: Whether the word embedding layer is trainable.
:param char_embd_trainable: Whether the character embedding layer is trainable.
:param rnn: The type of the recurrent layer, 'lstm' or 'gru'.
:param mask_zeros: Whether enable the mask.
:param char_hidden_layer_type: The type of the recurrent layer, 'lstm' or 'gru'.
:param word_mask_zero: Whether enable the mask for words.
:param char_mask_zero: Whether enable the mask for characters.
:return inputs, embd_layer: The keras layer.
"""
Expand All @@ -111,12 +113,13 @@ def get_embedding_layer(self,
word_embd_dim=word_embd_dim,
char_embd_dim=char_embd_dim,
char_hidden_dim=char_hidden_dim,
rnn=rnn,
char_hidden_layer_type=char_hidden_layer_type,
word_embd_weights=word_embd_weights,
char_embd_weights=char_embd_weights,
word_embd_trainable=word_embd_trainable,
char_embd_trainable=char_embd_trainable,
mask_zeros=mask_zeros)
word_mask_zero=word_mask_zero,
char_mask_zero=char_mask_zero)

def get_batch_input(self, sentences):
"""Convert sentences to desired input tensors.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Expand Up @@ -2,7 +2,7 @@

setup(
name='keras-word-char-embd',
version='0.13',
version='0.14',
packages=['keras_wc_embd'],
url='https://github.com/CyberZHG/keras-word-char-embd',
license='MIT',
Expand Down
75 changes: 73 additions & 2 deletions tests/test_get_embedding_layer.py
@@ -1,11 +1,12 @@
import unittest
import numpy
import keras
from keras_wc_embd import get_embedding_layer


class TestGetEmbeddingLayer(unittest.TestCase):

def test_shape(self):
def test_shape_rnn(self):
inputs, embd_layer = get_embedding_layer(
word_dict_len=3,
char_dict_len=5,
Expand All @@ -25,7 +26,7 @@ def test_shape(self):
word_embd_dim=300,
char_embd_dim=50,
char_hidden_dim=150,
rnn='gru',
char_hidden_layer_type='gru',
)
self.assertEqual(len(inputs), 2)
self.assertEqual(inputs[0]._keras_shape, (None, None))
Expand Down Expand Up @@ -71,3 +72,73 @@ def char_embd_wrong_shape():
)

self.assertRaises(ValueError, char_embd_wrong_shape)

def test_shape_cnn(self):
inputs, embd_layer = get_embedding_layer(
word_dict_len=3,
char_dict_len=5,
max_word_len=7,
word_embd_dim=300,
char_embd_dim=50,
char_hidden_dim=150,
char_hidden_layer_type='cnn',
char_mask_zero=False,
)
self.assertEqual(len(inputs), 2)
self.assertEqual(inputs[0]._keras_shape, (None, None))
self.assertEqual(inputs[1]._keras_shape, (None, None, 7))
self.assertEqual(embd_layer._keras_shape, (None, None, 450))

def test_custom(self):
layers = [
keras.layers.Conv1D(
filters=16,
kernel_size=3,
activation='relu',
),
keras.layers.Conv1D(
filters=16,
kernel_size=3,
activation='relu',
),
keras.layers.Flatten(),
keras.layers.Dense(
units=50,
name='Dense_Char',
),
]
inputs, embd_layer = get_embedding_layer(
word_dict_len=3,
char_dict_len=5,
max_word_len=7,
word_embd_dim=300,
char_embd_dim=50,
char_hidden_layer_type=layers,
char_mask_zero=False,
)
self.assertEqual(len(inputs), 2)
self.assertEqual(inputs[0]._keras_shape, (None, None))
self.assertEqual(inputs[1]._keras_shape, (None, None, 7))
self.assertEqual(embd_layer._keras_shape, (None, None, 350))
inputs, embd_layer = get_embedding_layer(
word_dict_len=3,
char_dict_len=5,
max_word_len=7,
word_embd_dim=300,
char_embd_dim=50,
char_hidden_layer_type=keras.layers.GRU(units=30),
char_mask_zero=False,
)
self.assertEqual(len(inputs), 2)
self.assertEqual(inputs[0]._keras_shape, (None, None))
self.assertEqual(inputs[1]._keras_shape, (None, None, 7))
self.assertEqual(embd_layer._keras_shape, (None, None, 330))

def test_not_implemented(self):
with self.assertRaises(NotImplementedError):
get_embedding_layer(
word_dict_len=3,
char_dict_len=5,
max_word_len=7,
char_hidden_layer_type='Jack',
)

0 comments on commit 35f7bb7

Please sign in to comment.