Skip to content

Commit

Permalink
add radam op
Browse files Browse the repository at this point in the history
  • Loading branch information
megemini committed Apr 18, 2024
1 parent b41a902 commit 2752082
Show file tree
Hide file tree
Showing 12 changed files with 1,242 additions and 15 deletions.
11 changes: 11 additions & 0 deletions paddle/phi/api/yaml/ops.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -2274,6 +2274,17 @@
func : qr
backward : qr_grad

- op : radam_
args : (Tensor param, Tensor grad, Tensor learning_rate, Tensor beta1_pow, Tensor beta2_pow, Tensor rho, Tensor moment1, Tensor moment2, Tensor master_param, float beta1 = 0.9f, float beta2 = 0.999f, float epsilon = 1.0e-8f, bool multi_precision = false)
output : Tensor(param_out), Tensor(beta1_pow_out), Tensor(beta2_pow_out), Tensor(rho_out), Tensor(moment1_out), Tensor(moment2_out), Tensor(master_param_out)
infer_meta :
func : RAdamInferMeta
kernel :
func : radam
data_type : param
optional : master_param, master_param_out
inplace : (param -> param_out), (beta1_pow -> beta1_pow_out), (beta2_pow -> beta2_pow_out), (rho -> rho_out), (moment1 -> moment1_out), (moment2 -> moment2_out), (master_param->master_param_out)

- op : real
args : (Tensor x)
output : Tensor (out)
Expand Down
82 changes: 82 additions & 0 deletions paddle/phi/infermeta/multiary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3607,6 +3607,88 @@ void QuantizeLinearInferMeta(const MetaTensor& x,
}
}

void RAdamInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& learning_rate,
const MetaTensor& beta1_pow,
const MetaTensor& beta2_pow,
const MetaTensor& rho,
const MetaTensor& moment1,
const MetaTensor& moment2,
const MetaTensor& master_param,
float beta1,
float beta2,
float epsilon,
bool multi_precision,
MetaTensor* param_out,
MetaTensor* beta1_pow_out,
MetaTensor* beta2_pow_out,
MetaTensor* rho_out,
MetaTensor* moment1_out,
MetaTensor* moment2_out,
MetaTensor* master_param_out) {
auto param_dim = param.dims();
PADDLE_ENFORCE_EQ(param_dim,
moment1.dims(),
phi::errors::InvalidArgument(
"Param and Momentum input of NAdamOp "
"should have the same dimension. But received "
"Param's dim [%s] and Moment1 [%s]",
param_dim,
moment1.dims()));
PADDLE_ENFORCE_EQ(param_dim,
moment2.dims(),
phi::errors::InvalidArgument(
"Param and Momentum input of NAdamOp "
"should have the same dimension. But received "
"Param's dim [%s] and Moment2 [%s]",
param_dim,
moment2.dims()));

auto lr_dim = learning_rate.dims();
PADDLE_ENFORCE_EQ(common::product(lr_dim),
1,
phi::errors::InvalidArgument(
"Learning Rate of NAdamOp should be a scalar. But "
"received LearningRate's dim [%s]",
common::product(lr_dim)));

if (master_param.initialized()) {
PADDLE_ENFORCE_EQ(param_dim,
master_param.dims(),
errors::InvalidArgument(
"Param and MasterParam input of NAdamOp should "
"have same dimension. But "
"received Param dims: [%s], MasterParam dims: [%s].",
param_dim,
master_param.dims()));
}

param_out->set_dims(param_dim);
param_out->set_dtype(param.dtype());

beta1_pow_out->set_dims(beta1_pow.dims());
beta1_pow_out->set_dtype(beta1_pow.dtype());
beta2_pow_out->set_dims(beta2_pow.dims());
beta2_pow_out->set_dtype(beta2_pow.dtype());
rho_out->set_dims(rho.dims());
rho_out->set_dtype(rho.dtype());

moment1_out->set_dims(param_dim);
moment1_out->set_dtype(moment1.dtype());
moment2_out->set_dims(param_dim);
moment2_out->set_dtype(moment2.dtype());

if (multi_precision && master_param.initialized()) {
auto MPType = (param.dtype() == phi::DataType::FLOAT16 ||
param.dtype() == phi::DataType::BFLOAT16)
? phi::DataType::FLOAT32
: param.dtype();
master_param_out->set_dims(param_dim);
master_param_out->set_dtype(MPType);
}
}

