Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 24 additions & 1 deletion aten/src/ATen/native/cuda/Copy.cu
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,19 @@ void bfloat16_copy_kernel_cuda(TensorIteratorBase &iter) {
});
}

#ifdef USE_ROCM
void bfloat16tofloat32_copy_kernel_cuda(TensorIteratorBase &iter) {
gpu_kernel_nocast(iter, [] GPU_LAMBDA(at::BFloat16 value) {
return static_cast<float>(value);
});
}
void float16tofloat32_copy_kernel_cuda(TensorIteratorBase &iter) {
gpu_kernel_nocast(iter, [] GPU_LAMBDA(at::Half value) {
return static_cast<float>(value);
});
}
#endif

void float8_copy_kernel_cuda(TensorIteratorBase &iter) {
ScalarType dtype = iter.dtype(0);
ScalarType other_dtype = iter.dtype(1);
Expand Down Expand Up @@ -187,7 +200,17 @@ void direct_copy_kernel_cuda(TensorIteratorBase &iter) {
} else {
float16_copy_kernel_cuda(iter);
}
} else if (isBitsType(dtype)) {
}
#ifdef USE_ROCM
else if ((iter.dtype(1) == kBFloat16 || iter.dtype(1) == kHalf) && dtype == kFloat) {
if (iter.dtype(1) == kBFloat16) {
bfloat16tofloat32_copy_kernel_cuda(iter);
} else {
float16tofloat32_copy_kernel_cuda(iter);
}
}
#endif
else if (isBitsType(dtype)) {
TORCH_CHECK(dtype == iter.dtype(1), "copy_() does not support casting "
"bits types to different bits types. Source dtype is ", iter.dtype(1), "target dtype is ", dtype);
AT_DISPATCH_BIT_TYPES(dtype, "copy_", [&] {
Expand Down