Skip to content

Commit

Permalink
✅ Adding tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
BrikerMan committed Jun 24, 2019
1 parent 34b05ab commit 083b7d2
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 9 deletions.
9 changes: 8 additions & 1 deletion kashgari/tasks/classification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,11 @@
# time: 2019-05-22 12:40


from kashgari.tasks.classification.models import BLSTMModel
from kashgari.tasks.classification.models import BiLSTM_Model
from kashgari.tasks.classification.models import CNN_Model
from kashgari.tasks.classification.models import CNN_LSTM_Model


BLSTMModel = BiLSTM_Model
CNNModel = CNN_Model
CNNLSTMModel = CNN_LSTM_Model
9 changes: 2 additions & 7 deletions kashgari/tasks/classification/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,20 +121,15 @@ def build_model_arc(self):
self.tf_model = tf.keras.Model(embed_model.inputs, tensor)


BLSTMModel = BiLSTM_Model
CNNModel = CNN_Model
CNNLSTMModel = CNN_LSTM_Model

if __name__ == "__main__":
print(BLSTMModel.get_default_hyper_parameters())
print(BiLSTM_Model.get_default_hyper_parameters())
logging.basicConfig(level=logging.DEBUG)
from kashgari.corpus import SMP2018ECDTCorpus

x, y = SMP2018ECDTCorpus.load_data()

m = BLSTMModel()
m = BiLSTM_Model()
m.build_model(x, y)
r = m.get_data_generator(x, y)
m.fit(x, y, epochs=5)
m.evaluate(x, y)
print(m.predict(x[:10]))
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def test_bert_model(self):
assert np.array_equal(new_res, res)


class TestCNN_LSTM_Model(unittest.TestCase):
class TestBi_LSTM_Model(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model_class = BLSTMModel
Expand Down
22 changes: 22 additions & 0 deletions tests/classification/test_cnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# encoding: utf-8

# author: BrikerMan
# contact: eliyar917@gmail.com
# blog: https://eliyar.biz

# file: test_cnn.py
# time: 21:35

import tests.classification.test_bi_lstm as base
from kashgari.tasks.classification import CNN_Model


class TestBiGRUModel(base.TestBi_LSTM_Model):
@classmethod
def setUpClass(cls):
cls.epochs = 1
cls.model_class = CNN_Model


if __name__ == "__main__":
print("hello, world")
22 changes: 22 additions & 0 deletions tests/classification/test_cnn_lstm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# encoding: utf-8

# author: BrikerMan
# contact: eliyar917@gmail.com
# blog: https://eliyar.biz

# file: test_cnn_lstm.py
# time: 21:35

import tests.classification.test_bi_lstm as base
from kashgari.tasks.classification import CNN_LSTM_Model


class TestBiGRUModel(base.TestBi_LSTM_Model):
@classmethod
def setUpClass(cls):
cls.epochs = 1
cls.model_class = CNN_LSTM_Model


if __name__ == "__main__":
print("hello, world")

0 comments on commit 083b7d2

Please sign in to comment.