Skip to content

Commit

Permalink
Merge pull request #87 from DLHub-Argonne/fix_skl
Browse files Browse the repository at this point in the history
More robust importing for joblib
  • Loading branch information
WardLT committed Aug 24, 2020
2 parents 6f561fa + 025f9b5 commit 6b88cb2
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 16 deletions.
2 changes: 1 addition & 1 deletion dlhub_sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .client import DLHubClient # noqa F401 (import unused)
from .version import __version__
from .version import __version__ # noqa F401 (import unused)
17 changes: 15 additions & 2 deletions dlhub_sdk/models/servables/sklearn.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
from dlhub_sdk.models.servables.python import BasePythonServableModel
from sklearn.base import is_classifier
from sklearn.pipeline import Pipeline
from sklearn.externals import joblib
import sklearn.base as sklbase
import pickle as pkl
import inspect

# Get a version of joblib
try:
from sklearn.externals import joblib
except ImportError:
try:
import joblib
except ImportError:
joblib = None

# scikit-learn stores the version used to create a model in the pickle file,
# but deletes it before unpickling the object. This code intercepts the version
Expand Down Expand Up @@ -106,7 +113,13 @@ def _load_model(path, serialization_method):
with open(path, 'rb') as fp:
model = pkl.load(fp)
elif serialization_method == "joblib":
model = joblib.load(path)
if joblib is None:
raise ImportError('joblib was not installed')
try:
model = joblib.load(path)
except ModuleNotFoundError:
raise ValueError('Model saved with sklearn.external.joblib. '
'Please install sklearn version 0.19.2 or earlier')
else:
raise Exception('Unknown serialization method: {}'.format(serialization_method))

Expand Down
6 changes: 4 additions & 2 deletions dlhub_sdk/models/servables/tests/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,10 @@ def test_pickle(self):
'description': 'Predicted probabilities of being each iris species',
'shape': [None, 3]
}
assert output['servable']['methods']['run']['method_details']['class_name'].endswith('.SVC')
assert output['servable']['methods']['run']['method_details']['method_name'] == 'predict_proba'
assert (output['servable']['methods']['run']
['method_details']['class_name'].endswith('.SVC'))
assert (output['servable']['methods']['run']
['method_details']['method_name'] == 'predict_proba')

self.assertEqual([pickle_path], model.list_files())
validate_against_dlhub_schema(output, 'servable')
Expand Down
27 changes: 16 additions & 11 deletions dlhub_sdk/models/servables/tests/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
import unittest
import os

import numpy as np
from sklearn import __version__ as skversion
import numpy as np

from dlhub_sdk.utils.schemas import validate_against_dlhub_schema
from dlhub_sdk.models.servables.sklearn import ScikitLearnModel
from dlhub_sdk.version import __version__


_year = str(datetime.now().year)
Expand Down Expand Up @@ -70,13 +69,19 @@ def test_regression(self):
model_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'model-lr.pkl'))

# Load the model
model_info = ScikitLearnModel.create_model(model_path, n_input_columns=2,
serialization_method='joblib',
classes=np.array(['number']))
if skversion > '0.23':
self.assertRaises(ValueError, ScikitLearnModel.create_model, model_path,
n_input_columns=2, serialization_method='joblib',
classes=np.array(['number']))
else:
model_info = ScikitLearnModel.create_model(model_path, n_input_columns=2,
serialization_method='joblib',
classes=np.array(['number']))

# Check that the metadata is as expected
self.assertEqual(model_info["servable"]["methods"]["run"]["method_details"]["method_name"],
"predict")
self.assertEqual([model_path], model_info.list_files())
self.assertEqual(['number'], model_info["servable"]["options"]["classes"])
self.assertEqual([None], model_info["servable"]["methods"]["run"]['output']['shape'])
# Check that the metadata is as expected
self.assertEqual(model_info["servable"]["methods"]["run"]
["method_details"]["method_name"],
"predict")
self.assertEqual([model_path], model_info.list_files())
self.assertEqual(['number'], model_info["servable"]["options"]["classes"])
self.assertEqual([None], model_info["servable"]["methods"]["run"]['output']['shape'])

0 comments on commit 6b88cb2

Please sign in to comment.