Skip to content

Commit

Permalink
⚡ Improving performance.
Browse files Browse the repository at this point in the history
  • Loading branch information
BrikerMan committed May 21, 2019
1 parent 72d9042 commit 44ad63a
Show file tree
Hide file tree
Showing 5 changed files with 39 additions and 14 deletions.
1 change: 1 addition & 0 deletions kashgari/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import os
os.environ['TF_KERAS'] = '1'

from kashgari import layers
from kashgari import corpus
from kashgari import embeddings
from kashgari import macros
Expand Down
21 changes: 21 additions & 0 deletions kashgari/layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# encoding: utf-8

# author: BrikerMan
# contact: eliyar917@gmail.com
# blog: https://eliyar.biz

# file: layers.py
# time: 2019-05-21 18:55

import tensorflow as tf
from tensorflow.python import keras

L = keras.layers

if tf.test.is_gpu_available(cuda_only=True):
UnifiedLSTM = L.CuDNNLSTM
else:
UnifiedLSTM = L.LSTM

if __name__ == "__main__":
print("Hello world")
21 changes: 11 additions & 10 deletions kashgari/tasks/labeling/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,8 @@ def fit(self,
batch_size: Number of samples per gradient update, default to 64.
epochs: Integer. Number of epochs to train the model. default 5.
fit_kwargs: fit_kwargs: additional arguments passed to ``fit_generator()`` function from
``tensorflow.keras.Model`` - https://www.tensorflow.org/api_docs/python/tf/keras/models/Model#fit_generator
``tensorflow.keras.Model``
- https://www.tensorflow.org/api_docs/python/tf/keras/models/Model#fit_generator
**kwargs:
Returns:
Expand All @@ -134,13 +135,13 @@ def fit(self,
x_validate = utils.wrap_as_tuple(x_validate)
y_validate = utils.wrap_as_tuple(y_validate)
if self.embedding.token_count == 0:
if x_validate is not None:
x_all = (x_train + x_validate)
y_all = (y_train + y_validate)
else:
x_all = x_train
y_all = y_train
self.embedding.analyze_corpus(x_all, y_all)
# if x_validate is not None:
# y_all = (y_train + y_validate)
# x_all = (x_train + x_validate)
# else:
# x_all = x_train.copy()
# y_all = y_train.copy()
self.embedding.analyze_corpus(x_train, y_train)

if self.tf_model is None:
self.build_model_arc()
Expand All @@ -158,8 +159,8 @@ def fit(self,
batch_size)

fit_kwargs['validation_data'] = validation_generator
fit_kwargs['validation_steps'] = len(x_validate) // batch_size

fit_kwargs['validation_steps'] = len(x_validate[0]) // batch_size
print(fit_kwargs)
self.tf_model.fit_generator(train_generator,
steps_per_epoch=len(x_train[0]) // batch_size,
epochs=epochs,
Expand Down
8 changes: 5 additions & 3 deletions kashgari/tasks/labeling/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from tensorflow import keras

from kashgari.tasks.labeling.base_model import BaseLabelingModel
from kashgari.layers import UnifiedLSTM


L = keras.layers

Expand Down Expand Up @@ -49,7 +51,7 @@ def build_model_arc(self):
config = self.hyper_parameters
embed_model = self.embedding.embed_model

layer_blstm = L.Bidirectional(L.LSTM(**config['layer_blstm']),
layer_blstm = L.Bidirectional(UnifiedLSTM(**config['layer_blstm']),
name='layer_blstm')

layer_dropout = L.Dropout(**config['layer_dropout'],
Expand Down Expand Up @@ -106,8 +108,8 @@ def build_model_arc(self):

layer_conv = L.Conv1D(**config['layer_conv'],
name='layer_conv')
layer_lstm = L.LSTM(**config['layer_lstm'],
name='layer_lstm')
layer_lstm = UnifiedLSTM(**config['layer_lstm'],
name='layer_lstm')
layer_dropout = L.Dropout(**config['layer_dropout'],
name='layer_dropout')
layer_time_distributed = L.TimeDistributed(L.Dense(output_dim,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_labeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def setUpClass(cls):

def test_basic_use_build(self):
model = self.model_class()
model.fit(valid_x, valid_y, epochs=1)
model.fit(valid_x, valid_y, valid_x, valid_y, epochs=1)
assert True

def test_w2v_model(self):
Expand Down

0 comments on commit 44ad63a

Please sign in to comment.