-
Notifications
You must be signed in to change notification settings - Fork 5.5k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #14 from PaddlePaddle/develop
merge to local
- Loading branch information
Showing
39 changed files
with
2,910 additions
and
601 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/operators/range_op.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
class RangeOp : public framework::OperatorWithKernel { | ||
public: | ||
using framework::OperatorWithKernel::OperatorWithKernel; | ||
|
||
void InferShape(framework::InferShapeContext *ctx) const override { | ||
if (ctx->HasInput("Start")) { | ||
auto s_dims = ctx->GetInputDim("Start"); | ||
PADDLE_ENFORCE((s_dims.size() == 1) && (s_dims[0] == 1), | ||
"The shape of Input(Start) should be [1]."); | ||
} | ||
if (ctx->HasInput("End")) { | ||
auto e_dims = ctx->GetInputDim("End"); | ||
PADDLE_ENFORCE((e_dims.size() == 1) && (e_dims[0] == 1), | ||
"The shape of Input(End) should be [1]."); | ||
} | ||
if (ctx->HasInput("Step")) { | ||
auto step_dims = ctx->GetInputDim("Step"); | ||
PADDLE_ENFORCE((step_dims.size() == 1) && (step_dims[0] == 1), | ||
"The shape of Input(Step) should be [1]."); | ||
} | ||
ctx->SetOutputDim("Out", {-1}); | ||
} | ||
}; | ||
|
||
class RangeOpMaker : public framework::OpProtoAndCheckerMaker { | ||
public: | ||
void Make() override { | ||
AddInput("Start", | ||
"Start of interval. The interval includes this value. It is a " | ||
"tensor with shape=[1]."); | ||
AddInput("End", | ||
"End of interval. The interval does not include this value, " | ||
"except in some cases where step is not an integer and floating " | ||
"point round-off affects the length of out. It is a tensor with " | ||
"shape=[1]."); | ||
AddInput("Step", "Spacing between values. It is a tensor with shape=[1]."); | ||
AddOutput("Out", "A sequence of numbers."); | ||
AddComment(R"DOC( | ||
Return evenly spaced values within a given interval. Values are generated within the half-open interval [start, stop) (in other words, the interval including start but excluding stop). Like arange function of numpy. | ||
)DOC"); | ||
} | ||
}; | ||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OP_WITHOUT_GRADIENT(range, ops::RangeOp, ops::RangeOpMaker); | ||
REGISTER_OP_CPU_KERNEL(range, ops::CPURangeKernel<int>, | ||
ops::CPURangeKernel<float>, ops::CPURangeKernel<double>, | ||
ops::CPURangeKernel<int64_t>); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/operators/range_op.h" | ||
#include "paddle/fluid/platform/cuda_primitives.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
#define CUDA_1D_KERNEL_LOOP(i, n) \ | ||
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \ | ||
i += blockDim.x * gridDim.x) | ||
|
||
template <typename T> | ||
__global__ void RangeKernel(T start, T step, int64_t size, T* out) { | ||
CUDA_1D_KERNEL_LOOP(index, size) { out[index] = start + step * index; } | ||
} | ||
|
||
template <typename T> | ||
class CUDARangeKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
auto* start_t = context.Input<framework::Tensor>("Start"); | ||
auto* end_t = context.Input<framework::Tensor>("End"); | ||
auto* step_t = context.Input<framework::Tensor>("Step"); | ||
auto* out = context.Output<framework::Tensor>("Out"); | ||
|
||
framework::Tensor n; | ||
framework::TensorCopy(*start_t, platform::CPUPlace(), &n); | ||
T start = n.data<T>()[0]; | ||
framework::TensorCopy(*end_t, platform::CPUPlace(), &n); | ||
T end = n.data<T>()[0]; | ||
framework::TensorCopy(*step_t, platform::CPUPlace(), &n); | ||
T step = n.data<T>()[0]; | ||
|
||
int64_t size = 0; | ||
GetSize(start, end, step, &size); | ||
out->Resize(framework::make_ddim({size})); | ||
T* out_data = out->mutable_data<T>(context.GetPlace()); | ||
|
||
auto stream = context.cuda_device_context().stream(); | ||
int block = 512; | ||
int grid = (size + block - 1) / block; | ||
RangeKernel<T><<<grid, block, 0, stream>>>(start, step, size, out_data); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OP_CUDA_KERNEL(range, ops::CUDARangeKernel<int>, | ||
ops::CUDARangeKernel<int64_t>, | ||
ops::CUDARangeKernel<float>, | ||
ops::CUDARangeKernel<double>); |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. | ||
Licensed under the Apache License, Version 2.0 (the "License"); | ||
you may not use this file except in compliance with the License. | ||
You may obtain a copy of the License at | ||
http://www.apache.org/licenses/LICENSE-2.0 | ||
Unless required by applicable law or agreed to in writing, software | ||
distributed under the License is distributed on an "AS IS" BASIS, | ||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#pragma once | ||
#include <functional> | ||
#include "paddle/fluid/framework/op_registry.h" | ||
#include "paddle/fluid/operators/math/math_function.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
template <typename T> | ||
void GetSize(T start, T end, T step, int64_t* size) { | ||
PADDLE_ENFORCE(!std::equal_to<T>()(step, 0), | ||
"The step of range op should not be 0."); | ||
PADDLE_ENFORCE(((start < end) && (step > 0)) || ((start > end) && (step < 0)), | ||
"The step should be greater than 0 while start < end. And the " | ||
"step should be less than 0 while start > end."); | ||
*size = std::is_integral<T>::value | ||
? ((std::abs(end - start) + std::abs(step) - 1) / std::abs(step)) | ||
: std::ceil(std::abs((end - start) / step)); | ||
} | ||
|
||
template <typename T> | ||
class CPURangeKernel : public framework::OpKernel<T> { | ||
public: | ||
void Compute(const framework::ExecutionContext& context) const override { | ||
T start = context.Input<framework::Tensor>("Start")->data<T>()[0]; | ||
T end = context.Input<framework::Tensor>("End")->data<T>()[0]; | ||
T step = context.Input<framework::Tensor>("Step")->data<T>()[0]; | ||
auto* out = context.Output<framework::Tensor>("Out"); | ||
int64_t size = 0; | ||
GetSize(start, end, step, &size); | ||
out->Resize(framework::make_ddim({size})); | ||
T* out_data = out->mutable_data<T>(context.GetPlace()); | ||
T value = start; | ||
for (int64_t i = 0; i < size; ++i) { | ||
out_data[i] = value; | ||
value += step; | ||
} | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.