Skip to content

Commit

Permalink
enable cat for cuda bits types (pytorch#115044)
Browse files Browse the repository at this point in the history
It was already working for cpu, so bring parity.
Also, slightly reduce number of compiled kernels by using OpaqueType.

Pull Request resolved: pytorch#115044
Approved by: https://github.com/malfet
  • Loading branch information
ngimel authored and ZhiweiYan-96 committed Dec 22, 2023
1 parent f8872af commit e95e266
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 17 deletions.
47 changes: 35 additions & 12 deletions aten/src/ATen/native/cuda/Shape.cu
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
int nDims, c10::MemoryFormat memory_format) {
// First, let's set up our kernel parameters. We start with a raw pointer to
// the storage for the output Tensor.
scalar_t *data = out.mutable_data_ptr<scalar_t>();
scalar_t *data = (scalar_t *)(out.mutable_data_ptr());
CatArrInputTensorMetadata<scalar_t, unsigned int, batch_size, stride_size> catMetaData;
TensorSizeStride<unsigned int, CAT_ARRAY_MAX_INPUT_DIMS> outputParam;

Expand Down Expand Up @@ -289,7 +289,7 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
dimSize = inputs[i+batchCounter].get().size(dimension);
}

catMetaData.input[batchCounter] = inputs[i+batchCounter].get().const_data_ptr<scalar_t>();
catMetaData.input[batchCounter] = (scalar_t*)(inputs[i+batchCounter].get().const_data_ptr());
catMetaData.offset[batchCounter] = offset;
catMetaData.dimSize[batchCounter] = dimSize;
catMetaData.nElements[batchCounter] = inputs[i+batchCounter].get().numel();
Expand Down Expand Up @@ -375,6 +375,10 @@ void parallel_cat(const Tensor &out, const MaterializedITensorListRef& inputs, i
#undef HANDLE_CASE
}
}
// The kernels are templated on an opaque, self-aligned type of the correct
// size to avoid redundant kernels for different types of the same size.
template <unsigned N> struct alignas(N) OpaqueType { char data[N]; };

} // namespace

TORCH_IMPL_FUNC(cat_out_cuda)
Expand Down Expand Up @@ -412,29 +416,48 @@ TORCH_IMPL_FUNC(cat_out_cuda)
// memory. Therefore, we could pass more inputs to cuda threads.
// For non-contiguous, we reduce the number of inputs passed to cuda kernel due to the limitation
// of constant memory.



