diff --git a/c10/cuda/CUDACachingAllocator.cpp b/c10/cuda/CUDACachingAllocator.cpp index 4a1e4654f9207..052a5201d53e8 100644 --- a/c10/cuda/CUDACachingAllocator.cpp +++ b/c10/cuda/CUDACachingAllocator.cpp @@ -3261,6 +3261,15 @@ class DeviceCachingAllocator { } }; +static bool zeroAllocations() { + static auto has_cuda_env = + c10::utils::check_env("PYTORCH_CUDA_MEMORY_CACHING_MEMSET_ZEROS") == true; + static auto has_rocm_env = + c10::utils::check_env("PYTORCH_HIP_MEMORY_CACHING_MEMSET_ZEROS") == true; + static bool zeros = has_cuda_env || has_rocm_env; + return zeros; +} + // Returns whether to force all allocations to bypass the caching allocator and // go straight to cudaMalloc. This setting is useful when debugging GPU memory // errors, since the caching allocator foils cuda-memcheck. @@ -3652,6 +3661,10 @@ class NativeCachingAllocator : public CUDAAllocator { TORCH_SDT_WITH_SEMAPHORE(malloc, devPtr, device, size, stream.id()); } + if (zeroAllocations()) { + C10_CUDA_CHECK(cudaMemsetAsync(devPtr, 0, size, stream)); + } + return {devPtr, devPtr, deleteFunc, Device(DeviceType::CUDA, device)}; } DeleterFnPtr raw_deleter() const override { @@ -3734,6 +3747,12 @@ class NativeCachingAllocator : public CUDAAllocator { C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); malloc(&r, device, nbytes, cuda::getCurrentCUDAStream(device)); } + if (zeroAllocations()) { + c10::DeviceIndex device = 0; + C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); + C10_CUDA_CHECK( + cudaMemsetAsync(r, 0, nbytes, cuda::getCurrentCUDAStream(device))); + } return r; } @@ -3749,6 +3768,9 @@ class NativeCachingAllocator : public CUDAAllocator { C10_CUDA_CHECK(c10::cuda::GetDevice(&device)); malloc(&r, device, nbytes, stream); } + if (zeroAllocations()) { + C10_CUDA_CHECK(cudaMemsetAsync(r, 0, nbytes, stream)); + } return r; }