-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #67 from DLHub-Argonne/pytorch
basic pytorch support
- Loading branch information
Showing
6 changed files
with
228 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import torch | ||
|
||
from dlhub_sdk.models.servables.python import BasePythonServableModel | ||
from dlhub_sdk.utils.types import compose_argument_block | ||
|
||
|
||
class TorchModel(BasePythonServableModel): | ||
"""Servable based on a Torch Model object. | ||
Assumes that the model has been saved to a pt or a pth file""" | ||
|
||
@classmethod | ||
def create_model(cls, model_path, input_shape, output_shape): | ||
"""Initialize a PyTorch model. | ||
Args: | ||
model_path (string): Path to the pt or pth file that contains the weights and | ||
the architecture | ||
input_shape (list): Shape of input matrix to model | ||
output_shape (list): Shape of output matrix from model | ||
""" | ||
output = super(TorchModel, cls).create_model('__call__') | ||
|
||
# Add model as a file to be sent | ||
output.add_file(model_path, 'model') | ||
|
||
# Get the model details | ||
if model_path.endswith('.pt') or model_path.endswith('.pth'): | ||
model = torch.load(model_path) | ||
else: | ||
raise ValueError('File type for architecture not recognized') | ||
|
||
# Get the inputs of the model | ||
output['servable']['methods']['run']['input'] = output.format_layer_spec(input_shape) | ||
output['servable']['methods']['run']['output'] = output.format_layer_spec(output_shape) | ||
|
||
output['servable']['model_summary'] = str(model) | ||
output['servable']['model_type'] = 'Deep NN' | ||
|
||
# Add torch as a dependency | ||
output.add_requirement('torch', torch.__version__) | ||
|
||
return output | ||
|
||
def format_layer_spec(self, layers): | ||
"""Make a description of a list of input or output layers | ||
Args: | ||
layers (tuple or [tuple]): Shape of the layers | ||
Return: | ||
(dict) Description of the inputs / outputs | ||
""" | ||
if isinstance(layers, tuple): | ||
return compose_argument_block("ndarray", "Tensor", shape=list(layers)) | ||
else: | ||
return compose_argument_block("tuple", "Tuple of tensors", | ||
element_types=[self.format_layer_spec(i) for i in layers]) | ||
|
||
def _get_handler(self): | ||
return "torch.TorchServable" | ||
|
||
def _get_type(self): | ||
return "Torch Model" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
class Net(nn.Module): | ||
def __init__(self): | ||
super(Net, self).__init__() | ||
self.conv1 = nn.Conv2d(1, 20, 5, 1) | ||
self.conv2 = nn.Conv2d(20, 50, 5, 1) | ||
self.fc1 = nn.Linear(4*4*50, 500) | ||
self.fc2 = nn.Linear(500, 10) | ||
|
||
def forward(self, x): | ||
x = F.relu(self.conv1(x)) | ||
x = F.max_pool2d(x, 2, 2) | ||
x = F.relu(self.conv2(x)) | ||
x = F.max_pool2d(x, 2, 2) | ||
x = x.view(-1, 4*4*50) | ||
x = F.relu(self.fc1(x)) | ||
x = self.fc2(x) | ||
return F.log_softmax(x, dim=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,80 @@ | ||
from datetime import datetime | ||
from tempfile import mkdtemp | ||
import shutil | ||
import os | ||
|
||
from unittest import TestCase | ||
|
||
import torch | ||
from Net import Net | ||
|
||
from dlhub_sdk.models.servables.pytorch import TorchModel | ||
from dlhub_sdk.utils.schemas import validate_against_dlhub_schema | ||
from dlhub_sdk.version import __version__ | ||
|
||
_year = str(datetime.now().year) | ||
|
||
|
||
def _make_simple_model(): | ||
model = Net() | ||
return model | ||
|
||
|
||
class TestTorch(TestCase): | ||
maxDiff = 4096 | ||
|
||
def test_torch_single_input(self): | ||
# Make a Keras model | ||
model = _make_simple_model() | ||
|
||
# Save it to disk | ||
tempdir = mkdtemp() | ||
try: | ||
model_path = os.path.join(tempdir, 'model.pt') | ||
torch.save(model, model_path) | ||
|
||
# Create a model | ||
metadata = TorchModel.create_model(model_path, (2, 4), (3, 5)) | ||
metadata.set_title('Torch Test') | ||
metadata.set_name('mlp') | ||
|
||
output = metadata.to_dict() | ||
self.assertEqual(output, { | ||
"datacite": {"creators": [], "titles": [{"title": "Torch Test"}], | ||
"publisher": "DLHub", "publicationYear": _year, | ||
"identifier": {"identifier": "10.YET/UNASSIGNED", | ||
"identifierType": "DOI"}, | ||
"resourceType": {"resourceTypeGeneral": "InteractiveResource"}, | ||
"descriptions": [], | ||
"fundingReferences": [], | ||
"relatedIdentifiers": [], | ||
"alternateIdentifiers": [], | ||
"rightsList": []}, | ||
"dlhub": {"version": __version__, "domains": [], | ||
"visible_to": ["public"], | ||
'type': 'servable', | ||
"name": "mlp", "files": {"model": model_path}, | ||
"dependencies": {"python": { | ||
'torch': torch.__version__ | ||
}}}, | ||
"servable": {"methods": {"run": { | ||
"input": {"type": "ndarray", "description": "Tensor", "shape": [2, 4]}, | ||
"output": {"type": "ndarray", "description": "Tensor", | ||
"shape": [3, 5]}, "parameters": {}, | ||
"method_details": { | ||
"method_name": "__call__" | ||
}}}, | ||
"type": "Torch Model", | ||
"shim": "torch.TorchServable", | ||
"model_type": "Deep NN", | ||
"model_summary": """Net( | ||
(conv1): Conv2d(1, 20, kernel_size=(5, 5), stride=(1, 1)) | ||
(conv2): Conv2d(20, 50, kernel_size=(5, 5), stride=(1, 1)) | ||
(fc1): Linear(in_features=800, out_features=500, bias=True) | ||
(fc2): Linear(in_features=500, out_features=10, bias=True) | ||
)"""}}) | ||
|
||
# Validate against schema | ||
validate_against_dlhub_schema(output, 'servable') | ||
finally: | ||
shutil.rmtree(tempdir) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,3 +11,4 @@ scipy>=0.19.1 | |
sphinx_rtd_theme>=0.4.2 | ||
tensorflow>=1.8.0 | ||
mdf_toolbox>=0.4.0 | ||
torch>=1.1.0 |