diff --git a/src/infiniop/ops/causal_softmax/kunlun/kernel.h b/src/infiniop/ops/causal_softmax/kunlun/kernel.h index fbeb7ca64..659a93d29 100644 --- a/src/infiniop/ops/causal_softmax/kunlun/kernel.h +++ b/src/infiniop/ops/causal_softmax/kunlun/kernel.h @@ -54,7 +54,7 @@ __device__ void causalSoftmaxBlock( // Apply softmax for (size_t col = core_id(); col < width; col += BLOCK_SIZE) { if (sum_ != 0) { - y[col] = to(to(y[col]) / sum_); + y[col] = Tdata(Tcompute(y[col]) / sum_); } else { y[col] = Tdata(0); } diff --git a/src/infiniop/ops/rms_norm/kunlun/kernel.h b/src/infiniop/ops/rms_norm/kunlun/kernel.h index 4fb695e39..6372d0ce2 100644 --- a/src/infiniop/ops/rms_norm/kunlun/kernel.h +++ b/src/infiniop/ops/rms_norm/kunlun/kernel.h @@ -27,7 +27,7 @@ __device__ void rmsnormBlock( for (size_t i = core_id(); i < dim; i += BLOCK_SIZE) { Tdata xi = x[i]; Tweight wi = w[i]; - y[i] = static_cast(to(xi) * to(wi) * rms); + y[i] = Tdata(Tcompute(xi) * Tcompute(wi) * rms); } sync_cluster(); } diff --git a/src/infiniop/ops/rms_norm/kunlun/rms_norm_kunlun.xpu b/src/infiniop/ops/rms_norm/kunlun/rms_norm_kunlun.xpu index ee598e071..be331c0c8 100644 --- a/src/infiniop/ops/rms_norm/kunlun/rms_norm_kunlun.xpu +++ b/src/infiniop/ops/rms_norm/kunlun/rms_norm_kunlun.xpu @@ -95,10 +95,14 @@ infiniStatus_t launchKernel( if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F16) { LAUNCH_KERNEL(half, half, float); + } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_BF16) { + LAUNCH_KERNEL(half, bfloat16_t, float); } else if (atype == INFINI_DTYPE_F16 && wtype == INFINI_DTYPE_F32) { LAUNCH_KERNEL(half, float, float); } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_BF16) { LAUNCH_KERNEL(bfloat16_t, bfloat16_t, float); + } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F16) { + LAUNCH_KERNEL(bfloat16_t, half, float); } else if (atype == INFINI_DTYPE_BF16 && wtype == INFINI_DTYPE_F32) { LAUNCH_KERNEL(bfloat16_t, float, float); } else if (atype == INFINI_DTYPE_F32 && wtype == INFINI_DTYPE_F32) { diff --git a/src/infiniop/reduce/kunlun/reduce_kunlun.h b/src/infiniop/reduce/kunlun/reduce_kunlun.h index 7244302dc..34f23829b 100644 --- a/src/infiniop/reduce/kunlun/reduce_kunlun.h +++ b/src/infiniop/reduce/kunlun/reduce_kunlun.h @@ -14,12 +14,12 @@ __device__ inline Tcompute sumSquared(__shared_ptr__ const Tdata *data_ptr, size for (size_t i = core_id(); i < count; i += BLOCK_SIZE) { Tdata xi = data_ptr[i]; - ss += to(xi) * to(xi); + ss += Tcompute(xi) * Tcompute(xi); } __shared__ Tcompute temp_storage; if (core_id() == 0) { - temp_storage = to(0.f); + temp_storage = Tcompute(0.f); } sync_cluster(); @@ -36,12 +36,12 @@ __device__ inline Tcompute sum(__shared_ptr__ const Tdata *data_ptr, size_t coun for (size_t i = core_id(); i < count; i += BLOCK_SIZE) { Tdata xi = data_ptr[i]; - ss += to(xi); + ss += Tcompute(xi); } __shared__ Tcompute temp_storage; if (core_id() == 0) { - temp_storage = to(0.f); + temp_storage = Tcompute(0.f); } sync_cluster(); @@ -58,7 +58,7 @@ __device__ inline Tdata max(__shared_ptr__ const Tdata *data_ptr, size_t count) for (size_t i = core_id(); i < count; i += BLOCK_SIZE) { Tdata xi = data_ptr[i]; - max_val = fmax(max_val, to(xi)); + max_val = fmax(max_val, Tdata(xi)); } __shared__ Tdata temp_storage;