Skip to content

Commit

Permalink
Merge pull request #67 from DLHub-Argonne/pytorch
Browse files Browse the repository at this point in the history
basic pytorch support
  • Loading branch information
WardLT committed Jul 25, 2019
2 parents 5b61452 + 865abf4 commit 87e55df
Show file tree
Hide file tree
Showing 6 changed files with 228 additions and 0 deletions.
63 changes: 63 additions & 0 deletions dlhub_sdk/models/servables/pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch

from dlhub_sdk.models.servables.python import BasePythonServableModel
from dlhub_sdk.utils.types import compose_argument_block


class TorchModel(BasePythonServableModel):
"""Servable based on a Torch Model object.
Assumes that the model has been saved to a pt or a pth file"""

@classmethod
def create_model(cls, model_path, input_shape, output_shape):
"""Initialize a PyTorch model.
Args:
model_path (string): Path to the pt or pth file that contains the weights and
the architecture
input_shape (list): Shape of input matrix to model
output_shape (list): Shape of output matrix from model
"""
output = super(TorchModel, cls).create_model('__call__')

# Add model as a file to be sent
output.add_file(model_path, 'model')

# Get the model details
if model_path.endswith('.pt') or model_path.endswith('.pth'):
model = torch.load(model_path)
else:
raise ValueError('File type for architecture not recognized')

# Get the inputs of the model
output['servable']['methods']['run']['input'] = output.format_layer_spec(input_shape)
output['servable']['methods']['run']['output'] = output.format_layer_spec(output_shape)

output['servable']['model_summary'] = str(model)
output['servable']['model_type'] = 'Deep NN'

# Add torch as a dependency
output.add_requirement('torch', torch.__version__)

return output

def format_layer_spec(self, layers):
"""Make a description of a list of input or output layers
Args:
layers (tuple or [tuple]): Shape of the layers
Return:
(dict) Description of the inputs / outputs
"""
if isinstance(layers, tuple):
return compose_argument_block("ndarray", "Tensor", shape=list(layers))
else:
return compose_argument_block("tuple", "Tuple of tensors",
element_types=[self.format_layer_spec(i) for i in layers])

def _get_handler(self):
return "torch.TorchServable"

def _get_type(self):
return "Torch Model"
21 changes: 21 additions & 0 deletions dlhub_sdk/models/servables/tests/Net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch.nn as nn
import torch.nn.functional as F


class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, 500)
self.fc2 = nn.Linear(500, 10)

def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
80 changes: 80 additions & 0 deletions dlhub_sdk/models/servables/tests/test_pytorch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
from datetime import datetime
from tempfile import mkdtemp
import shutil
import os

from unittest import TestCase

import torch
from Net import Net

from dlhub_sdk.models.servables.pytorch import TorchModel
from dlhub_sdk.utils.schemas import validate_against_dlhub_schema
from dlhub_sdk.version import __version__

_year = str(datetime.now().year)


def _make_simple_model():
model = Net()
return model


class TestTorch(TestCase):
maxDiff = 4096

def test_torch_single_input(self):
# Make a Keras model
model = _make_simple_model()

# Save it to disk
tempdir = mkdtemp()
try:
model_path = os.path.join(tempdir, 'model.pt')
torch.save(model, model_path)

# Create a model
metadata = TorchModel.create_model(model_path, (2, 4), (3, 5))
metadata.set_title('Torch Test')
metadata.set_name('mlp')

output = metadata.to_dict()
self.assertEqual(output, {
"datacite": {"creators": [], "titles": [{"title": "Torch Test"}],
"publisher": "DLHub", "publicationYear": _year,
"identifier": {"identifier": "10.YET/UNASSIGNED",
"identifierType": "DOI"},
"resourceType": {"resourceTypeGeneral": "InteractiveResource"},
"descriptions": [],
"fundingReferences": [],
"relatedIdentifiers": [],
"alternateIdentifiers": [],
"rightsList": []},
"dlhub": {"version": __version__, "domains": [],
"visible_to": ["public"],
'type': 'servable',
"name": "mlp", "files": {"model": model_path},
"dependencies": {"python": {
'torch': torch.__version__
}}},
"servable": {"methods": {"run": {
"input": {"type": "ndarray", "description": "Tensor", "shape": [2, 4]},
"output": {"type": "ndarray", "description": "Tensor",
"shape": [3, 5]}, "parameters": {},
"method_details": {
"method_name": "__call__"
}}},
"type": "Torch Model",
"shim": "torch.TorchServable",
"model_type": "Deep NN",
"model_summary": """Net(
(conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1))
(conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1))
(fc1): Linear(in_features=800, out_features=500, bias=True)
(fc2): Linear(in_features=500, out_features=10, bias=True)
)"""}})

# Validate against schema
validate_against_dlhub_schema(output, 'servable')
finally:
shutil.rmtree(tempdir)
55 changes: 55 additions & 0 deletions docs/servable-types.rst
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,61 @@ but the model is ready to be served without any modifications.

The SDK also determines the version of Keras on your system, and saves that in the requirements.

PyTorch Models
--------------

**Model Class**: `TorchModel <source/dlhub_sdk.models.servables.html#dlhub_sdk.models.servables.pytorch.TorchModel>`_

DLHub serves PyTorch models using the .pt file saved using the ``torch.save`` function
(see `PyTorch FAQs <https://pytorch.org/tutorials/beginner/saving_loading_models.html>`_).
As an example, the description for a PyTorch model created using:

.. code-block:: python
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 20, 5, 1)
self.conv2 = nn.Conv2d(20, 50, 5, 1)
self.fc1 = nn.Linear(4*4*50, 500)
self.fc2 = nn.Linear(500, 10)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.max_pool2d(x, 2, 2)
x = F.relu(self.conv2(x))
x = F.max_pool2d(x, 2, 2)
x = x.view(-1, 4*4*50)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.log_softmax(x, dim=1)
model = Net()
torch.save(model, 'model.pt')
can be generated from the .pt file and the shapes of the input and output arrays.

.. code-block:: python
model_info = TorchModel.create_model('model.pt', (None, 1, 28, 28), (None, 10))
DLHub will need the definition for the ``Net`` module in order to load and run it.
You must add the Python libraries containing the module definitions as requirements,
or add the files defining the modules to the servable definition.

.. code-block:: python
model_info.add_file('Net.py')
As with Keras, we recommended changing the descriptions for the inputs and outputs from their
default values::

model_info['servable']['methods']['run']['output']['description'] = 'Response'

but the model is ready to be served without any modifications.

The SDK also determines the version of Torch on your system, and saves that in the requirements.cd

TensorFlow Graphs
-----------------

Expand Down
8 changes: 8 additions & 0 deletions docs/source/dlhub_sdk.models.servables.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ dlhub\_sdk\.models\.servables\.keras module
:undoc-members:
:show-inheritance:

dlhub\_sdk\.models\.servables\.pytorch module
---------------------------------------------

.. automodule:: dlhub_sdk.models.servables.pytorch
:members:
:undoc-members:
:show-inheritance:

dlhub\_sdk\.models\.servables\.python module
--------------------------------------------

Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,4 @@ scipy>=0.19.1
sphinx_rtd_theme>=0.4.2
tensorflow>=1.8.0
mdf_toolbox>=0.4.0
torch>=1.1.0

0 comments on commit 87e55df

Please sign in to comment.