Skip to content

Commit

Permalink
reduce more cpu overhead
Browse files Browse the repository at this point in the history
Signed-off-by: daquexian <daquexian566@gmail.com>
  • Loading branch information
daquexian committed Aug 12, 2023
1 parent aec55ba commit 15eb680
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 54 deletions.
88 changes: 48 additions & 40 deletions rwkv_pip_package/src/rwkv/cuda/att_one_v5.cu
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,20 @@ struct Mix {
__hmul(sx_, __hsub(__float2half(1), r_mix_)));
}
};

struct ToHalf {
const float *x;
half *y;
__device__ void operator()(int i) const { y[i] = __float2half(x[i]); }
};

struct InplaceAdd {
__device__ __forceinline__ half operator()(int i) const {
y[i] = __hadd(x[i], y[i]);
}
half *y;
half *x;
};
} // namespace

using torch::Tensor;
Expand All @@ -64,50 +78,44 @@ void gemm_cublas(const void *a, const void *b, void *c, int batch, int ori_m,
at::ScalarType torch_output_dtype);

Tensor att_one_v5(Tensor x, Tensor sx, Tensor s, Tensor ln_w, Tensor ln_b,
Tensor lx_w, Tensor lx_b, Tensor kvr_mix,
/* imm */ Tensor kvrx, Tensor kvrw, Tensor ow, Tensor t_first,
Tensor t_decay, /* imm */ Tensor kvr, /* imm */ Tensor a,
/* imm */ Tensor buf,
/* imm */ Tensor s1,
/* out */ Tensor x_plus_out, /* out */ Tensor s2) {
Tensor lx_w, Tensor lx_b, Tensor kvr_mix, Tensor kvrw,
Tensor ow, Tensor t_first, Tensor t_decay, Tensor tmp,
Tensor buf, /* out */ Tensor s2_t,
/* out */ Tensor x_plus_out_t) {
const int x_numel = x.numel();
Tensor xx = at::layer_norm(x, {x_numel}, ln_w, ln_b);
int H = t_decay.size(0);
int S = x_numel / H;
char *buf_ptr = (char *)buf.data_ptr();
half *kvrx = (half *)buf_ptr;
float *kvr = (float *)(kvrx + 3 * x_numel);
float *a = kvr + 3 * x_numel;
half *tmp2 = (half *)(a + H * S * S);
float *s1 = (float *)(tmp2 + x_numel);
float *s2 = data_ptr<float>(s2_t);
half *x_plus_out = data_ptr<half>(x_plus_out_t);

element_wise(Mix{data_ptr<half>(xx), data_ptr<half>(sx),
data_ptr<half>(kvr_mix), static_cast<int>(x_numel),
data_ptr<half>(kvrx)},
data_ptr<half>(kvr_mix), static_cast<int>(x_numel), kvrx},
x_numel);

int H = t_decay.size(0);
int S = x_numel / H;
// gemm_cublas_tensor(at::unsqueeze(kvrx, 1), kvrw, kvr);
gemm_cublas(data_ptr<half>(kvrx), data_ptr<half>(kvrw), data_ptr<float>(kvr),
3, 1, x_numel, x_numel, at::kHalf, at::kFloat);
float* k = data_ptr<float>(kvr);
float* v = k + x_numel;
float* r = v + x_numel;
// Tensor k = at::reshape(kvr[0], {H, S, 1});
// Tensor v = at::reshape(kvr[1], {H, 1, S});
// Tensor r = at::reshape(kvr[2], {H, 1, S});

// gemm_cublas_tensor(k, v, a);
gemm_cublas(k, v, data_ptr<float>(a), H, S, S, 1, at::kFloat, at::kFloat);
// s1 = t_first * a + s
// s2 = a + t_decay * s
element_wise(Fused1{data_ptr<float>(t_first), data_ptr<float>(t_decay),
data_ptr<float>(a), data_ptr<float>(s),
static_cast<int32_t>(a.size(1) * a.size(2)),
data_ptr<float>(s1), data_ptr<float>(s2)},
a.numel());

// gemm_cublas_tensor(r, s1, buf);
gemm_cublas(r, data_ptr<float>(s1), data_ptr<float>(buf), H, 1, S, S,
at::kFloat, at::kFloat);
buf = at::group_norm(buf, H, lx_w, lx_b);
buf = at::_cast_Half(buf);

// gemm_cublas_tensor(buf, ow, x_plus_out);
gemm_cublas(data_ptr<half>(buf), data_ptr<half>(ow), data_ptr<half>(x_plus_out),
1, 1, x_numel, x_numel, at::kHalf, at::kHalf);
x_plus_out += x;
gemm_cublas(kvrx, data_ptr<half>(kvrw), kvr, 3, 1, x_numel, x_numel,
at::kHalf, at::kFloat);
float *k = kvr;
float *v = k + x_numel;
float *r = v + x_numel;

