Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transfer oneflow_compile from onediff to oneflow #10404

Closed
Closed
1 change: 1 addition & 0 deletions python/oneflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -497,6 +497,7 @@ def atexit_hook(hook):
)
import oneflow.utils.data
import oneflow.framework.docstr as docstr
import oneflow.framework.infer_compiler as infer_compiler
import oneflow.cuda
import oneflow.multiprocessing
import oneflow.asyncs
Expand Down
17 changes: 17 additions & 0 deletions python/oneflow/framework/infer_compiler/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from .transform import *
from .with_oneflow_compile import compile_from_torch
17 changes: 17 additions & 0 deletions python/oneflow/framework/infer_compiler/import_tools/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
""" Tools for importing modules and packages"""
from .importer import LazyMocker, import_module_from_path
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import inspect
from types import FunctionType
from typing import Union


class MockEntityNameFormatter:
def __init__(self, prefix: str = "mock_", suffix: str = "_oflow"):
self.prefix = prefix
self.suffix = suffix

def _format_pkg_name(self, pkg_name: str) -> str:
if pkg_name.startswith(self.prefix) and pkg_name.endswith(self.suffix):
return pkg_name
return self.prefix + pkg_name + self.suffix

def _reverse_pkg_name(self, pkg_name: str) -> str:
assert pkg_name.startswith(self.prefix) and pkg_name.endswith(
self.suffix
), f"Package name must start with {self.prefix} and end with {self.suffix}, but got {pkg_name}"
return pkg_name[len(self.prefix) : -len(self.suffix)]

def _format_full_class_name(self, obj: Union[str, type, FunctionType]):
if isinstance(obj, type):
obj = f"{obj.__module__}.{obj.__qualname__}"

elif isinstance(obj, FunctionType):
module = inspect.getmodule(obj)
obj = f"{module.__name__}.{obj.__qualname__}"

assert isinstance(obj, str), f"obj must be str, but got {type(obj)}"

if "." in obj:
pkg_name, cls_name = obj.split(".", 1)
return f"{self._format_pkg_name(pkg_name)}.{cls_name}"
else:
return self._format_pkg_name(obj)

def format(self, entity: Union[str, type, FunctionType]) -> str:
return self._format_full_class_name(entity)

def unformat(self, mock_entity_name: str) -> str:
if "." in mock_entity_name:
pkg_name, cls_name = mock_entity_name.split(".", 1)
return f"{self._reverse_pkg_name(pkg_name)}.{cls_name}"
else: # mock_entity_name is a pkg_name
return self._reverse_pkg_name(mock_entity_name)
127 changes: 127 additions & 0 deletions python/oneflow/framework/infer_compiler/import_tools/importer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import importlib
import os
import sys
from importlib.metadata import requires
from pathlib import Path
from types import FunctionType, ModuleType
from typing import Optional, Union

from oneflow.mock_torch import DynamicMockModule

from .format_utils import MockEntityNameFormatter

__all__ = ["import_module_from_path", "LazyMocker", "is_need_mock"]


def is_need_mock(cls) -> bool:
assert isinstance(cls, (type, str))
main_pkg = cls.__module__.split(".")[0]
try:
pkgs = requires(main_pkg)
except Exception as e:
return True
if pkgs:
for pkg in pkgs:
pkg = pkg.split(" ")[0]
if pkg == "torch":
return True
return False
return True


def import_module_from_path(module_path: Union[str, Path]) -> ModuleType:
if isinstance(module_path, Path):
module_path = str(module_path)
module_name = os.path.basename(module_path)
if os.path.isfile(module_path):
sp = os.path.splitext(module_path)
module_name = sp[0]

if os.path.isfile(module_path):
module_spec = importlib.util.spec_from_file_location(module_name, module_path)
module_dir = os.path.split(module_path)[0]
else:
module_spec = importlib.util.spec_from_file_location(
module_name, os.path.join(module_path, "__init__.py")
)
module_dir = module_path

module = importlib.util.module_from_spec(module_spec)
sys.modules[module_name] = module
module_spec.loader.exec_module(module)
return module


class LazyMocker:
def __init__(self, prefix: str, suffix: str, tmp_dir: Optional[Union[str, Path]]):
self.prefix = prefix
self.suffix = suffix
self.tmp_dir = tmp_dir
self.mocked_packages = set()
self.cleanup_list = []

def mock_package(self, package: str):
pass

def cleanup(self):
pass

def get_mock_entity_name(self, entity: Union[str, type, FunctionType]):
formatter = MockEntityNameFormatter(prefix=self.prefix, suffix=self.suffix)
full_obj_name = formatter.format(entity)
return full_obj_name

def mock_entity(self, entity: Union[str, type, FunctionType]):
"""Mock the entity and return the mocked entity

Example:
>>> mocker = LazyMocker(prefix="mock_", suffix="_of", tmp_dir="tmp")
>>> mocker.mock_entity("models.DemoModel")
<class 'mock_models_of.DemoModel'>
>>> cls_obj = models.DemoModel
>>> mocker.mock_entity(cls_obj)
<class 'mock_models_of.DemoModel'>
"""
return self.load_entity_with_mock(entity)

def add_mocked_package(self, package: str):
if package in self.mocked_packages:
return

self.mocked_packages.add(package)
package = sys.modules.get(package, None)

# TODO remove code below
# fix the mock error in https://github.com/siliconflow/oneflow/blob/main/python/oneflow/mock_torch/mock_importer.py#L105-L118
if package and getattr(package, "__file__", None) is not None:
pkg_path = Path(package.__file__).parents[1]
if pkg_path not in sys.path:
sys.path.append(str(pkg_path))

def load_entity_with_mock(self, entity: Union[str, type, FunctionType]):
formatter = MockEntityNameFormatter(prefix=self.prefix, suffix=self.suffix)
full_obj_name = formatter.format(entity)
attrs = full_obj_name.split(".")

# add package path to sys.path to avoid mock error
self.add_mocked_package(attrs[0])

mock_pkg = DynamicMockModule.from_package(attrs[0], verbose=False)
for name in attrs[1:]:
mock_pkg = getattr(mock_pkg, name)
return mock_pkg
26 changes: 26 additions & 0 deletions python/oneflow/framework/infer_compiler/transform/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
"""
Copyright 2020 The OneFlow Authors. All rights reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
"""Module to convert PyTorch code to OneFlow."""
from .builtin_transform import (
ProxySubmodule,
default_converter,
get_attr,
map_args,
proxy_class,
torch2oflow,
)
from .custom_transform import register
from .manager import transform_mgr
Loading
Loading