Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Need approval] Add AdamW-CPU FP32 JIT assembly kernel #42522

Merged
merged 4 commits into from May 9, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/operators/jit/gen/CMakeLists.txt
Expand Up @@ -33,5 +33,6 @@ USE_JITKERNEL_GEN(kHMax)
USE_JITKERNEL_GEN(kHSum)
USE_JITKERNEL_GEN(kEmbSeqPool)
USE_JITKERNEL_GEN(kAdam)
USE_JITKERNEL_GEN(kAdamW)
USE_JITKERNEL_GEN(kSgd)
USE_JITKERNEL_GEN(kVBroadcast)
6 changes: 3 additions & 3 deletions paddle/fluid/operators/jit/gen/adam.cc
Expand Up @@ -80,19 +80,19 @@ void AdamJitCode::mainCode() {
// beta2 * mom2 + (1 - beta2) * g * g
vmulps(ymm7 | k1, ymm7, ymm7);
vmulps(ymm7 | k1, ymm_one_sub_beta2, ymm7);
vfmadd231ps(ymm7 | k1, ymm1, ptr[reg_mom2_ptr + reg_offset]);
vfmadd231ps(ymm7 | k1, ymm_beta2, ptr[reg_mom2_ptr + reg_offset]);

// store mom1 and mom2
vmovups(ptr[reg_mom1_out_ptr + reg_offset] | k1, ymm8);
vmovups(ptr[reg_mom2_out_ptr + reg_offset] | k1, ymm7);

// sqrt(mom2) + eps
vsqrtps(ymm7 | k1, ymm7);
vaddps(ymm7 | k1, ymm7, ymm3);
vaddps(ymm7 | k1, ymm7, ymm_eps);

// p + (-lr) * (mom1 / sqrt(mom2) + eps)
vdivps(ymm7 | k1, ymm8, ymm7);
vfmadd213ps(ymm7 | k1, ymm2, ptr[reg_param_ptr + reg_offset]);
vfmadd213ps(ymm7 | k1, ymm_lr, ptr[reg_param_ptr + reg_offset]);

// store p
vmovups(ptr[reg_param_out_ptr + reg_offset] | k1, ymm7);
Expand Down
165 changes: 165 additions & 0 deletions paddle/fluid/operators/jit/gen/adamw.cc
@@ -0,0 +1,165 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */

#include "paddle/fluid/operators/jit/gen/adamw.h"

#include <stddef.h> // offsetof

#include "paddle/fluid/operators/jit/registry.h"
#include "paddle/fluid/platform/cpu_info.h"

namespace paddle {
namespace operators {
namespace jit {
namespace gen {

void AdamWJitCode::loadArgs() {
static constexpr int32_t one_as_float = 0x3f800000;
static constexpr int32_t mask_all_ones = 0xFFFFFFFF;
static constexpr int64_t mask_8_divisible = 0xFFFFFFFFFFFFFFF8;
static constexpr int64_t abi_pushes_offset = num_g_abi_regs * 8;

mov(reg_mom2_out_ptr, ptr[rsp + (abi_pushes_offset + 8)]);
mov(reg_param_out_ptr, ptr[rsp + (abi_pushes_offset + 16)]);
mov(eax, one_as_float);
movd(xmm_one, eax);

vbroadcastss(ymm_one, xmm_one); // 1
vbroadcastss(ymm_beta1, xmm_beta1); // beta1
vbroadcastss(ymm_beta2, xmm_beta2); // beta2
vbroadcastss(ymm_lr, xmm_lr); // -lr
vbroadcastss(ymm_eps, xmm_eps); // eps
vbroadcastss(ymm_old_lr, xmm_old_lr); // old lr
vbroadcastss(ymm_lr_ratio, xmm_lr_ratio); // lr_ratio
vbroadcastss(ymm_coeff, xmm_coeff); // coeff
vsubps(ymm_one_sub_beta1, ymm_one, ymm_beta1); // 1 - beta1
vsubps(ymm_one_sub_beta2, ymm_one, ymm_beta2); // 1 - beta2

mov(reg_numel_without_tail, reg_numel);
and_(reg_numel_without_tail, mask_8_divisible); // make it 8-divisible

shl(reg_numel_without_tail, 2); // * 4 to treat it as float offset
shl(reg_numel, 2);

mov(eax, mask_all_ones);
kmovw(k1, eax);

xor_(reg_offset, reg_offset);
}

void AdamWJitCode::setTailOpmask() {
mov(r13, rcx);

mov(rcx, reg_numel);
sub(rcx, reg_offset); // get tail numel as float size
shr(rcx, 2); // as elements
mov(r14, 1);
shl(r14, cl); // 2 ^ elements
dec(r14); // 2 ^ elements - 1, so numel first bits are set to 1
kmovw(k1, r14d);

mov(rcx, r13);
}

void AdamWJitCode::mainCode() {
// load p
vmovups(ymm10 | k1, ptr[reg_param_ptr + reg_offset]);

// ((lr * lr_ratio) * coeff)
vmulps(ymm11 | k1, ymm_old_lr, ymm_lr_ratio);
vmulps(ymm11 | k1, ymm11, ymm_coeff);

// - (lr * lr_ratio) * coeff) * p + p
// p is stored in ymm11
vfnmadd132ps(ymm11 | k1, ymm10, ymm10);

// load grad
vmovups(ymm10 | k1, ptr[reg_grad_ptr + reg_offset]);

// beta1 * mom1 + (1 - beta1) * g
vmulps(ymm12 | k1, ymm_one_sub_beta1, ymm10);
vfmadd231ps(ymm12 | k1, ymm_beta1, ptr[reg_mom1_ptr + reg_offset]);

// beta2 * mom2 + (1 - beta2) * g * g
vmulps(ymm10 | k1, ymm10, ymm10);
vmulps(ymm10 | k1, ymm_one_sub_beta2, ymm10);
vfmadd231ps(ymm10 | k1, ymm_beta2, ptr[reg_mom2_ptr + reg_offset]);

// store mom1 and mom2
vmovups(ptr[reg_mom1_out_ptr + reg_offset] | k1, ymm12);
vmovups(ptr[reg_mom2_out_ptr + reg_offset] | k1, ymm10);

// sqrt(mom2) + eps
vsqrtps(ymm10 | k1, ymm10);
vaddps(ymm10 | k1, ymm10, ymm_eps);

// p + (-lr) * (mom1 / sqrt(mom2) + eps)
vdivps(ymm10 | k1, ymm12, ymm10);
vfmadd213ps(ymm10 | k1, ymm_lr, ymm11);

// store p
vmovups(ptr[reg_param_out_ptr + reg_offset] | k1, ymm10);
}

void AdamWJitCode::genCode() {
static constexpr int64_t main_loop_elems_size =
8 * sizeof(float); // 8 floats in YMM
static constexpr int64_t offset_increment = main_loop_elems_size;
preCode();
loadArgs();

cmp(reg_numel, main_loop_elems_size);
jl("process_tail");

L("main_loop");
{
mainCode();
add(reg_offset, offset_increment);
cmp(reg_numel_without_tail, reg_offset);
jg("main_loop");
}

cmp(reg_numel, reg_offset);
je("end", T_NEAR); // size between jmp and label is larger than 127 byte,
// T_NEAR allow long jump

L("process_tail");
{
setTailOpmask();
mainCode();
}

L("end");
postCode();
}

class AdamWCreator : public JitCodeCreator<int> {
public:
bool CanBeUsed(const int& attr) const override {
return platform::MayIUse(platform::avx512f);
}
size_t CodeSize(const int& attr) const override { return 96 + 32 * 8; }
std::unique_ptr<GenBase> CreateJitCode(const int& attr) const override {
return make_unique<AdamWJitCode>(attr, CodeSize(attr));
}
};

} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle

namespace gen = paddle::operators::jit::gen;

REGISTER_JITKERNEL_GEN(kAdamW, gen::AdamWCreator);
81 changes: 81 additions & 0 deletions paddle/fluid/operators/jit/gen/adamw.h
@@ -0,0 +1,81 @@
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */

#pragma once

#include <string>

#include "glog/logging.h"
#include "paddle/fluid/operators/jit/gen/jitcode.h"
#include "paddle/fluid/platform/enforce.h"

namespace paddle {
namespace operators {
namespace jit {
namespace gen {

class AdamWJitCode : public JitCode {
public:
explicit AdamWJitCode(const int& attr, size_t code_size = 256 * 1024,
void* code_ptr = nullptr)
: JitCode(code_size, code_ptr) {
this->genCode();
}

DECLARE_JIT_CODE(AdamJitCode);
void genCode() override;
void loadArgs();
void setTailOpmask();
void mainCode();

private:
reg64_t reg_numel{abi_param1};
reg64_t reg_grad_ptr{abi_param2};
reg64_t reg_mom1_ptr{abi_param3};
reg64_t reg_mom2_ptr{abi_param4};
reg64_t reg_param_ptr{abi_param5};
reg64_t reg_mom1_out_ptr{abi_param6};

xmm_t xmm_beta1 = xmm_t(0);
xmm_t xmm_beta2 = xmm_t(1);
xmm_t xmm_lr = xmm_t(2);
xmm_t xmm_eps = xmm_t(3);
xmm_t xmm_old_lr = xmm_t(4);
xmm_t xmm_lr_ratio = xmm_t(5);
xmm_t xmm_coeff = xmm_t(6);
xmm_t xmm_one_sub_beta1 = xmm_t(7);
xmm_t xmm_one_sub_beta2 = xmm_t(8);
xmm_t xmm_one = xmm_t(9);

ymm_t ymm_beta1 = ymm_t(0);
ymm_t ymm_beta2 = ymm_t(1);
ymm_t ymm_lr = ymm_t(2);
ymm_t ymm_eps = ymm_t(3);
ymm_t ymm_old_lr = ymm_t(4);
ymm_t ymm_lr_ratio = ymm_t(5);
ymm_t ymm_coeff = ymm_t(6);
ymm_t ymm_one_sub_beta1 = ymm_t(7);
ymm_t ymm_one_sub_beta2 = ymm_t(8);
ymm_t ymm_one = ymm_t(9);

reg64_t reg_mom2_out_ptr{r10};
reg64_t reg_param_out_ptr{r11};
reg64_t reg_numel_without_tail{r12};
reg64_t reg_offset{rax};
};

} // namespace gen
} // namespace jit
} // namespace operators
} // namespace paddle
1 change: 1 addition & 0 deletions paddle/fluid/operators/jit/helper.cc
Expand Up @@ -59,6 +59,7 @@ const char* to_string(KernelType kt) {
ONE_CASE(kMatMul);
ONE_CASE(kHMax);
ONE_CASE(kAdam);
ONE_CASE(kAdamW);
ONE_CASE(kHSum);
ONE_CASE(kStrideASum);
ONE_CASE(kSoftmax);
Expand Down
10 changes: 10 additions & 0 deletions paddle/fluid/operators/jit/kernel_base.h
Expand Up @@ -25,6 +25,7 @@ typedef enum {
kNone = 0,
// sort by alphabet
kAdam = 1,
kAdamW,
kCRFDecoding,
kEmbSeqPool,
kGRUH1,
Expand Down Expand Up @@ -285,6 +286,15 @@ struct AdamTuple {
const T*, T*, T*, T*);
};

template <typename T>
struct AdamWTuple {
static constexpr KernelType kernel_type = kAdamW;
typedef T data_type;
typedef int attr_type;
typedef void (*func_type)(T, T, T, T, T, T, T, int64_t, const T*, const T*,
const T*, const T*, T*, T*, T*);
};

typedef struct matmul_attr_s {
int m, n, k;
void* packed_weight{nullptr};
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/operators/jit/refer/CMakeLists.txt
Expand Up @@ -37,5 +37,6 @@ USE_JITKERNEL_REFER(kStrideASum)
USE_JITKERNEL_REFER(kSoftmax)
USE_JITKERNEL_REFER(kEmbSeqPool)
USE_JITKERNEL_REFER(kAdam)
USE_JITKERNEL_REFER(kAdamW)
USE_JITKERNEL_REFER(kSgd)
USE_JITKERNEL_REFER(kVBroadcast)
1 change: 1 addition & 0 deletions paddle/fluid/operators/jit/refer/refer.cc
Expand Up @@ -56,6 +56,7 @@ REGISTER_REFER_KERNEL(StrideASum);
REGISTER_REFER_KERNEL(Softmax);
REGISTER_REFER_KERNEL(EmbSeqPool);
REGISTER_REFER_KERNEL(Adam);
REGISTER_REFER_KERNEL(AdamW);
REGISTER_REFER_KERNEL(Sgd);
REGISTER_REFER_KERNEL(VBroadcast);

Expand Down
16 changes: 16 additions & 0 deletions paddle/fluid/operators/jit/refer/refer.h
Expand Up @@ -565,6 +565,21 @@ void Adam(T beta1, T beta2, T lr, T eps, int64_t numel, const T* grad_ptr,
}
}

template <typename T>
void AdamW(T beta1, T beta2, T lr, T eps, T old_lr, T lr_ratio, T coeff,
int64_t numel, const T* grad_ptr, const T* mom1_ptr,
const T* mom2_ptr, const T* param_ptr, T* mom1_out_ptr,
T* mom2_out_ptr, T* param_out_ptr) {
for (int i = 0; i < numel; ++i) {
auto param_tmp = param_ptr[i] - old_lr * lr_ratio * coeff * param_ptr[i];
mom1_out_ptr[i] = beta1 * mom1_ptr[i] + (1 - beta1) * grad_ptr[i];
mom2_out_ptr[i] =
beta2 * mom2_ptr[i] + (1 - beta2) * grad_ptr[i] * grad_ptr[i];
param_out_ptr[i] =
param_tmp + lr * (mom1_out_ptr[i] / (sqrt(mom2_out_ptr[i]) + eps));
}
}

#define DECLARE_REFER_KERNEL(name) \
template <typename T> \
class name##Kernel : public ReferKernel<name##Tuple<T>> { \
Expand Down Expand Up @@ -617,6 +632,7 @@ DECLARE_REFER_KERNEL(MatMul);
DECLARE_REFER_KERNEL(Softmax);
DECLARE_REFER_KERNEL(EmbSeqPool);
DECLARE_REFER_KERNEL(Adam);
DECLARE_REFER_KERNEL(AdamW);
DECLARE_REFER_KERNEL(Sgd);
DECLARE_REFER_KERNEL(VBroadcast);

Expand Down