Skip to content

Commit

Permalink
Merge pull request #105 from DLHub-Argonne/update_skl
Browse files Browse the repository at this point in the history
Support newer versions of sklearn
  • Loading branch information
WardLT committed Dec 22, 2020
2 parents 5d615d7 + eda769c commit d920615
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 9 deletions.
Binary file removed dlhub_sdk/models/servables/tests/model.pkl
Binary file not shown.
19 changes: 12 additions & 7 deletions dlhub_sdk/models/servables/tests/test_python.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,26 @@
from sklearn import __version__ as skl_version
from numpy import __version__ as numpy_version
from datetime import datetime
import pickle as pkl
import unittest
import math
import os

_year = str(datetime.now().year)

_pickle_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'pickle.pkl'))


class TestPythonModels(unittest.TestCase):
maxDiff = 4096

def test_pickle(self):
pickle_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'model.pkl'))
def setUp(self):
with open(_pickle_path, 'wb') as fp:
pkl.dump(PythonClassMethodModel(), fp)

def test_pickle(self):
# Make the model
model = PythonClassMethodModel.create_model(pickle_path, 'predict_proba', {'fake': 'kwarg'})
model = PythonClassMethodModel.create_model(_pickle_path, 'to_dict', {'fake': 'kwarg'})
model.set_title('Python example').set_name("class_method")

# Make sure it throws value errors if inputs are not set
Expand All @@ -46,7 +51,7 @@ def test_pickle(self):

# Check the model output
output = model.to_dict()
assert output['dlhub']['files'] == {'pickle': pickle_path}
assert output['dlhub']['files'] == {'pickle': _pickle_path}
assert output['dlhub']['dependencies']['python'] == {
'scikit-learn': skl_version,
'numpy': numpy_version,
Expand All @@ -65,11 +70,11 @@ def test_pickle(self):
'shape': [None, 3]
}
assert (output['servable']['methods']['run']
['method_details']['class_name'].endswith('.SVC'))
['method_details']['class_name'].endswith('.PythonClassMethodModel'))
assert (output['servable']['methods']['run']
['method_details']['method_name'] == 'predict_proba')
['method_details']['method_name'] == 'to_dict')

self.assertEqual([pickle_path], model.list_files())
self.assertEqual([_pickle_path], model.list_files())
validate_against_dlhub_schema(output, 'servable')

def test_function(self):
Expand Down
13 changes: 11 additions & 2 deletions dlhub_sdk/models/servables/tests/test_sklearn.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from datetime import datetime
import pickle as pkl
import unittest
import os

from sklearn.svm import SVC
from sklearn import __version__ as skversion
import numpy as np

Expand All @@ -12,11 +14,18 @@
_year = str(datetime.now().year)


_svm_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'model.pkl'))


class TestSklearn(unittest.TestCase):
maxDiff = 4096

def setUp(self):
with open(_svm_path, 'wb') as fp:
pkl.dump(SVC(probability=True), fp)

def test_load_model(self):
model_path = os.path.abspath(os.path.join(os.path.dirname(__file__), 'model.pkl'))
model_path = _svm_path

# Load the model
model_info = ScikitLearnModel.create_model(model_path, n_input_columns=4, classes=3)
Expand All @@ -27,7 +36,7 @@ def test_load_model(self):

# Test key components
assert metadata['dlhub']['dependencies']['python'] == {
'scikit-learn': '0.19.1' # The version used to save the model
'scikit-learn': skversion
}
assert metadata['servable']['shim'] == 'sklearn.ScikitLearnServable'
assert metadata['servable']['model_type'] == 'SVC'
Expand Down

0 comments on commit d920615

Please sign in to comment.