gemm_cublas(k, v, a, H, S, S, 1, at::kFloat, at::kFloat);
element_wise(Fused1{data_ptr<float>(t_first), data_ptr<float>(t_decay), a,
data_ptr<float>(s), static_cast<int32_t>(S * S), s1, s2},
H * S * S);

gemm_cublas(r, s1, data_ptr<float>(tmp), H, 1, S, S, at::kFloat, at::kFloat);
tmp = at::group_norm(tmp, H, lx_w, lx_b);
element_wise(ToHalf{data_ptr<float>(tmp), tmp2}, tmp.numel());

gemm_cublas(tmp2, data_ptr<half>(ow), x_plus_out, 1, 1, x_numel, x_numel,
at::kHalf, at::kHalf);
element_wise(InplaceAdd{x_plus_out, data_ptr<half>(x)}, x.numel());
return xx;
}
10 changes: 4 additions & 6 deletions rwkv_pip_package/src/rwkv/cuda/wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -137,12 +137,10 @@ Tensor att_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
Tensor t_decay, /* imm */ Tensor buf, /* out */ Tensor x_plus_out);

Tensor att_one_v5(Tensor x, Tensor sx, Tensor s, Tensor ln_w, Tensor ln_b,
Tensor lx_w, Tensor lx_b, Tensor kvr_mix,
/* imm */ Tensor kvrx, Tensor kvrw, Tensor ow, Tensor t_first,
Tensor t_decay, /* imm */ Tensor kvr, /* imm */ Tensor a,
/* imm */ Tensor buf,
/* imm */ Tensor s1,
/* out */ Tensor x_plus_out, /* out */ Tensor s2);
Tensor lx_w, Tensor lx_b, Tensor kvr_mix, Tensor kvrw,
Tensor ow, Tensor t_first, Tensor t_decay, Tensor tmp,
Tensor buf, /* out */ Tensor s2_t,
/* out */ Tensor x_plus_out_t);

Tensor ffn_seq(Tensor x, Tensor sx, Tensor ln_w, Tensor ln_b, Tensor k_mix,
Tensor r_mix, Tensor kw, Tensor vw, Tensor rw,
Expand Down
22 changes: 14 additions & 8 deletions rwkv_pip_package/src/rwkv/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -746,21 +746,27 @@ def cuda_att_one_fp16(self, x, sx, aa, bb, pp, ln_w, ln_b, k_mix, v_mix, r_mix,

@MyFunction
def cuda_att_one_v5_fp16(self, x, sx, s, ln_w, ln_b, lx_w, lx_b, kvr_mix, t_decay, t_first, kvrw, ow, kmx, krx, kmy, kry, vmx, vrx, vmy, vry, rmx, rrx, rmy, rry, omx, orx, omy, ory):
kvrx = torch.empty((3, x.numel()), dtype=x.dtype, device=x.device)

H = t_decay.shape[0]
S = x.shape[-1] // H

kvr = torch.empty((3, 1, x.shape[-1]), dtype=torch.float32, device=x.device)
a = torch.empty((H, S, S), dtype=torch.float32, device=x.device)
buf = torch.empty((1, x.shape[-1]), dtype=torch.float32, device=x.device)
s1 = torch.empty((H, S, S), dtype=torch.float32, device=x.device)
tmp = torch.empty((1, x.shape[-1]), dtype=torch.float32, device=x.device)
buf = torch.empty((3 * x.numel() * 2 + 3 * x.numel() * 4 + H * S * S * 4 + x.numel() * 2 + H * S * S * 4,), dtype=torch.int8, device=x.device)
# two outputs
s2 = torch.empty((H, S, S), dtype=torch.float32, device=x.device)
x_plus_out = torch.empty_like(x)

# kvrx = torch.empty((3, x.numel()), dtype=x.dtype, device=x.device)
# kvr = torch.empty((3, 1, x.shape[-1]), dtype=torch.float32, device=x.device)
# a = torch.empty((H, S, S), dtype=torch.float32, device=x.device)
# tmp2 = torch.empty((1, x.shape[-1]), dtype=torch.float16, device=x.device)
# s1 = torch.empty((H, S, S), dtype=torch.float32, device=x.device)
# s2 = torch.empty((H, S, S), dtype=torch.float32, device=x.device)
# x_plus_out = torch.empty_like(x)

# import pdb; pdb.set_trace()

xx = torch.ops.rwkv.att_one_v5(x, sx, s, ln_w, ln_b, lx_w, lx_b, kvr_mix, kvrx, kvrw, ow, t_first, t_decay, kvr, a, buf, s1, x_plus_out, s2) # type: ignore[reportGeneralTypeIssues]
xx = torch.ops.rwkv.att_one_v5(x, sx, s, ln_w, ln_b, lx_w, lx_b, kvr_mix, kvrw, ow, t_first, t_decay, tmp, buf, s2, x_plus_out) # type: ignore[reportGeneralTypeIssues]

# import pdb; pdb.set_trace()
return x_plus_out, xx, s2

@MyFunction
Expand Down

0 comments on commit 15eb680

Please sign in to comment.