From f8c44d242c925fef56e70f81fbd1ed1c17f24539 Mon Sep 17 00:00:00 2001 From: zkh2016 Date: Tue, 14 Jun 2022 11:58:26 +0000 Subject: [PATCH] add sparse SyncBatchNorm --- python/paddle/incubate/sparse/nn/__init__.py | 3 ++- python/paddle/incubate/sparse/nn/layer/norm.py | 4 +++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/python/paddle/incubate/sparse/nn/__init__.py b/python/paddle/incubate/sparse/nn/__init__.py index be4985e694b4b..e0d9b0d6524f8 100644 --- a/python/paddle/incubate/sparse/nn/__init__.py +++ b/python/paddle/incubate/sparse/nn/__init__.py @@ -15,7 +15,7 @@ from . import functional from .layer.activation import ReLU -from .layer.norm import BatchNorm +from .layer.norm import BatchNorm, SyncBatchNorm from .layer.conv import Conv3D from .layer.conv import SubmConv3D from .layer.pooling import MaxPool3D @@ -23,6 +23,7 @@ __all__ = [ 'ReLU', 'BatchNorm', + 'SyncBatchNorm', 'Conv3D', 'SubmConv3D', 'MaxPool3D', diff --git a/python/paddle/incubate/sparse/nn/layer/norm.py b/python/paddle/incubate/sparse/nn/layer/norm.py index cc094487c465c..5a35aaabe0076 100644 --- a/python/paddle/incubate/sparse/nn/layer/norm.py +++ b/python/paddle/incubate/sparse/nn/layer/norm.py @@ -27,6 +27,8 @@ import paddle import warnings +from paddle.nn.layer.norm import _BatchNormBase +from paddle.framework import no_grad class BatchNorm(paddle.nn.BatchNorm1D): @@ -171,7 +173,7 @@ def __init__(self, name=None): super(SyncBatchNorm, self).__init__(num_features, momentum, epsilon, weight_attr, - bias_attr, data_format, None, name) + bias_attr, data_format, name) def forward(self, x): out = super(SyncBatchNorm, self).forward(x.values())