Skip to content

Commit

Permalink
💚 Fixing CI Build.
Browse files Browse the repository at this point in the history
  • Loading branch information
BrikerMan committed Jun 10, 2020
1 parent 7b0647a commit de3e594
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
1 change: 1 addition & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __getattr__(cls, name):
'tensorflow.keras.callbacks',
'tensorflow.keras.backend',
'tensorflow.python',
'tensorflow.python.util',
'bert4keras',
'bert4keras.models',
'sklearn',
Expand Down
9 changes: 7 additions & 2 deletions kashgari/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

import warnings
import tensorflow as tf
from typing import TYPE_CHECKING, Union
from tensorflow.keras.utils import CustomObjectScope

from kashgari import custom_objects
Expand All @@ -17,16 +18,20 @@
from .multi_label import MultiLabelBinarizer
from .serialize import load_data_object

if TYPE_CHECKING:
from kashgari.tasks.labeling import ABCLabelingModel
from kashgari.tasks.classification import ABCClassificationModel


def custom_object_scope() -> CustomObjectScope:
return tf.keras.utils.custom_object_scope(custom_objects)


def load_model(model_path, *args, **kwargs):
def load_model(model_path: str) -> Union["ABCLabelingModel", "ABCClassificationModel"]:
warnings.warn("The 'load_model' function is deprecated, "
"use 'XX_Model.load_model' instead", DeprecationWarning, 2)
from kashgari.tasks.abs_task_model import ABCTaskModel
ABCTaskModel.load_model(model_path=model_path)
return ABCTaskModel.load_model(model_path=model_path)


if __name__ == "__main__":
Expand Down

0 comments on commit de3e594

Please sign in to comment.