Skip to content

Commit

Permalink
Lazily import classes as needed
Browse files Browse the repository at this point in the history
Prevents loading in large libraries (e.g., TF) unnecessarily
  • Loading branch information
WardLT committed Oct 24, 2022
1 parent f9146e4 commit 2b25677
Showing 1 changed file with 16 additions and 18 deletions.
34 changes: 16 additions & 18 deletions dlhub_sdk/client.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import logging
import json
import os
Expand All @@ -14,13 +15,6 @@
from funcx.sdk.client import FuncXClient
from globus_sdk.scopes import AuthScopes, SearchScopes

from dlhub_sdk.models.servables.keras import KerasModel
from dlhub_sdk.models.servables.pytorch import TorchModel
from dlhub_sdk.models.servables.python import PythonClassMethodModel
from dlhub_sdk.models.servables.python import PythonStaticMethodModel
from dlhub_sdk.models.servables.tensorflow import TensorFlowModel
from dlhub_sdk.models.servables.sklearn import ScikitLearnModel

from dlhub_sdk.config import DLHUB_SERVICE_ADDRESS, CLIENT_ID
from dlhub_sdk.utils.futures import DLHubFuture
from dlhub_sdk.utils.schemas import validate_against_dlhub_schema
Expand Down Expand Up @@ -128,7 +122,7 @@ def __init__(self, dlh_authorizer: Optional[GlobusAuthorizer] = None,
AuthScopes.openid: openid_authorizer,
SearchScopes.all: search_authorizer,
'dlhub': dlh_authorizer,
}
}

login_manager = FuncXLoginManager(authorizers=auth_dict)
self._fx_client = FuncXClient(login_manager=login_manager)
Expand Down Expand Up @@ -255,7 +249,7 @@ def describe_methods(self, name, method=None):

def run(self, name: str, inputs: Any, parameters: Optional[Dict[str, Any]] = None,
asynchronous: bool = False, debug: bool = False, validate_input: bool = False,
async_wait: float = 5, timeout: Optional[float] = None)\
async_wait: float = 5, timeout: Optional[float] = None) \
-> Union[
DLHubFuture,
Tuple[Any, Dict[str, Any]],
Expand Down Expand Up @@ -378,19 +372,23 @@ def easy_publish(self, title: str, creators: Union[str, List[str]], short_name:
ValueError: If the given servable_type is not in the list of acceptable types
Exception: If the serv_options are incomplete or the request to publish results in an error
"""
# conversion table for model string names to classes
models = {"static_method": PythonStaticMethodModel,
"class_method": PythonClassMethodModel,
"keras": KerasModel,
"pytorch": TorchModel,
"tensorflow": TensorFlowModel,
"sklearn": ScikitLearnModel}
# conversion table for model string names to class paths
model_paths = {"static_method": ("dlhub_sdk.models.servables.python", "PythonStaticMethodModel"),
"class_method": ("dlhub_sdk.models.servables.python", "PythonClassMethodModel"),
"keras": ("dlhub_sdk.models.servables.keras", "KerasModel"),
"pytorch": ("dlhub_sdk.models.servables.pytorch", "TorchModel"),
"tensorflow": ("dlhub_sdk.models.servables.tensorflow", "TensorFlowModel"),
"sklearn": ("dlhub_sdk.models.servables.sklearn", "ScikitLearnModel")}

# raise an error if the provided servable_type is invalid
model = models.get(servable_type)
if model is None:
if servable_type not in model_paths:
raise ValueError(f"dl.easy_publish given invalid servable type: {servable_type}, please refer to the docstring")

# Load the model in using importlib
module_name, class_name = model_paths[servable_type]
mod = importlib.import_module(module_name)
model = getattr(mod, class_name)

# attempt to construct the model and raise a helpful error if needed
try:
# if the servable is a python function, set the parameter to attempt to auto-generate the inputs/outputs
Expand Down

0 comments on commit 2b25677

Please sign in to comment.