From de3e594d1db59423a7a965cb38a092ca4c2de4a1 Mon Sep 17 00:00:00 2001 From: BrikerMan Date: Wed, 10 Jun 2020 14:55:43 +0800 Subject: [PATCH] :green_heart: Fixing CI Build. --- docs/conf.py | 1 + kashgari/utils/__init__.py | 9 +++++++-- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/docs/conf.py b/docs/conf.py index ffccf673..4bdad823 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -45,6 +45,7 @@ def __getattr__(cls, name): 'tensorflow.keras.callbacks', 'tensorflow.keras.backend', 'tensorflow.python', + 'tensorflow.python.util', 'bert4keras', 'bert4keras.models', 'sklearn', diff --git a/kashgari/utils/__init__.py b/kashgari/utils/__init__.py index 73e23c0b..e15cc02b 100644 --- a/kashgari/utils/__init__.py +++ b/kashgari/utils/__init__.py @@ -9,6 +9,7 @@ import warnings import tensorflow as tf +from typing import TYPE_CHECKING, Union from tensorflow.keras.utils import CustomObjectScope from kashgari import custom_objects @@ -17,16 +18,20 @@ from .multi_label import MultiLabelBinarizer from .serialize import load_data_object +if TYPE_CHECKING: + from kashgari.tasks.labeling import ABCLabelingModel + from kashgari.tasks.classification import ABCClassificationModel + def custom_object_scope() -> CustomObjectScope: return tf.keras.utils.custom_object_scope(custom_objects) -def load_model(model_path, *args, **kwargs): +def load_model(model_path: str) -> Union["ABCLabelingModel", "ABCClassificationModel"]: warnings.warn("The 'load_model' function is deprecated, " "use 'XX_Model.load_model' instead", DeprecationWarning, 2) from kashgari.tasks.abs_task_model import ABCTaskModel - ABCTaskModel.load_model(model_path=model_path) + return ABCTaskModel.load_model(model_path=model_path) if __name__ == "__main__":