Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[NewIR]Gen python c apis for new ir #56571

Merged
merged 7 commits into from Aug 25, 2023
Merged

Conversation

0x45f
Copy link
Contributor

@0x45f 0x45f commented Aug 23, 2023

PR types

Others

PR changes

Others

Description

  • 生成了static_op_function.h/cc文件
  • 对于没有可变attr的api,比如mean生成了如下代码:
PyObject *static_api_mean(PyObject *self, PyObject *args, PyObject *kwargs);

PyObject *static_api_mean(PyObject *self, PyObject *args, PyObject *kwargs) {
    try {
        VLOG(6) << "Add mean op into program";
        VLOG(8) << "args count: " << (PyTuple_Size(args) / 2);

        // Get OpResult from args
        PyObject *x_obj = PyTuple_GET_ITEM(args, 0);
        auto x = CastPyArg2OpResult(x_obj, "mean", 0);

        // Parse Attributes
        PyObject *axis_obj = PyTuple_GET_ITEM(args, 1);
        std::vector<int64_t> axis = CastPyArg2Longs(axis_obj, "mean", 1);
        PyObject *keepdim_obj = PyTuple_GET_ITEM(args, 2);
        bool keepdim = CastPyArg2Boolean(keepdim_obj, "mean", 2);

        // Call ir static api
        auto static_api_out = paddle::dialect::mean(x, axis, keepdim);

        return ToPyObject(static_api_out);
    } catch (...) {
        ThrowExceptionToPython(std::current_exception());
        return nullptr;
    }
}
  • 对于有可变的api,比如eye生成了如下代码:
PyObject *static_api_eye(PyObject *self, PyObject *args, PyObject *kwargs);

PyObject *static_api_eye(PyObject *self, PyObject *args, PyObject *kwargs) {
    try {
        VLOG(6) << "Add eye op into program";
        VLOG(8) << "args count: " << (PyTuple_Size(args) / 2);

        // Get OpResult from args

        // Parse Attributes
        PyObject *num_rows_obj = PyTuple_GET_ITEM(args, 0);
        PyObject *num_columns_obj = PyTuple_GET_ITEM(args, 1);
        PyObject *dtype_obj = PyTuple_GET_ITEM(args, 2);
        PyObject *place_obj = PyTuple_GET_ITEM(args, 3);

        // Check for mutable attrs
        bool has_mutable_attr = false;
        if (PyObject_CheckIROpResult(num_rows_obj)){
            has_mutable_attr = true;
        }
        if (PyObject_CheckIROpResult(num_columns_obj)){
            has_mutable_attr = true;
        }

        if (has_mutable_attr){
            ir::OpResult num_rows = CastPyArg2OpResult(num_rows_obj, "eye", 0);
            ir::OpResult num_columns = CastPyArg2OpResult(num_columns_obj, "eye", 1);
            phi::DataType dtype = CastPyArg2DataType(dtype_obj, "eye", 2);
            Place place = CastPyArg2Place(place_obj, "eye", 3);
            // Call ir static api
            auto static_api_out = paddle::dialect::eye(num_rows, num_columns, dtype, place);
            return ToPyObject(static_api_out);
        } else {
            float num_rows = CastPyArg2Float(num_rows_obj, "eye", 0);
            float num_columns = CastPyArg2Float(num_columns_obj, "eye", 1);
            phi::DataType dtype = CastPyArg2DataType(dtype_obj, "eye", 2);
            Place place = CastPyArg2Place(place_obj, "eye", 3);
            // Call ir static api
            auto static_api_out = paddle::dialect::eye(num_rows, num_columns, dtype, place);
            return ToPyObject(static_api_out);
        }
    } catch (...) {
        ThrowExceptionToPython(std::current_exception());
        return nullptr;
    }
}

TODO:

  • 对于可变attr的api,比如某个api输入为x, y,attr有3个为a,b,c,其中b,c是可变attr。当用户实际调用时b传入的是OpResult,c传入的不是OpResult,这个时候需要将c也full成OpResult。后续需要处理这类情况

Pcard-67164

Copy link
Contributor

@Aurelius84 Aurelius84 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@Aurelius84 Aurelius84 merged commit 048ce0c into PaddlePaddle:develop Aug 25, 2023
26 checks passed
BeingGod pushed a commit to BeingGod/Paddle that referenced this pull request Sep 9, 2023
* Gen all Apis

* Gen python c apis

* Add empty file

* Fix cast data type

* Fix None dtype
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants