Skip to content

Commit

Permalink
auto format by CI
Browse files Browse the repository at this point in the history
  • Loading branch information
oneflow-ci-bot committed Mar 29, 2022
1 parent 6114b70 commit 18215dc
Show file tree
Hide file tree
Showing 8 changed files with 24 additions and 32 deletions.
3 changes: 2 additions & 1 deletion oneflow/core/common/shape.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ int64_t ShiftNegativeAxis(int64_t axis, const int64_t num_axes) {
return axis;
}

Shape::Shape(const std::initializer_list<int64_t>& dim_vec) : dim_vec_(dim_vec), is_initialized_(true) {}
Shape::Shape(const std::initializer_list<int64_t>& dim_vec)
: dim_vec_(dim_vec), is_initialized_(true) {}
Shape::Shape(const DimVector& dim_vec) : dim_vec_(dim_vec), is_initialized_(true) {}
Shape::Shape(DimVector&& dim_vec) : dim_vec_(std::move(dim_vec)), is_initialized_(true) {}
Shape::Shape(const ShapeProto& shape_proto) : is_initialized_(true) {
Expand Down
8 changes: 2 additions & 6 deletions oneflow/core/common/shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,9 @@ class Shape final {

Maybe<Shape> Slice(int64_t start_dim, int64_t end_dim) const;

ShapeView ToShapeView() const {
return ShapeView(dim_vec_.data(), dim_vec_.size());
}
ShapeView ToShapeView() const { return ShapeView(dim_vec_.data(), dim_vec_.size()); }

MutShapeView ToMutShapeView() {
return MutShapeView(dim_vec_.data(), dim_vec_.size());
}
MutShapeView ToMutShapeView() { return MutShapeView(dim_vec_.data(), dim_vec_.size()); }

private:
DimVector dim_vec_;
Expand Down
10 changes: 6 additions & 4 deletions oneflow/core/eager/eager_blob_object.h
Original file line number Diff line number Diff line change
Expand Up @@ -141,12 +141,14 @@ class EagerBlobObject final {
const Shape& shape() const { return *shape_; }
Shape& mut_shape() { return *shape_; }

size_t ByteSizeOfBlobBody() const {
return shape_->elem_cnt() * GetSizeOfDataType(data_type_);
size_t ByteSizeOfBlobBody() const { return shape_->elem_cnt() * GetSizeOfDataType(data_type_); }
size_t AlignedByteSizeOfBlobBody() const {
return RoundUp(ByteSizeOfBlobBody(), kBlobBodyAlignSize);
}
size_t AlignedByteSizeOfBlobBody() const { return RoundUp(ByteSizeOfBlobBody(), kBlobBodyAlignSize); }
size_t ByteSizeOfBlobHeader() const { return shape().NumAxes() * sizeof(int64_t); }
size_t AlignedByteSizeOfBlobHeader() const { return RoundUp(ByteSizeOfBlobHeader(), kBlobHeaderAlignSize); }
size_t AlignedByteSizeOfBlobHeader() const {
return RoundUp(ByteSizeOfBlobHeader(), kBlobHeaderAlignSize);
}

template<typename T = void>
const T* dptr() const {
Expand Down
5 changes: 2 additions & 3 deletions oneflow/core/framework/op_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -164,10 +164,9 @@ class KernelInferContext {

virtual ep::Stream* stream() = 0;
virtual Tensor* Tensor4ArgNameAndIndex(const std::string& arg_name, int32_t arg_index) = 0;
virtual ShapeView ShapeView4ArgNameAndIndex(const std::string& arg_name,
int32_t arg_index) = 0;
virtual ShapeView ShapeView4ArgNameAndIndex(const std::string& arg_name, int32_t arg_index) = 0;
virtual MutShapeView MutShapeView4ArgNameAndIndex(const std::string& arg_name,
int32_t arg_index) = 0;
int32_t arg_index) = 0;

const std::string& input(const std::string& arg_name, int32_t index) const {
return user_op_conf().input(arg_name, index);
Expand Down
4 changes: 3 additions & 1 deletion oneflow/core/job/runtime.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,9 @@ bool HasNonCtrlConsumedRegstDescId(const TaskProto& task) {

} // namespace

Runtime::Runtime(const Plan& plan, const HashMap<std::string, vm::EagerBlobObject*>& variable_op_name2eager_blob_object) {
Runtime::Runtime(
const Plan& plan,
const HashMap<std::string, vm::EagerBlobObject*>& variable_op_name2eager_blob_object) {
{
// NOTE(chengcheng): All runtime Global objects AddPlan
Global<RegstMgr>::Get()->AddPlan(plan, variable_op_name2eager_blob_object);
Expand Down
5 changes: 2 additions & 3 deletions oneflow/core/kernel/user_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,15 +411,14 @@ class UserKernelInferContext final : public user_op::KernelInferContext {
CHECK(it != arg2tensor_.end()) << "Arg (" << arg_name << "," << arg_index << ") is not found";
return it->second.get();
}
ShapeView ShapeView4ArgNameAndIndex(const std::string& arg_name,
int32_t arg_index) override {
ShapeView ShapeView4ArgNameAndIndex(const std::string& arg_name, int32_t arg_index) override {
user_op::Tensor* arg_tensor = Tensor4ArgNameAndIndex(arg_name, arg_index);
CHECK(arg_tensor != nullptr) << "Tensor of arg (" << arg_name << "," << arg_index
<< ") is not found";
return arg_tensor->shape();
}
MutShapeView MutShapeView4ArgNameAndIndex(const std::string& arg_name,
int32_t arg_index) override {
int32_t arg_index) override {
user_op::Tensor* arg_tensor = Tensor4ArgNameAndIndex(arg_name, arg_index);
CHECK(arg_tensor != nullptr) << "Tensor of arg (" << arg_name << "," << arg_index
<< ") is not found";
Expand Down
5 changes: 3 additions & 2 deletions oneflow/core/register/register_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ class RegstMgr final {
RegstMgr() = default;
~RegstMgr() = default;

void AddPlan(const Plan& plan,
const HashMap<std::string, vm::EagerBlobObject*>& variable_op_name2eager_blob_object);
void AddPlan(
const Plan& plan,
const HashMap<std::string, vm::EagerBlobObject*>& variable_op_name2eager_blob_object);
void AddPlan(const Plan& plan);
void NewRegsts(const RegstDescProto& regst_desc_proto, std::function<void(Regst*)> OneRegstDone);
const RtRegstDesc& RegstDesc4RegstDescId(int64_t regst_desc_id) const;
Expand Down
16 changes: 4 additions & 12 deletions oneflow/user/kernels/stateful_local_opkernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,7 @@ class EagerBlobObjectTensorView final : public user_op::Tensor {

DataType data_type() const override { return mut_eager_blob_object_()->data_type(); }

const MemoryCase& mem_case() const override {
return mut_eager_blob_object_()->mem_case();
}
const MemoryCase& mem_case() const override { return mut_eager_blob_object_()->mem_case(); }

const void* raw_dptr() const override { return mut_eager_blob_object_()->dptr(); }

Expand All @@ -83,19 +81,13 @@ class EagerBlobObjectTensorDescView final : public user_op::TensorDesc {

DataType data_type() const override { return mut_eager_blob_object_()->data_type(); }

DataType* mut_data_type() override {
return mut_eager_blob_object_()->mut_data_type();
}
DataType* mut_data_type() override { return mut_eager_blob_object_()->mut_data_type(); }

bool is_dynamic() const override { return mut_eager_blob_object_()->is_dynamic(); }

bool* mut_is_dynamic() override {
return mut_eager_blob_object_()->mut_is_dynamic();
}
bool* mut_is_dynamic() override { return mut_eager_blob_object_()->mut_is_dynamic(); }

void set_is_dynamic(bool val) override {
mut_eager_blob_object_()->set_is_dynamic(val);
}
void set_is_dynamic(bool val) override { mut_eager_blob_object_()->set_is_dynamic(val); }

private:
const std::function<vm::EagerBlobObject*()> mut_eager_blob_object_;
Expand Down

0 comments on commit 18215dc

Please sign in to comment.