void RmsNormInferMeta(const MetaTensor& x,
const MetaTensor& bias,
const MetaTensor& residual,
Expand Down
21 changes: 21 additions & 0 deletions paddle/phi/infermeta/multiary.h
Original file line number Diff line number Diff line change
Expand Up @@ -669,6 +669,27 @@ void QuantizeLinearInferMeta(const MetaTensor& x,
MetaTensor* out_accum,
MetaTensor* out_state);

void RAdamInferMeta(const MetaTensor& param,
const MetaTensor& grad,
const MetaTensor& learning_rate,
const MetaTensor& beta1_pow,
const MetaTensor& beta2_pow,
const MetaTensor& rho,
const MetaTensor& moment1,
const MetaTensor& moment2,
const MetaTensor& master_param,
float beta1,
float beta2,
float epsilon,
bool multi_precision,
MetaTensor* param_out,
MetaTensor* beta1_pow_out,
MetaTensor* beta2_pow_out,
MetaTensor* rho_out,
MetaTensor* moment1_out,
MetaTensor* moment2_out,
MetaTensor* master_param_outs);

void RmsNormInferMeta(const MetaTensor& x,
const MetaTensor& bias,
const MetaTensor& residual,
Expand Down
21 changes: 21 additions & 0 deletions paddle/phi/kernels/cpu/radam_kernel.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/radam_kernel.h"

#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/radam_kernel_impl.h"

PD_REGISTER_KERNEL(radam, CPU, ALL_LAYOUT, phi::RAdamKernel, float, double) {}
10 changes: 0 additions & 10 deletions paddle/phi/kernels/gpu/nadam_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,12 @@ template <typename T, typename MT>
__global__ void NAdamGPUKernel(const T* param,
const T* grad,
const MT* learning_rate,
const MT* momentum_decay_pow,
const MT* beta2_pow,
const MT* mu_product,
const MT* moment1,
const MT* moment2,
const MT* master_param,
MT beta1,
MT beta2,
MT epsilon,
MT momentum_decay,
MT momentum_decay_base,
MT beta2_pow_scalar,
MT mu_t_scalar,
MT mu_t_1_scalar,
Expand Down Expand Up @@ -171,17 +166,12 @@ void NAdamKernel(const Context& dev_ctx,
<<<block, grid, 0, stream>>>(param.data<T>(),
grad.data<T>(),
learning_rate.data<MPDType>(),
momentum_decay_pow.data<MPDType>(),
beta2_pow.data<MPDType>(),
mu_product.data<MPDType>(),
moment1.data<MPDType>(),
moment2.data<MPDType>(),
master_in_data,
beta1_,
beta2_,
epsilon_,
momentum_decay_,
momentum_decay_base_,
beta2_pow_scalar,
mu_t_scalar,
mu_t_1_scalar,
Expand Down
189 changes: 189 additions & 0 deletions paddle/phi/kernels/gpu/radam_kernel.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
// Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/phi/kernels/radam_kernel.h"

#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"

namespace phi {
template <typename T, typename MT>
__global__ void RAdamGPUKernel(const T* param,
const T* grad,
const MT* learning_rate,
const MT* moment1,
const MT* moment2,
const MT* master_param,
MT beta1,
MT beta2,
MT epsilon,
MT beta1_pow_scalar,
MT beta2_pow_scalar,
MT rho_scalar,
MT rho_inf,
int num,
T* param_out,
MT* moment1_out,
MT* moment2_out,
MT* master_param_out) {
MT lr_scalar = static_cast<MT>(learning_rate[0]);

int idx = blockIdx.x * blockDim.x + threadIdx.x;

for (int index = idx; index < num; index += gridDim.x * blockDim.x) {
// load and cast input to MT
MT d_param =
master_param ? master_param[index] : static_cast<MT>(param[index]);
MT d_grad = static_cast<MT>(grad[index]);
MT d_moment1 = static_cast<MT>(moment1[index]);
MT d_moment2 = static_cast<MT>(moment2[index]);
// compute
MT m1_out = beta1 * d_moment1 + (static_cast<MT>(1) - beta1) * d_grad;
MT m2_out =
beta2 * d_moment2 + (static_cast<MT>(1) - beta2) * d_grad * d_grad;

MT m1_hat = m1_out / (static_cast<MT>(1) - beta1_pow_scalar);

MT rho_t = rho_inf - static_cast<MT>(2) * rho_scalar;
MT p_out = static_cast<MT>(0);

if (rho_t > static_cast<MT>(5)) {
MT l_t = std::sqrt((static_cast<MT>(1) - beta2_pow_scalar)) /
(std::sqrt(m2_out) + epsilon);
MT r_t = std::sqrt(((rho_t - static_cast<MT>(4)) *
(rho_t - static_cast<MT>(2)) * rho_inf) /
((rho_inf - static_cast<MT>(4)) *
(rho_inf - static_cast<MT>(2)) * rho_t));

p_out = d_param - lr_scalar * m1_hat * r_t * l_t;
} else {
p_out = d_param - lr_scalar * m1_hat;
}

// store
param_out[index] = static_cast<T>(p_out);
moment1_out[index] = static_cast<MT>(m1_out);
moment2_out[index] = static_cast<MT>(m2_out);

if (master_param_out) {
master_param_out[index] = p_out;
}
}
}

template <typename T, typename Context>
void RAdamKernel(const Context& dev_ctx,
const DenseTensor& param,
const DenseTensor& grad,
const DenseTensor& learning_rate,
const DenseTensor& beta1_pow,
const DenseTensor& beta2_pow,
const DenseTensor& rho,
const DenseTensor& moment1,
const DenseTensor& moment2,
const paddle::optional<DenseTensor>& master_param,
float beta1,
float beta2,
float epsilon,
bool multi_precision,
DenseTensor* param_out,
DenseTensor* beta1_pow_out,
DenseTensor* beta2_pow_out,
DenseTensor* rho_out,
DenseTensor* moment1_out,
DenseTensor* moment2_out,
DenseTensor* master_param_out) {
using MPDType = typename phi::dtype::template MPTypeTrait<T>::Type;
T* param_out_data = dev_ctx.template Alloc<T>(param_out);

MPDType* beta1_pow_out_data =
dev_ctx.template HostAlloc<MPDType>(beta1_pow_out);
MPDType* beta2_pow_out_data =
dev_ctx.template HostAlloc<MPDType>(beta2_pow_out);
MPDType* rho_out_data = dev_ctx.template HostAlloc<MPDType>(rho_out);

MPDType* moment1_out_data = dev_ctx.template Alloc<MPDType>(moment1_out);
MPDType* moment2_out_data = dev_ctx.template Alloc<MPDType>(moment2_out);

const MPDType* master_in_data =
multi_precision ? master_param->data<MPDType>() : nullptr;
MPDType* master_out_data =
multi_precision ? dev_ctx.template Alloc<MPDType>(master_param_out)
: nullptr;

MPDType beta1_ = static_cast<MPDType>(beta1);
MPDType beta2_ = static_cast<MPDType>(beta2);
MPDType epsilon_ = static_cast<MPDType>(epsilon);

MPDType rho_inf =
static_cast<MPDType>(2) / (static_cast<MPDType>(1) - beta2_) -
static_cast<MPDType>(1);

// make cpu accumulator to tensor
DenseTensor beta1_pow_data;
phi::Copy(dev_ctx, beta1_pow, phi::CPUPlace(), false, &beta1_pow_data);
MPDType beta1_pow_scalar = beta1_pow_data.data<MPDType>()[0] * beta1_;
beta1_pow_out_data[0] = beta1_pow_scalar;

DenseTensor beta2_pow_data;
phi::Copy(dev_ctx, beta2_pow, phi::CPUPlace(), false, &beta2_pow_data);
MPDType beta2_pow_scalar = beta2_pow_data.data<MPDType>()[0] * beta2_;
beta2_pow_out_data[0] = beta2_pow_scalar;

DenseTensor rho_data;
phi::Copy(dev_ctx, rho, phi::CPUPlace(), false, &rho_data);
MPDType rho_scalar =
(rho_data.data<MPDType>()[0] * (beta2_ - beta2_pow_scalar) +
beta2_pow_scalar) /
(static_cast<MPDType>(1) - beta2_pow_scalar);
rho_out_data[0] = rho_scalar;

int numel = param.numel();
int block = 512;
int grid = (param.numel() + block - 1) / block;
auto stream = dev_ctx.stream();

RAdamGPUKernel<T, MPDType>
<<<block, grid, 0, stream>>>(param.data<T>(),
grad.data<T>(),
learning_rate.data<MPDType>(),
moment1.data<MPDType>(),
moment2.data<MPDType>(),
master_in_data,
beta1_,
beta2_,
epsilon_,
beta1_pow_scalar,
beta2_pow_scalar,
rho_scalar,
rho_inf,
numel,
param_out_data,
moment1_out_data,
moment2_out_data,
master_out_data);
}
} // namespace phi

PD_REGISTER_KERNEL(radam,
GPU,
ALL_LAYOUT,
phi::RAdamKernel,
float,
double,
phi::dtype::float16) {}

0 comments on commit 2752082

Please sign in to comment.