Skip to content

Commit

Permalink
[ROCm] revert cat operator performance work-around (#987)
Browse files Browse the repository at this point in the history
revert d5ca53c (pytorch#46097).  The changes only affect ROCm.  Reverts a work-around for a compiler performance issue that is no longer needed.

`python -m pt.cat_test --tag_filter all --device cuda`

```

OLD Forward Execution Time (us) : 48.833
NEW Forward Execution Time (us) : 8.318

OLD Forward Execution Time (us) : 54.508
NEW Forward Execution Time (us) : 23.824

OLD Forward Execution Time (us) : 52.117
NEW Forward Execution Time (us) : 14.942

OLD Forward Execution Time (us) : 98.790
NEW Forward Execution Time (us) : 74.334

OLD Forward Execution Time (us) : 102.063
NEW Forward Execution Time (us) : 76.008

OLD Forward Execution Time (us) : 167.786
NEW Forward Execution Time (us) : 123.679

OLD Forward Execution Time (us) : 98.320
NEW Forward Execution Time (us) : 67.436

OLD Forward Execution Time (us) : 91.484
NEW Forward Execution Time (us) : 59.230

OLD Forward Execution Time (us) : 109.569
NEW Forward Execution Time (us) : 76.557

OLD Forward Execution Time (us) : 106.603
NEW Forward Execution Time (us) : 87.635

OLD Forward Execution Time (us) : 106.693
NEW Forward Execution Time (us) : 88.902

OLD Forward Execution Time (us) : 110.881
NEW Forward Execution Time (us) : 94.361

OLD Forward Execution Time (us) : 122.925
NEW Forward Execution Time (us) : 123.046

OLD Forward Execution Time (us) : 272.442
NEW Forward Execution Time (us) : 271.932

OLD Forward Execution Time (us) : 457.329
NEW Forward Execution Time (us) : 456.767

OLD Forward Execution Time (us) : 117.688
NEW Forward Execution Time (us) : 87.133

OLD Forward Execution Time (us) : 873.764
NEW Forward Execution Time (us) : 865.075

OLD Forward Execution Time (us) : 1746.831
NEW Forward Execution Time (us) : 1730.252

OLD Forward Execution Time (us) : 2619.303
NEW Forward Execution Time (us) : 2598.717

OLD Forward Execution Time (us) : 52.063
NEW Forward Execution Time (us) : 7.904

OLD Forward Execution Time (us) : 52.275
NEW Forward Execution Time (us) : 8.118

OLD Forward Execution Time (us) : 51.896
NEW Forward Execution Time (us) : 7.938

OLD Forward Execution Time (us) : 51.745
NEW Forward Execution Time (us) : 7.922

OLD Forward Execution Time (us) : 52.575
NEW Forward Execution Time (us) : 13.299

OLD Forward Execution Time (us) : 52.090
NEW Forward Execution Time (us) : 8.015
```
Pull Request resolved: pytorch#74129
Approved by: https://github.com/ngimel
  • Loading branch information
jeffdaily authored and jithunnair-amd committed Sep 28, 2022
1 parent 3af29a2 commit da4170e
Showing 1 changed file with 0 additions and 178 deletions.
178 changes: 0 additions & 178 deletions aten/src/ATen/native/cuda/Shape.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,7 @@
namespace at {
namespace native {

#ifdef __HIP_PLATFORM_HCC__
constexpr int CAT_ARRAY_BATCH_SIZE = 1024;
#else
constexpr int CAT_ARRAY_BATCH_SIZE = 128;
#endif
constexpr int CAT_ARRAY_MAX_INPUT_DIMS = 4;

namespace {
Expand Down Expand Up @@ -85,45 +81,6 @@ struct TensorSizeStride {
*/


// Use pinned memory and and pass the struct by pointer on ROCm
template <typename T, typename IndexType>
struct CatArrInputTensor {
T* input;
IndexType offset;
IndexType dimSize;
IndexType nElements;
};

template <typename T, typename IndexType, int Dims>
C10_LAUNCH_BOUNDS_1(512)
__global__ void HIP_CatArrayBatchedCopy(
T* output,
CatArrInputTensor<T, IndexType>* inputs,
TensorSizeStride<IndexType, CAT_ARRAY_MAX_INPUT_DIMS> os,
const int concatDim,
IndexType dimStride) {

IndexType tid = blockIdx.x * blockDim.x + threadIdx.x;
IndexType nElements = inputs[blockIdx.y].nElements;

if(tid >= nElements) return;

T* data = inputs[blockIdx.y].input;
IndexType offset = inputs[blockIdx.y].offset;
IndexType dimSize = inputs[blockIdx.y].dimSize;
IndexType dataOffset = offset * dimStride;

IndexType stride = gridDim.x * blockDim.x;

while( tid < nElements){
IndexType elementOffset = CatArrIndexToOffset<IndexType, Dims>::compute(
os.tensorSize, os.tensorStride, dimSize, concatDim, tid);
output[dataOffset + elementOffset] = data[tid];

tid += stride;
}
}

// pass meta data directly through kernel argument instead of pin memory
// In contiguous case, we will not need stride_size, setting it as 1 as placeholder
// to pass compile.
Expand Down Expand Up @@ -173,127 +130,6 @@ __global__ void CatArrayBatchedCopy(
}
}

template <typename scalar_t>
void hip_parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
int nDims, c10::MemoryFormat memory_format) {
// First, let's set up our kernel parameters. We start with a raw pointer to
// the storage for the output Tensor.
scalar_t *data = out.data_ptr<scalar_t>();

// Kernel Parameter
long tensorMetadataSize =
sizeof(CatArrInputTensor<scalar_t, unsigned int>) * CAT_ARRAY_BATCH_SIZE;
auto d_inputs_storage = at::empty(
{tensorMetadataSize}, out.options().dtype(at::kByte));
auto d_inputs = static_cast<CatArrInputTensor<scalar_t, unsigned int> *>(
d_inputs_storage.data_ptr());

TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> outputParam;

// Next, let's initialize the size, stride arrays for the output Tensor.
if (memory_format == c10::MemoryFormat::Contiguous) {
for (int i = 0; i < nDims; ++i) {
outputParam.tensorSize[i] = at::native::size(out, i);
outputParam.tensorStride[i] = out.stride(i);
}
} else if (memory_format == c10::MemoryFormat::ChannelsLast || memory_format == c10::MemoryFormat::ChannelsLast3d) {
// permute the semantics of dims from NCHW to NHWC so that the input
// tensor is now contiguous
outputParam.tensorSize[0] = at::native::size(out, 0);
outputParam.tensorStride[0] = out.stride(0);
for (int i = 1; i < nDims - 1; ++i) {
outputParam.tensorSize[i] = at::native::size(out, i + 1);
outputParam.tensorStride[i] = out.stride(i + 1);
}
outputParam.tensorSize[nDims - 1] = at::native::size(out, 1);
outputParam.tensorStride[nDims - 1] = out.stride(1);
} else {
TORCH_CHECK(false, "unsupported memory format");
}

at::cuda::CUDAStream stream = at::cuda::getCurrentCUDAStream();

// Now we loop
int batchCounter = 0;
int64_t offset = 0;
for (int i = 0; i < inputs.size() ; i += CAT_ARRAY_BATCH_SIZE) {
// Re-allocate stackInputs every iteration to avoid read-after-write hazard
{
auto stackInputs_storage = at::empty({tensorMetadataSize},
out.options().dtype(at::kByte).device(at::kCPU).pinned_memory(true));
auto stackInputs =
static_cast<CatArrInputTensor<scalar_t, unsigned int> *>(
stackInputs_storage.data_ptr());
for (batchCounter = 0;
batchCounter < CAT_ARRAY_BATCH_SIZE &&
(i+batchCounter) < inputs.size();
++batchCounter) {
int64_t dimSize = 0;
// There is a legacy case where a 1-D empty tensor can be concat with
// high-dimensional tensor
if (inputs[i+batchCounter].numel() > 0) {
dimSize = at::native::size(inputs[i+batchCounter], dimension);
}

stackInputs[batchCounter].input =
inputs[i+batchCounter].data_ptr<scalar_t>();
stackInputs[batchCounter].offset = offset;
stackInputs[batchCounter].dimSize = dimSize;
stackInputs[batchCounter].nElements = inputs[i+batchCounter].numel();

// update offset
offset += dimSize;
}
at::native::copy_(d_inputs_storage, stackInputs_storage,
/* non_blocking= */ true);
}

// Next, let's consider how we set our kernel launch parameters.
// We borrow from THCApply, which the kernel's internal indexing
// is based on.
dim3 applyBlock = dim3(32*16);

//Get grid where x dim fills half gpu and y dim is number of tensors.
//This will have cating two tensors fill the entire grid, but prevent
//many threads from needlessly load meta data if their sizes is small.
dim3 catGrid;
getCatGrid(batchCounter, catGrid);

if (memory_format != c10::MemoryFormat::Contiguous) {
switch (dimension) {
case 0:
break;
case 1:
dimension = nDims - dimension;
break;
default:
dimension--;
}
}
// Template Declarations for dim = 1, 2, 3, 4
#define HANDLE_CASE(DIMS) \
HIP_CatArrayBatchedCopy<scalar_t, unsigned int, DIMS><<<\
catGrid, applyBlock, 0, stream.stream()>>>(\
data, d_inputs, outputParam, dimension, outputParam.tensorStride[dimension]); \
C10_CUDA_KERNEL_LAUNCH_CHECK();
switch (nDims) {
case 1:
HANDLE_CASE(1);
break;
case 2:
HANDLE_CASE(2);
break;
case 3:
HANDLE_CASE(3);
break;
case 4:
HANDLE_CASE(4);
break;
}
#undef HANDLE_CASE
}
}

template <typename scalar_t, int batch_size, int stride_size>
void parallel_cat(Tensor &out, const TensorList &inputs, int64_t dimension,
int nDims, c10::MemoryFormat memory_format) {
Expand Down Expand Up @@ -546,19 +382,6 @@ Tensor& cat_out_cuda(TensorList inputs, int64_t dimension, Tensor& out) {
});
allSameType = allSameType && (out.scalar_type() == firstType);

#ifdef __HIP_PLATFORM_HCC__
if (inputs.size() > 1 &&
out.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
at::cuda::detail::canUse32BitIndexMath(out) &&
allContiguous &&
all32BitIndexable &&
allSameType) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND3(
at::ScalarType::Half, at::ScalarType::Bool, at::ScalarType::BFloat16,
out.scalar_type(), "cat_cuda", [&]() {
hip_parallel_cat<scalar_t>(out, inputs, dimension, nDims, memory_format);
});
#else
// We support the contiguous inputs and non-contiguous input (<=4 dims) in different ways
// For contiguous input, we don't need to pass stride meta data to cuda kernel through constant
// memory. Therefore, we could pass more inputs to cuda threads.
Expand Down Expand Up @@ -587,7 +410,6 @@ Tensor& cat_out_cuda(TensorList inputs, int64_t dimension, Tensor& out) {
out.scalar_type(), "cat_cuda", [&]() {
parallel_cat<scalar_t, CAT_ARRAY_BATCH_SIZE/2, CAT_ARRAY_BATCH_SIZE/2>(out, inputs, dimension, nDims, memory_format);
});
#endif
} else {
int64_t offset = 0;
for (int j = 0; j < inputs.size(); j++)
Expand Down

0 comments on commit da4170e

Please sign in to comment.