-
Notifications
You must be signed in to change notification settings - Fork 657
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Import oneflow as torch #6076
Merged
Merged
Import oneflow as torch #6076
Changes from 9 commits
Commits
Show all changes
29 commits
Select commit
Hold shift + click to select a range
bafd0f6
add oneflow_pytorch_compatiblity_test
BBuf 6b205e1
add compatiblity test
BBuf c5d3c8a
align init model
BBuf a8e81d2
add alexnet
BBuf 2c62837
fix comments
BBuf 56b92d5
restruct code structure
BBuf 9b978a3
add resnet50 and restruct structure
BBuf fdd53a1
Delete loss_compare.png
jackalcooper b235352
Merge branch 'import_oneflow_as_torch' of https://github.com/Oneflow-…
jackalcooper 528a8e5
fix comment
BBuf 7c62def
make dataset and modelzoo read only
jackalcooper 261e82f
Merge branch 'import_oneflow_as_torch' of https://github.com/Oneflow-…
jackalcooper f3c4b1b
fix comments
BBuf 53eb985
Merge branch 'import_oneflow_as_torch' of github.com:Oneflow-Inc/onef…
BBuf d7b4e9c
refine
BBuf 09877fc
fix bug
BBuf 0f7b5a7
fix comments
BBuf bfb3631
Merge branch 'master' into import_oneflow_as_torch
BBuf 8e1a31f
Merge branch 'master' into import_oneflow_as_torch
BBuf 71f36bd
Merge branch 'master' into import_oneflow_as_torch
oneflow-ci-bot c1aa300
auto format by CI
oneflow-ci-bot 7eb8ff1
Merge branch 'master' into import_oneflow_as_torch
oneflow-ci-bot 434413b
Merge branch 'master' into import_oneflow_as_torch
oneflow-ci-bot 5b2b9a9
fix ci error
BBuf d06abb3
Merge branch 'import_oneflow_as_torch' of github.com:Oneflow-Inc/onef…
BBuf 00f956e
Merge branch 'master' into import_oneflow_as_torch
BBuf 9594f88
Merge branch 'master' into import_oneflow_as_torch
oneflow-ci-bot 0c6e340
Merge branch 'master' into import_oneflow_as_torch
BBuf be60e54
Merge branch 'master' into import_oneflow_as_torch
oneflow-ci-bot File filter
Filter by extension
Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
""" | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
import os | ||
import importlib.machinery | ||
|
||
|
||
def _download_file_from_remote_location(fpath: str, url: str) -> None: | ||
pass | ||
|
||
|
||
def _is_remote_location_available() -> bool: | ||
return False | ||
|
||
|
||
try: | ||
from torch.hub import load_state_dict_from_url | ||
except ImportError: | ||
from torch.utils.model_zoo import load_url as load_state_dict_from_url | ||
|
||
|
||
def _get_extension_path(lib_name): | ||
|
||
lib_dir = os.path.dirname(__file__) | ||
if os.name == "nt": | ||
# Register the main torchvision library location on the default DLL path | ||
import ctypes | ||
import sys | ||
|
||
kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) | ||
with_load_library_flags = hasattr(kernel32, "AddDllDirectory") | ||
prev_error_mode = kernel32.SetErrorMode(0x0001) | ||
|
||
if with_load_library_flags: | ||
kernel32.AddDllDirectory.restype = ctypes.c_void_p | ||
|
||
if sys.version_info >= (3, 8): | ||
os.add_dll_directory(lib_dir) | ||
elif with_load_library_flags: | ||
res = kernel32.AddDllDirectory(lib_dir) | ||
if res is None: | ||
err = ctypes.WinError(ctypes.get_last_error()) | ||
err.strerror += f' Error adding "{lib_dir}" to the DLL directories.' | ||
raise err | ||
|
||
kernel32.SetErrorMode(prev_error_mode) | ||
|
||
loader_details = ( | ||
importlib.machinery.ExtensionFileLoader, | ||
importlib.machinery.EXTENSION_SUFFIXES, | ||
) | ||
|
||
extfinder = importlib.machinery.FileFinder(lib_dir, loader_details) | ||
ext_specs = extfinder.find_spec(lib_name) | ||
if ext_specs is None: | ||
raise ImportError | ||
|
||
return ext_specs.origin |
30 changes: 30 additions & 0 deletions
30
python/oneflow/test/modules/oneflow_pytorch_compatiblity.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
""" | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
import os | ||
import sys | ||
|
||
test_util_parent_dir = os.path.dirname( | ||
os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | ||
) | ||
oneflow_test_utils_dir_from_env = os.getenv("ONEFLOW_TEST_UTILS_DIR") | ||
if oneflow_test_utils_dir_from_env: | ||
from pathlib import Path | ||
|
||
oneflow_test_utils_dir_from_env = Path(oneflow_test_utils_dir_from_env) | ||
test_util_parent_dir = str(oneflow_test_utils_dir_from_env.parent.absolute()) | ||
sys.path.append(test_util_parent_dir) | ||
|
||
from test_utils.oneflow_pytorch_compatiblity import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
""" | ||
Copyright 2020 The OneFlow Authors. All rights reserved. | ||
|
||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
|
||
http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. | ||
""" | ||
import torch | ||
import torch.nn as nn | ||
from _internally_replaced_utils import load_state_dict_from_url | ||
from typing import Any | ||
|
||
|
||
__all__ = ["AlexNet", "alexnet"] | ||
|
||
|
||
model_urls = { | ||
"alexnet": "https://download.pytorch.org/models/alexnet-owt-7be5be79.pth", | ||
} | ||
|
||
|
||
class AlexNet(nn.Module): | ||
def __init__(self, num_classes: int = 1000) -> None: | ||
super(AlexNet, self).__init__() | ||
self.features = nn.Sequential( | ||
nn.Conv2d(3, 64, kernel_size=11, stride=4, padding=2), | ||
nn.ReLU(inplace=True), | ||
nn.MaxPool2d(kernel_size=3, stride=2), | ||
nn.Conv2d(64, 192, kernel_size=5, padding=2), | ||
nn.ReLU(inplace=True), | ||
nn.MaxPool2d(kernel_size=3, stride=2), | ||
nn.Conv2d(192, 384, kernel_size=3, padding=1), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(384, 256, kernel_size=3, padding=1), | ||
nn.ReLU(inplace=True), | ||
nn.Conv2d(256, 256, kernel_size=3, padding=1), | ||
nn.ReLU(inplace=True), | ||
nn.MaxPool2d(kernel_size=3, stride=2), | ||
) | ||
self.avgpool = nn.AdaptiveAvgPool2d((6, 6)) | ||
self.classifier = nn.Sequential( | ||
nn.Dropout(), | ||
nn.Linear(256 * 6 * 6, 4096), | ||
nn.ReLU(inplace=True), | ||
nn.Dropout(), | ||
nn.Linear(4096, 4096), | ||
nn.ReLU(inplace=True), | ||
nn.Linear(4096, num_classes), | ||
) | ||
|
||
def forward(self, x: torch.Tensor) -> torch.Tensor: | ||
x = self.features(x) | ||
x = self.avgpool(x) | ||
x = torch.flatten(x, 1) | ||
x = self.classifier(x) | ||
return x | ||
|
||
|
||
def alexnet(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> AlexNet: | ||
r"""AlexNet model architecture from the | ||
`"One weird trick..." <https://arxiv.org/abs/1404.5997>`_ paper. | ||
The required minimum input size of the model is 63x63. | ||
Args: | ||
pretrained (bool): If True, returns a model pre-trained on ImageNet | ||
progress (bool): If True, displays a progress bar of the download to stderr | ||
""" | ||
model = AlexNet(**kwargs) | ||
if pretrained: | ||
state_dict = load_state_dict_from_url(model_urls["alexnet"], progress=progress) | ||
model.load_state_dict(state_dict) | ||
return model |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
note 去掉oneflow export之后这种hack的文件可以不用写了,直接import oneflow.test_utils
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已删除。