Skip to content

Commit

Permalink
[NPU] fix FLAGS_npu_storage_format flag in python, test=develop (#48976)
Browse files Browse the repository at this point in the history
  • Loading branch information
qili93 committed Dec 13, 2022
1 parent 29d9dbe commit f3982a9
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 14 deletions.
2 changes: 0 additions & 2 deletions paddle/phi/core/flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1040,7 +1040,6 @@ PADDLE_DEFINE_EXPORTED_string(jit_engine_type,
"Predictor",
"Choose default funciton type in JitLayer.");

#ifdef PADDLE_WITH_CUSTOM_DEVICE
/**
* Custom Device NPU related FLAG
* Name: FLAGS_npu_storage_format
Expand All @@ -1050,7 +1049,6 @@ PADDLE_DEFINE_EXPORTED_string(jit_engine_type,
* Note: Enable NPU Storage Format for Ascend910 performance improvement.
*/
PADDLE_DEFINE_EXPORTED_bool(npu_storage_format, false, "");
#endif

#ifdef PADDLE_WITH_CUDNN_FRONTEND
/**
Expand Down
5 changes: 2 additions & 3 deletions python/paddle/fluid/dygraph/varbase_patch_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import inspect
import numpy as np
import warnings
Expand Down Expand Up @@ -42,6 +41,7 @@
from paddle.profiler.utils import in_profiler_mode
from paddle import _C_ops, _legacy_C_ops
from paddle.device import get_all_custom_device_type
from paddle.fluid.framework import _global_flags

_grad_scalar = None

Expand Down Expand Up @@ -381,8 +381,7 @@ def gradient(self):
new_ivar = self._grad_ivar()
# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if (
os.environ.get('FLAGS_npu_storage_format', None)
in [1, '1', True, 'True', 'true']
_global_flags()['FLAGS_npu_storage_format']
and 'npu' in get_all_custom_device_type()
):
new_ivar = paddle.incubate._npu_identity(x=new_ivar, format=-1)
Expand Down
8 changes: 2 additions & 6 deletions python/paddle/nn/functional/conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import os

from paddle import _C_ops, _legacy_C_ops, get_flags, in_dynamic_mode
from paddle.device import (
get_all_custom_device_type,
Expand Down Expand Up @@ -152,8 +150,7 @@ def _conv_nd(
bias = bias.reshape(new_shape)
# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if (
os.environ.get('FLAGS_npu_storage_format', None)
in [1, '1', True, 'True', 'true']
_global_flags()['FLAGS_npu_storage_format']
and 'npu' in get_all_custom_device_type()
):
with no_grad():
Expand Down Expand Up @@ -753,8 +750,7 @@ def conv2d(
)
# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if (
os.environ.get('FLAGS_npu_storage_format', None)
in [1, '1', True, 'True', 'true']
_global_flags()['FLAGS_npu_storage_format']
and 'npu' in get_all_custom_device_type()
):
with no_grad():
Expand Down
4 changes: 1 addition & 3 deletions python/paddle/nn/layer/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
# TODO: define normalization api

import numbers
import os
import warnings

import numpy as np
Expand Down Expand Up @@ -688,8 +687,7 @@ def __init__(

# TODO(qili93): temporary for ascned npu performance to be removed along with npu_identity op
if (
os.environ.get('FLAGS_npu_storage_format', None)
in [1, '1', True, 'True', 'true']
_global_flags()['FLAGS_npu_storage_format']
and 'npu' in get_all_custom_device_type()
):
with no_grad():
Expand Down

0 comments on commit f3982a9

Please sign in to comment.