if (materialized.size() > 1 &&
result.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
at::cuda::detail::canUse32BitIndexMath(result) &&
all_contiguous &&
all32BitIndexable &&
all_same_dtype) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf, kHalf, kBool, kBFloat16,
result.scalar_type(), "cat_cuda", [&]() {
parallel_cat<scalar_t, CAT_ARRAY_BATCH_SIZE, 1>(result, materialized, dim, nDims, memory_format);
});
if (isBitsType(result.scalar_type())) {
AT_DISPATCH_BIT_TYPES(result.scalar_type(), "cat_cuda", [&]() {
using dtype = OpaqueType<sizeof(scalar_t)>;
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE, 1>(result, materialized, dim, nDims, memory_format);
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf, kHalf, kBool, kBFloat16,
result.scalar_type(), "cat_cuda", [&]() {
using dtype = OpaqueType<sizeof(scalar_t)>;
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE, 1>(result, materialized, dim, nDims, memory_format);
});
}
} else if (materialized.size() > 1 &&
result.dim() <= CAT_ARRAY_MAX_INPUT_DIMS &&
at::cuda::detail::canUse32BitIndexMath(result) &&
nDims <= CAT_ARRAY_MAX_INPUT_DIMS &&
all32BitIndexable &&
all_same_dtype &&
memory_format == c10::MemoryFormat::Contiguous) {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf, kHalf, kBool, kBFloat16,
result.scalar_type(), "cat_cuda", [&]() {
parallel_cat<scalar_t, CAT_ARRAY_BATCH_SIZE/2, CAT_ARRAY_BATCH_SIZE/2>(result, materialized, dim, nDims, memory_format);
});
if (isBitsType(result.scalar_type())) {
AT_DISPATCH_BIT_TYPES(result.scalar_type(), "cat_cuda", [&]() {
using dtype = OpaqueType<sizeof(scalar_t)>;
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE/2, CAT_ARRAY_BATCH_SIZE/2>(result, materialized, dim, nDims, memory_format);
});
} else {
AT_DISPATCH_ALL_TYPES_AND_COMPLEX_AND4(
kComplexHalf, kHalf, kBool, kBFloat16,
result.scalar_type(), "cat_cuda", [&]() {
using dtype = OpaqueType<sizeof(scalar_t)>;
parallel_cat<dtype, CAT_ARRAY_BATCH_SIZE/2, CAT_ARRAY_BATCH_SIZE/2>(result, materialized, dim, nDims, memory_format);
});
}
} else {
int64_t offset = 0;
for (const Tensor& t : materialized) {
Expand Down
30 changes: 26 additions & 4 deletions test/quantization/core/experimental/test_bits.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,14 @@
# Owner(s): ["oncall: quantization"]

import torch
from torch.testing._internal.common_device_type import instantiate_device_type_tests

from torch.testing._internal.common_utils import run_tests, TestCase
from torch.utils._mode_utils import no_dispatch
from torch.utils._pytree import tree_map

import itertools

class Int16Tensor(torch.Tensor):
def __new__(cls, elem):
assert elem.dtype == torch.bits16
Expand Down Expand Up @@ -41,24 +45,42 @@ def __repr__(self) -> str:


class TestBits(TestCase):
def test_types(self):
def test_types(self, device):
bits_types = [torch.bits1x8, torch.bits2x4, torch.bits4x2, torch.bits8, torch.bits16]
for bits_type in bits_types:
_ = torch.zeros(20, dtype=torch.int32).view(bits_type)
_ = torch.empty(20, dtype=bits_type)
x = torch.randint(100, (20, 20), dtype=torch.int8).view(bits_type)
_ = torch.zeros(20, dtype=torch.int32, device=device).view(bits_type)
_ = torch.empty(20, dtype=bits_type, device=device)
x = torch.randint(100, (20, 20), dtype=torch.int8, device=device).view(bits_type)
y = x.t().contiguous()
view_type = torch.int8 if x.element_size() == 1 else torch.int16
self.assertEqual(x.t().view(view_type), y.view(view_type))
y = x.t().clone()
self.assertEqual(x.t().view(view_type), y.view(view_type))

def test_cat(self, device):
bits_types = [torch.bits1x8, torch.bits2x4, torch.bits4x2, torch.bits8, torch.bits16]
for bits_type in bits_types:
view_type = torch.int8 if bits_type.itemsize == 1 else torch.int16
x_int = torch.randint(100, (512, 512), dtype=view_type, device=device)
x = x_int.view(bits_type)
y_int = torch.randint(100, (512, 512), dtype=view_type, device=device)
y = y_int.view(bits_type)
for dim, transpose in itertools.product(range(x_int.ndim), (True, False)):
y_ref = y_int.t() if transpose else y_int
y_b = y.t() if transpose else y
z_ref = torch.cat([x_int, y_ref], dim=dim)
z = torch.cat([x, y_b], dim=dim)
self.assertEqual(z_ref, z.view(view_type))


def test_subclass(self):
t = torch.zeros(20, dtype=torch.int16).view(torch.bits16)
s = Int16Tensor(t)
s = s + 1 - 1
self.assertTrue(torch.allclose(s, torch.zeros(20, dtype=torch.bits16)))

instantiate_device_type_tests(TestBits, globals())


if __name__ == '__main__':
run_tests()
9 changes: 8 additions & 1 deletion test/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,14 @@
logging.warning(e)

# Experimental functionality
from quantization.core.experimental.test_bits import TestBits # noqa: F401
try:
from quantization.core.experimental.test_bits import TestBitsCPU # noqa: F401
except ImportError as e:
logging.warning(e)
try:
from quantization.core.experimental.test_bits import TestBitsCUDA # noqa: F401
except ImportError as e:
logging.warning(e)
try:
from quantization.core.experimental.test_float8 import TestFloat8DtypeCPU # noqa: F401
except ImportError as e:
Expand Down

0 comments on commit e95e266

Please sign in to comment.