From d03ca070f7b1a9c89ad6a25eeddcb515911446fb Mon Sep 17 00:00:00 2001 From: Jerry Mannil <65309407+jerrymannil@users.noreply.github.com> Date: Thu, 25 Sep 2025 16:43:58 -0700 Subject: [PATCH] [ROCm] Implement float32 copy kernel cherry-pick of https://github.com/pytorch/pytorch/pull/163869 --- aten/src/ATen/native/cuda/Copy.cu | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/aten/src/ATen/native/cuda/Copy.cu b/aten/src/ATen/native/cuda/Copy.cu index 402ed1fbd554..425cfe273785 100644 --- a/aten/src/ATen/native/cuda/Copy.cu +++ b/aten/src/ATen/native/cuda/Copy.cu @@ -42,6 +42,19 @@ void bfloat16_copy_kernel_cuda(TensorIteratorBase &iter) { }); } +#ifdef USE_ROCM +void bfloat16tofloat32_copy_kernel_cuda(TensorIteratorBase &iter) { + gpu_kernel_nocast(iter, [] GPU_LAMBDA(at::BFloat16 value) { + return static_cast(value); + }); +} +void float16tofloat32_copy_kernel_cuda(TensorIteratorBase &iter) { + gpu_kernel_nocast(iter, [] GPU_LAMBDA(at::Half value) { + return static_cast(value); + }); +} +#endif + void float8_copy_kernel_cuda(TensorIteratorBase &iter) { ScalarType dtype = iter.dtype(0); ScalarType other_dtype = iter.dtype(1); @@ -187,7 +200,17 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) { } else { float16_copy_kernel_cuda(iter); } - } else if (isBitsType(dtype)) { + } +#ifdef USE_ROCM + else if ((iter.dtype(1) == kBFloat16 || iter.dtype(1) == kHalf) && dtype == kFloat) { + if (iter.dtype(1) == kBFloat16) { + bfloat16tofloat32_copy_kernel_cuda(iter); + } else { + float16tofloat32_copy_kernel_cuda(iter); + } + } +#endif + else if (isBitsType(dtype)) { TORCH_CHECK(dtype == iter.dtype(1), "copy_() does not support casting " "bits types to different bits types. Source dtype is ", iter.dtype(1), "target dtype is ", dtype); AT_DISPATCH_BIT_TYPES(dtype, "copy_", [&] {