Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St] transforms.Resize Support static mode #49008

Closed
wants to merge 1 commit into from

Conversation

Aurelius84
Copy link
Contributor

@Aurelius84 Aurelius84 commented Dec 12, 2022

PR types

New features

PR changes

OPs

Describe

vision.transform变换类API新增支持静态图的样板间PR,供 #48612 参考。

支持静态图分支,实现API行为动静统一主要包括如下三个核心步骤:

step 1:tensor 判断逻辑升级

在 Resize API 中使用了 _is_tensor_image 函数。在静态图下,需要此函数支持 Variable 类型判断,同样返回 True 即可:

def _is_tensor_image(img):
    """
    Return True if img is a Tensor for dynamic mode or Variable for static mode.
    """
    return isinstance(img, (paddle.Tensor, Variable))

step 2:升级functional_tensor.py的核心接口

Resize 与Tensor相关的变换逻辑是在functional_tensor.py 中实现的,需要兼容适配下静态图逻辑。静态图下是通过append_op添加算子实现组网的,此部分逻辑大多数已经封装在了框架公用API中,只需要微调适配下即可。

注:要额外注意静态图下可能出现动态shape的场景,如image.shape = [-1, -1, 3],此时根据具体的实现判断是否需要特殊处理或捕获报错

def resize(img, size, interpolation='bilinear', data_format='CHW'):
   
    _assert_image_tensor(img, data_format)   # <----- 此部分要适配静态图Variable类型

    if not (
        isinstance(size, int)
        or (isinstance(size, (tuple, list)) and len(size) == 2)
    ):
        raise TypeError('Got inappropriate size arg: {}'.format(size))

    if isinstance(size, int):
        w, h = _get_image_size(img, data_format)   # <----- 静态图下 w,h 可能为 -1,要小心处理
        # TODO(Aurelius84): In static mode, w and h will be -1 for dynamic shape.
        # We should consider to support this case in future.
        if w <= 0 or h <= 0:
            raise NotImplementedError(
                "Not support while w<=0 or h<=0, but received w={}, h={}".format(
                    w, h
                )
            )
        if (w <= h and w == size) or (h <= w and h == size):
            return img
        if w < h:
            ow = size
            oh = int(size * h / w)
        else:
            oh = size
            ow = int(size * w / h)
    else:
        oh, ow = size

    img = img.unsqueeze(0)   # <---- 此接口已经是动静统一了,底层会自动走静态图append_op 分支
    img = F.interpolate(           # <---- 此接口已经是动静统一了,底层会自动走静态图append_op 分支
        img,
        size=(oh, ow),
        mode=interpolation.lower(),
        data_format='N' + data_format.upper(),
    )

    return img.squeeze(0)    # <---- 此接口已经是动静统一了,底层会自动走静态图append_op 分支

step 3:添加相应单测,确保静态图执行结果与动态图一致

可以统一添加到 test_transforms_static.py 文件里,统一继承TestTransformUnitTestBase基类即可。
对于新增单测,仅需要设置api信息即可,如有新需求,可扩展TestTransformUnitTestBase基类接口:

class TestResize(TestTransformUnitTestBase):
    def set_trans_api(self):
        self.api = transforms.Resize(size=(16, 16))



# 基类接口:
class TestTransformUnitTestBase(unittest.TestCase):
    def setUp(self):
        self.img = (np.random.rand(*self.get_shape()) * 255.0).astype(
            np.float32
        )
        self.set_trans_api()

    def get_shape(self):
        return (64, 64, 3)

    def set_trans_api(self):
        self.api = transforms.Resize(size=16)

    def dynamic_transform(self):
        paddle.seed(SEED)

        img_t = paddle.to_tensor(self.img)
        return self.api(img_t)

    def static_transform(self):
        paddle.enable_static()
        paddle.seed(SEED)

        main_program = paddle.static.Program()
        with paddle.static.program_guard(main_program):
            x = paddle.static.data(
                shape=self.get_shape(), dtype=paddle.float32, name='img'
            )
            out = self.api(x)

        exe = paddle.static.Executor()
        res = exe.run(main_program, fetch_list=[out], feed={'img': self.img})

        paddle.disable_static()
        return res[0]

    def test_transform(self):
        dy_res = self.dynamic_transform()
        st_res = self.static_transform()

        np.testing.assert_almost_equal(dy_res, st_res)

其他说明

1. 注意 API 入口逻辑分流

其他vision.transform 的API入口逻辑可能相对复杂,可以考虑在入口函数进行分流,如:

# 原始代码
class RandomHorizontalFlip(BaseTransform):

    def __init__(self, prob=0.5, keys=None):
        super().__init__(keys)
        assert 0 <= prob <= 1, "probability must be between 0 and 1"
        self.prob = prob

    def _apply_image(self, img):
        if random.random() < self.prob:
            return F.hflip(img)
        return img


# 修改思路
class RandomHorizontalFlip(BaseTransform):

    def __init__(self, prob=0.5, keys=None):
        super().__init__(keys)
        assert 0 <= prob <= 1, "probability must be between 0 and 1"
        self.prob = prob

    def _apply_image(self, img):
        if in_dynamic_mode():
             return self._dynamic_apply_image(img)
        else:
             return self._static_apply_image(img)
    
    def _dynamic_apply_image(self, img):
        if random.random() < self.prob:
            return F.hflip(img)
        return img

    def _static_apply_image(self, img):
        return  paddle.static.nn.cond(paddle.rand([1]) < self.prob, lambda : F.hflip(img), lambda: img)

2. 注意静态图下动态shape的适配或报错

@paddle-bot
Copy link

paddle-bot bot commented Dec 12, 2022

你的PR提交成功,感谢你对开源项目的贡献!
请关注后续CI自动化测试结果,详情请参考Paddle-CI手册
Your PR has been submitted. Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@luotao1 luotao1 changed the title [API] transforms.Resize Support static mode [Dy2St] transforms.Resize Support static mode Dec 13, 2022
Copy link
Member

@SigureMo SigureMo left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM~~~

@Aurelius84
Copy link
Contributor Author

#49024 已包含此PR内容,故closed掉

@Aurelius84 Aurelius84 closed this Dec 13, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants