Skip to content

Commit

Permalink
✨ Add GRU classification models.
Browse files Browse the repository at this point in the history
  • Loading branch information
BrikerMan committed Jun 25, 2019
1 parent c6c492f commit 2a64749
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 1 deletion.
3 changes: 2 additions & 1 deletion kashgari/tasks/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,10 @@


from kashgari.tasks.classification.models import BiLSTM_Model
from kashgari.tasks.classification.models import BiGRU_Model
from kashgari.tasks.classification.models import CNN_Model
from kashgari.tasks.classification.models import CNN_LSTM_Model

from kashgari.tasks.classification.models import CNN_GRU_Model

BLSTMModel = BiLSTM_Model
CNNModel = CNN_Model
Expand Down
68 changes: 68 additions & 0 deletions kashgari/tasks/classification/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,34 @@ def build_model_arc(self):
self.tf_model = tf.keras.Model(embed_model.inputs, output_tensor)


class BiGRU_Model(BaseClassificationModel):

@classmethod
def get_default_hyper_parameters(cls) -> Dict[str, Dict[str, Any]]:
return {
'layer_bi_gru': {
'units': 128,
'return_sequences': False
},
'layer_dense': {
'activation': 'softmax'
}
}

def build_model_arc(self):
output_dim = len(self.pre_processor.label2idx)
config = self.hyper_parameters
embed_model = self.embedding.embed_model

layer_bi_lstm = L.Bidirectional(L.GRU(**config['layer_bi_lstm']))
layer_dense = L.Dense(output_dim, **config['layer_dense'])

tensor = layer_bi_lstm(embed_model.output)
output_tensor = layer_dense(tensor)

self.tf_model = tf.keras.Model(embed_model.inputs, output_tensor)


class CNN_Model(BaseClassificationModel):

@classmethod
Expand Down Expand Up @@ -121,6 +149,46 @@ def build_model_arc(self):
self.tf_model = tf.keras.Model(embed_model.inputs, tensor)


class CNN_GRU_Model(BaseClassificationModel):

@classmethod
def get_default_hyper_parameters(cls) -> Dict[str, Dict[str, Any]]:
return {
'conv_layer': {
'filters': 32,
'kernel_size': 3,
'padding': 'same',
'activation': 'relu'
},
'max_pool_layer': {
'pool_size': 2
},
'gru_layer': {
'units': 100
},
'activation_layer': {
'activation': 'softmax'
},
}

def build_model_arc(self):
output_dim = len(self.pre_processor.label2idx)
config = self.hyper_parameters
embed_model = self.embedding.embed_model

layers_seq = []
layers_seq.append(L.Conv1D(**config['conv_layer']))
layers_seq.append(L.MaxPooling1D(**config['max_pool_layer']))
layers_seq.append(L.LSTM(**config['lstm_layer']))
layers_seq.append(L.Dense(output_dim, **config['activation_layer']))

tensor = embed_model.output
for layer in layers_seq:
tensor = layer(tensor)

self.tf_model = tf.keras.Model(embed_model.inputs, tensor)


if __name__ == "__main__":
print(BiLSTM_Model.get_default_hyper_parameters())
logging.basicConfig(level=logging.DEBUG)
Expand Down
22 changes: 22 additions & 0 deletions tests/classification/test_bi_gru.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# encoding: utf-8

# author: BrikerMan
# contact: eliyar917@gmail.com
# blog: https://eliyar.biz

# file: test_bi_gru.py
# time: 11:22

import tests.classification.test_bi_lstm as base
from kashgari.tasks.classification import BiGRU_Model


class TestBiGRUModel(base.TestBi_LSTM_Model):
@classmethod
def setUpClass(cls):
cls.epochs = 1
cls.model_class = BiGRU_Model


if __name__ == "__main__":
print("hello, world")
22 changes: 22 additions & 0 deletions tests/classification/test_cnn_gru.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# encoding: utf-8

# author: BrikerMan
# contact: eliyar917@gmail.com
# blog: https://eliyar.biz

# file: test_cnn_gru.py
# time: 11:22

import tests.classification.test_bi_lstm as base
from kashgari.tasks.classification import CNN_GRU_Model


class TestBiGRUModel(base.TestBi_LSTM_Model):
@classmethod
def setUpClass(cls):
cls.epochs = 1
cls.model_class = CNN_GRU_Model


if __name__ == "__main__":
print("hello, world")

0 comments on commit 2a64749

Please sign in to comment.