Skip to content

Commit

Permalink
Merge pull request #74 from DLHub-Argonne/pytorch
Browse files Browse the repository at this point in the history
Pytorch improvements
  • Loading branch information
WardLT committed Jul 25, 2019
2 parents aaab524 + 3eaffae commit 02dbbd1
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 12 deletions.
22 changes: 14 additions & 8 deletions dlhub_sdk/models/servables/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@ class TorchModel(BasePythonServableModel):
Assumes that the model has been saved to a pt or a pth file"""

@classmethod
def create_model(cls, model_path, input_shape, output_shape):
def create_model(cls, model_path, input_shape, output_shape, input_type='float', output_type='float'):
"""Initialize a PyTorch model.
Args:
model_path (string): Path to the pt or pth file that contains the weights and
the architecture
input_shape (list): Shape of input matrix to model
output_shape (list): Shape of output matrix from model
input_shape (tuple or [tuple]): Shape of input matrix to model
output_shape (tuple or [tuple]): Shape of output matrix from model
input_type (str or [str]): Data type of inputs
output_type (str or [str]): Data type of outputs
"""
output = super(TorchModel, cls).create_model('__call__')

Expand All @@ -31,8 +33,8 @@ def create_model(cls, model_path, input_shape, output_shape):
raise ValueError('File type for architecture not recognized')

# Get the inputs of the model
output['servable']['methods']['run']['input'] = output.format_layer_spec(input_shape)
output['servable']['methods']['run']['output'] = output.format_layer_spec(output_shape)
output['servable']['methods']['run']['input'] = output.format_layer_spec(input_shape, input_type)
output['servable']['methods']['run']['output'] = output.format_layer_spec(output_shape, output_type)

output['servable']['model_summary'] = str(model)
output['servable']['model_type'] = 'Deep NN'
Expand All @@ -42,19 +44,23 @@ def create_model(cls, model_path, input_shape, output_shape):

return output

def format_layer_spec(self, layers):
def format_layer_spec(self, layers, datatypes):
"""Make a description of a list of input or output layers
Args:
layers (tuple or [tuple]): Shape of the layers
datatypes (str or [str]): Data type of each input layer
Return:
(dict) Description of the inputs / outputs
"""
if isinstance(layers, tuple):
return compose_argument_block("ndarray", "Tensor", shape=list(layers))
return compose_argument_block("ndarray", "Tensor", shape=list(layers), item_type=datatypes)
else:
if isinstance(datatypes, str):
datatypes = [datatypes] * len(layers)
return compose_argument_block("tuple", "Tuple of tensors",
element_types=[self.format_layer_spec(i) for i in layers])
element_types=[self.format_layer_spec(i, t)
for i, t in zip(layers, datatypes)])

def _get_handler(self):
return "torch.TorchServable"
Expand Down
46 changes: 43 additions & 3 deletions dlhub_sdk/models/servables/tests/test_pytorch.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from datetime import datetime
from tempfile import mkdtemp
from tempfile import mkdtemp, TemporaryDirectory
import shutil
import os

from unittest import TestCase

import torch
from torch import nn
from Net import Net

from dlhub_sdk.models.servables.pytorch import TorchModel
Expand All @@ -20,6 +21,16 @@ def _make_simple_model():
return model


class MultiNetwork(nn.Module):

def __init__(self):
super().__init__()
self.layer = nn.Linear(4, 1)

def forward(self, x, y):
return self.layer(x), self.layer(y)


class TestTorch(TestCase):
maxDiff = 4096

Expand Down Expand Up @@ -58,9 +69,10 @@ def test_torch_single_input(self):
'torch': torch.__version__
}}},
"servable": {"methods": {"run": {
"input": {"type": "ndarray", "description": "Tensor", "shape": [2, 4]},
"input": {"type": "ndarray", "description": "Tensor", "shape": [2, 4],
"item_type": {"type": "float"}},
"output": {"type": "ndarray", "description": "Tensor",
"shape": [3, 5]}, "parameters": {},
"shape": [3, 5], "item_type": {"type": "float"}}, "parameters": {},
"method_details": {
"method_name": "__call__"
}}},
Expand All @@ -78,3 +90,31 @@ def test_torch_single_input(self):
validate_against_dlhub_schema(output, 'servable')
finally:
shutil.rmtree(tempdir)

def test_multinetwork(self):
model = MultiNetwork()

with TemporaryDirectory() as tp:
model_path = os.path.join(tp, 'model.pth')
torch.save(model, model_path)

metadata = TorchModel.create_model(model_path, [(None, 4)]*2, [(None, 1)]*2,
input_type='float', output_type=['float', 'float'])
metadata.set_name('t').set_title('t')

# Test the output shapes
self.assertEqual(metadata['servable']['methods']['run']['input'],
{'type': 'tuple', 'description': 'Tuple of tensors',
'element_types': [{'type': 'ndarray', 'description': 'Tensor', 'shape': [None, 4],
'item_type': {'type': 'float'}},
{'type': 'ndarray', 'description': 'Tensor', 'shape': [None, 4],
'item_type': {'type': 'float'}}]})
self.assertEqual(metadata['servable']['methods']['run']['output'],
{'type': 'tuple', 'description': 'Tuple of tensors',
'element_types': [{'type': 'ndarray', 'description': 'Tensor', 'shape': [None, 1],
'item_type': {'type': 'float'}},
{'type': 'ndarray', 'description': 'Tensor', 'shape': [None, 1],
'item_type': {'type': 'float'}}]})

validate_against_dlhub_schema(metadata.to_dict(), 'servable')

6 changes: 5 additions & 1 deletion docs/servable-types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,11 @@ default values::

but the model is ready to be served without any modifications.

The SDK also determines the version of Torch on your system, and saves that in the requirements.cd
In some cases, you may need to specify the data types of your input array(s) using the keyword arguments of ``create_model``.
The type specifications are needed because PyTorch does not do type casting automatically.
If in doubt, the data type is ``float`` and you can use the default settings.

The SDK also determines the version of Torch on your system, and saves that in the requirements.

TensorFlow Graphs
-----------------
Expand Down

0 comments on commit 02dbbd1

Please sign in to comment.