Skip to content

Commit

Permalink
Merge pull request open-mmlab#121 from openvinotoolkit/dp/hide_nncf_i…
Browse files Browse the repository at this point in the history
…mports

Hide NNCF imports
  • Loading branch information
druzhkov-paul committed Apr 23, 2021
2 parents da1efa0 + e38da35 commit fa55c70
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 55 deletions.
44 changes: 15 additions & 29 deletions mmdet/integration/nncf/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,35 +10,6 @@
from .utils import (check_nncf_is_enabled, get_nncf_version, is_nncf_enabled,
load_checkpoint, no_nncf_trace)

if is_nncf_enabled():
try:
from nncf import (NNCFConfig, create_compressed_model,
register_default_init_args)
from nncf.dynamic_graph.input_wrapping import nncf_model_input
from nncf.initialization import InitializingDataLoader
from nncf.nncf_network import NNCFNetwork

class_InitializingDataLoader = InitializingDataLoader
except ImportError:
raise RuntimeError(
'Cannot import the standard functions of NNCF library '
'-- most probably, incompatible version of NNCF. '
'Please, use NNCF version pointed in the documentation.')
else:
class DummyInitializingDataLoader:
pass


class_InitializingDataLoader = DummyInitializingDataLoader


class MMInitializeDataLoader(class_InitializingDataLoader):
def get_inputs(self, dataloader_output):
# redefined InitializingDataLoader because
# of DataContainer format in mmdet
kwargs = {k: v.data[0] for k, v in dataloader_output.items()}
return (), kwargs


def get_nncf_metadata():
"""
Expand Down Expand Up @@ -114,7 +85,21 @@ def wrap_nncf_model(model,
Note that the parameter `get_fake_input_func` should be the function `get_fake_input`
-- cannot import this function here explicitly
"""

check_nncf_is_enabled()

from nncf import (NNCFConfig, create_compressed_model,
register_default_init_args)
from nncf.dynamic_graph.io_handling import nncf_model_input
from nncf.initialization import InitializingDataLoader

class MMInitializeDataLoader(InitializingDataLoader):
def get_inputs(self, dataloader_output):
# redefined InitializingDataLoader because
# of DataContainer format in mmdet
kwargs = {k: v.data[0] for k, v in dataloader_output.items()}
return (), kwargs

pathlib.Path(cfg.work_dir).mkdir(parents=True, exist_ok=True)
nncf_config = NNCFConfig(cfg.nncf_config)
logger = get_root_logger(cfg.log_level)
Expand Down Expand Up @@ -274,6 +259,7 @@ def run_hacked_export_quantization(self, x):
def get_uncompressed_model(module):
if not is_nncf_enabled():
return module
from nncf.nncf_network import NNCFNetwork
if isinstance(module, NNCFNetwork):
return module.get_nncf_wrapped_model()
return module
35 changes: 10 additions & 25 deletions mmdet/integration/nncf/utils.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
import torch
import importlib
from collections import OrderedDict
from contextlib import contextmanager

try:
import nncf
_is_nncf_enabled = True
except ImportError:
_is_nncf_enabled = False
except RuntimeError as _e:
_is_nncf_enabled = False
print('Attention: RuntimeError happened when tried to import nncf')
print(' The reason may be in absent CUDA devices')
print(' RuntimeError:')
print(' ' + str(_e), flush=True)
import torch


_is_nncf_enabled = importlib.util.find_spec('nncf') is not None


def is_nncf_enabled():
return _is_nncf_enabled
Expand All @@ -28,22 +20,10 @@ def check_nncf_is_enabled():
def get_nncf_version():
if not is_nncf_enabled():
return None
import nncf
return nncf.__version__


if is_nncf_enabled():
try:
from nncf import load_state
from nncf.dynamic_graph.context import get_current_context
from nncf.dynamic_graph.context import \
no_nncf_trace as original_no_nncf_trace
except ImportError:
raise RuntimeError(
'Cannot import the standard functions of NNCF library '
'-- most probably, incompatible version of NNCF. '
'Please, use NNCF version pointed in the documentation.')


def load_checkpoint(model, filename, map_location=None, strict=False):
"""Load checkpoint from a file or URI.
Expand All @@ -57,6 +37,8 @@ def load_checkpoint(model, filename, map_location=None, strict=False):
Returns:
dict or OrderedDict: The loaded checkpoint.
"""
from nncf import load_state

checkpoint = torch.load(filename, map_location=map_location)
# get state_dict from checkpoint
if isinstance(checkpoint, OrderedDict):
Expand Down Expand Up @@ -84,6 +66,7 @@ def no_nncf_trace():
"""

if is_nncf_enabled():
from nncf.dynamic_graph.context import no_nncf_trace as original_no_nncf_trace
return original_no_nncf_trace()
return nullcontext()

Expand All @@ -92,6 +75,8 @@ def is_in_nncf_tracing():
if not is_nncf_enabled():
return False

from nncf.dynamic_graph.context import get_current_context

ctx = get_current_context()

if ctx is None:
Expand Down
2 changes: 1 addition & 1 deletion tools/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from mmdet.apis import get_fake_input, init_detector
from mmdet.integration.nncf import (check_nncf_is_enabled,
get_nncf_config_from_meta,
get_uncompressed_model, is_checkpoint_nncf,
is_checkpoint_nncf,
wrap_nncf_model)
from mmdet.models import detectors
from mmdet.utils import ExtendedDictAction
Expand Down

0 comments on commit fa55c70

Please sign in to comment.