diff --git a/paddle/framework/executor_test.cc b/paddle/framework/executor_test.cc index 25b737cad62d6..d5a3d8d2011b6 100644 --- a/paddle/framework/executor_test.cc +++ b/paddle/framework/executor_test.cc @@ -46,6 +46,7 @@ void AddOp(const std::string& type, const VariableNameMap& inputs, for (auto kv : outputs) { for (auto v : kv.second) { auto var = block->NewVar(v); + var->SetType(VarDesc::LOD_TENSOR); var->SetDataType(paddle::framework::DataType::FP32); } } diff --git a/paddle/framework/framework.proto b/paddle/framework/framework.proto index b7a63f9ba10b7..65760b07ada7a 100644 --- a/paddle/framework/framework.proto +++ b/paddle/framework/framework.proto @@ -97,16 +97,26 @@ enum DataType { FP64 = 6; } -message LoDTensorDesc { +message TensorDesc { required DataType data_type = 1; repeated int64 dims = 2; // [UNK, 640, 480] is saved as [-1, 640, 480] - optional int32 lod_level = 3 [ default = 0 ]; +} + +message LoDTensorDesc { + required TensorDesc tensor = 1; + optional int32 lod_level = 2 [ default = 0 ]; } message VarDesc { + enum VarType { + LOD_TENSOR = 1; + SELECTED_ROWS = 2; + } required string name = 1; - optional LoDTensorDesc lod_tensor = 2; - optional bool persistable = 3 [ default = false ]; + required VarType type = 2; + optional LoDTensorDesc lod_tensor = 3; + optional TensorDesc selected_rows = 4; + optional bool persistable = 5 [ default = false ]; } message BlockDesc { diff --git a/paddle/framework/var_desc.cc b/paddle/framework/var_desc.cc index a88e813b5e7c7..c302217e5aacd 100644 --- a/paddle/framework/var_desc.cc +++ b/paddle/framework/var_desc.cc @@ -13,32 +13,58 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/framework/var_desc.h" +#include "paddle/platform/enforce.h" namespace paddle { namespace framework { void VarDescBind::SetShape(const std::vector &dims) { - VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims()); + VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims()); } void VarDescBind::SetDataType(DataType data_type) { - desc_.mutable_lod_tensor()->set_data_type(data_type); + mutable_tensor_desc()->set_data_type(data_type); } std::vector VarDescBind::Shape() const { - return RepeatedToVector(desc_.lod_tensor().dims()); + return RepeatedToVector(tensor_desc().dims()); } -DataType VarDescBind::GetDataType() const { - return desc_.lod_tensor().data_type(); -} +DataType VarDescBind::GetDataType() const { return tensor_desc().data_type(); } void VarDescBind::SetLoDLevel(int32_t lod_level) { + PADDLE_ENFORCE(desc_.type() == VarDesc::LOD_TENSOR); desc_.mutable_lod_tensor()->set_lod_level(lod_level); } int32_t VarDescBind::GetLodLevel() const { + PADDLE_ENFORCE(desc_.type() == VarDesc::LOD_TENSOR); return desc_.lod_tensor().lod_level(); } + +const TensorDesc &VarDescBind::tensor_desc() const { + PADDLE_ENFORCE(desc_.has_type(), "invoke TensorDesc must after set type"); + switch (desc_.type()) { + case VarDesc::SELECTED_ROWS: + return desc_.selected_rows(); + case VarDesc::LOD_TENSOR: + return desc_.lod_tensor().tensor(); + default: + PADDLE_THROW("Unexpected branch."); + } +} + +TensorDesc *VarDescBind::mutable_tensor_desc() { + PADDLE_ENFORCE(desc_.has_type(), + "invoke MutableTensorDesc must after set type"); + switch (desc_.type()) { + case VarDesc::SELECTED_ROWS: + return desc_.mutable_selected_rows(); + case VarDesc::LOD_TENSOR: + return desc_.mutable_lod_tensor()->mutable_tensor(); + default: + PADDLE_THROW("Unexpected branch."); + } +} } // namespace framework } // namespace paddle diff --git a/paddle/framework/var_desc.h b/paddle/framework/var_desc.h index 4436879564566..8ffb858eb244d 100644 --- a/paddle/framework/var_desc.h +++ b/paddle/framework/var_desc.h @@ -72,7 +72,14 @@ class VarDescBind { int32_t GetLodLevel() const; + VarDesc::VarType GetType() const { return desc_.type(); } + + void SetType(VarDesc::VarType type) { desc_.set_type(type); } + private: + const TensorDesc &tensor_desc() const; + TensorDesc *mutable_tensor_desc(); + VarDesc desc_; }; } // namespace framework diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 7ab4e6a451846..89c625b42d485 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -162,7 +162,8 @@ void BindVarDsec(py::module &m) { .value("FP32", DataType::FP32) .value("FP64", DataType::FP64); - py::class_(m, "VarDesc", "") + py::class_ var_desc(m, "VarDesc", ""); + var_desc .def("name", [](const VarDescBind &self) { py::bytes name = self.Name(); @@ -174,7 +175,13 @@ void BindVarDsec(py::module &m) { .def("shape", &VarDescBind::Shape, py::return_value_policy::reference) .def("data_type", &VarDescBind::GetDataType) .def("lod_level", &VarDescBind::GetLodLevel) - .def("set_lod_level", &VarDescBind::SetLoDLevel); + .def("set_lod_level", &VarDescBind::SetLoDLevel) + .def("type", &VarDescBind::GetType) + .def("set_type", &VarDescBind::SetType); + + py::enum_(var_desc, "VarType", "") + .value("LOD_TENSOR", VarDesc::LOD_TENSOR) + .value("SELECTED_ROWS", VarDesc::SELECTED_ROWS); } void BindOpDesc(py::module &m) { diff --git a/python/paddle/v2/framework/framework.py b/python/paddle/v2/framework/framework.py index 2afbd0c83158d..c57d723996010 100644 --- a/python/paddle/v2/framework/framework.py +++ b/python/paddle/v2/framework/framework.py @@ -10,6 +10,7 @@ class Variable(object): def __init__(self, block, + type=core.VarDesc.VarType.LOD_TENSOR, name=None, shape=None, dtype=None, @@ -26,6 +27,14 @@ def __init__(self, self.desc = self.block.desc.new_var(name) is_new_var = True + if is_new_var: + self.desc.set_type(type) + elif self.desc.type() != type: + raise ValueError("Variable {0} has been created before. The " + "previous type is {1}; the new type is {2}. They" + " are not matched".format(self.name, + self.desc.type(), type)) + if shape is not None: if is_new_var: self.desc.set_shape(shape) diff --git a/python/paddle/v2/framework/tests/test_infer_shape.py b/python/paddle/v2/framework/tests/test_infer_shape.py index 99562890fdd4d..9d9fb1c3096ed 100644 --- a/python/paddle/v2/framework/tests/test_infer_shape.py +++ b/python/paddle/v2/framework/tests/test_infer_shape.py @@ -14,11 +14,14 @@ def test_sum_op(self): # prepare input/output x1 = block.new_var("x1") + x1.set_type(core.VarDesc.VarType.LOD_TENSOR) x1.set_shape(shape) x2 = block.new_var("x2") + x2.set_type(core.VarDesc.VarType.LOD_TENSOR) x2.set_shape(shape) out = block.new_var("out") + out.set_type(core.VarDesc.VarType.LOD_TENSOR) # prepare the operator sum_op_desc = block.append_op() @@ -40,11 +43,14 @@ def test_mul_op(self): # prepare input/output x1 = block.new_var("x") + x1.set_type(core.VarDesc.VarType.LOD_TENSOR) x1.set_shape(x_shape) x2 = block.new_var("y") + x2.set_type(core.VarDesc.VarType.LOD_TENSOR) x2.set_shape(y_shape) out = block.new_var("out") + out.set_type(core.VarDesc.VarType.LOD_TENSOR) # prepare the operator mul_op_desc = block.append_op() diff --git a/python/paddle/v2/framework/tests/test_protobuf_descs.py b/python/paddle/v2/framework/tests/test_protobuf_descs.py index af5ed6801fa7b..9b3a21261f02b 100644 --- a/python/paddle/v2/framework/tests/test_protobuf_descs.py +++ b/python/paddle/v2/framework/tests/test_protobuf_descs.py @@ -94,17 +94,21 @@ def test_shape(self): program_desc = core.ProgramDesc.__create_program_desc__() block = program_desc.block(0) var = block.new_var('my_var') + var.set_type(core.VarDesc.VarType.SELECTED_ROWS) src_shape = [3, 2, 10, 8] var.set_shape(src_shape) res_shape = var.shape() self.assertEqual(src_shape, res_shape) + self.assertEqual(core.VarDesc.VarType.SELECTED_ROWS, var.type()) def test_data_type(self): program_desc = core.ProgramDesc.__create_program_desc__() block = program_desc.block(0) var = block.new_var('my_var') + var.set_type(core.VarDesc.VarType.LOD_TENSOR) var.set_data_type(core.DataType.INT32) self.assertEqual(core.DataType.INT32, var.data_type()) + self.assertEqual(core.VarDesc.VarType.LOD_TENSOR, var.type()) class TestBlockDesc(unittest.TestCase):