Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for torch.gather #252

Merged
merged 8 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions thunder/clang/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,12 @@ def take_along_axis(a: TensorProxy, /, indices: TensorProxy, dim: int) -> Tensor
return prims.take_along_axis(a, indices, dim)


@clangop()
def gather(a: TensorProxy, /, indices: TensorProxy, dim: int) -> TensorProxy:
dim = utils.canonicalize_dim(a.ndim, dim)
return prims.gather(a, indices, dim)


@clangop()
def scatter_add(a: TensorProxy, /, indices: TensorProxy, value: TensorProxy, dim: int) -> TensorProxy:
dim = utils.canonicalize_dim(a.ndim, dim)
Expand Down
24 changes: 24 additions & 0 deletions thunder/core/prims.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ class PrimIDs(Enum):
ARGMIN = auto()
TOPK = auto()
# Scatter and gather prims (Experimental!)
GATHER = auto()
INDEX_ADD = auto()
INDEX_PUT = auto()
SCATTER_ADD = auto()
Expand Down Expand Up @@ -3082,6 +3083,29 @@ def take_along_axis_meta(a: TensorProxy, /, index: TensorProxy, dim: int) -> Ten
take_along_axis = make_prim(PrimIDs.TAKE_ALONG_AXIS, "take_along_axis", meta=take_along_axis_meta)


def gather_meta(a: TensorProxy, /, index: TensorProxy, dim: int) -> TensorProxy:
utils.check_type(a, TensorProxy)
utils.check_type(index, TensorProxy)
utils.check_type(dim, int)
utils.check_same_device(a, index)
utils.check(utils.is_integer_dtype(index.dtype), lambda: f"index dtype={index.dtype} was not an integer dtype")
utils.check(
index.ndim == a.ndim, lambda: f"Expected index (rank={index.ndim}) to have the same rank as a (rank={a.ndim})"
)
utils.validate_idx(a.ndim, dim)

for idx, l in enumerate(index.shape):
if idx != dim:
utils.check(
index.shape[idx] <= a.shape[idx],
lambda: f"Expected 'index' size on all dimensions to be <= 'a', except `dim`. Found dim {idx}, where 'index' has {index.shape[idx]} and 'a' has {a.shape[idx]}",
)
return TensorProxy(like=a, shape=index.shape)


gather = make_prim(PrimIDs.GATHER, "gather", meta=gather_meta)


def scatter_add_meta(a: TensorProxy, /, index: TensorProxy, value: TensorProxy, dim: int) -> TensorProxy:
utils.check_type(a, TensorProxy)
utils.check_type(index, TensorProxy)
Expand Down
19 changes: 18 additions & 1 deletion thunder/core/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -815,12 +815,29 @@ def _take_prim_grad(a: TensorProxy, index: TensorProxy, dim: int) -> TensorProxy
register_grad(pids.TAKE, _take_prim_grad)


@torchctx
def _gather_prim_grad(a: TensorProxy, index: TensorProxy, dim: int) -> TensorProxy:
fwd = prims.gather(a, index, dim)

g = get_grad(fwd)
# NOTE Intentionally not calling zeros_like to avoid preserving TensorProxy a.
# TODO Update to call ltorch.zeros
zeros = prims.full(a.shape, fill_value=0, device=a.device, dtype=a.dtype)
a_grad = prims.scatter_add(zeros, index, g, dim)
put_grad(a, a_grad)

return fwd


register_grad(pids.GATHER, _gather_prim_grad)


@torchctx
def _take_along_axis_prim_grad(a: TensorProxy, index: TensorProxy, dim: int) -> TensorProxy:
fwd = prims.take_along_axis(a, index, dim)

g = get_grad(fwd)
# NOTE Intentionally not calling zeros_like to avoid preserving a
# NOTE Intentionally not calling zeros_like to avoid preserving TensorProxy a.
# TODO Update to call ltorch.zeros
zeros = prims.full(a.shape, fill_value=0, device=a.device, dtype=a.dtype)
a_grad = prims.scatter_add(zeros, index, g, dim)
Expand Down
13 changes: 13 additions & 0 deletions thunder/executors/torchex.py
Original file line number Diff line number Diff line change
Expand Up @@ -1099,6 +1099,7 @@ def _topk_transform(
# Scatter and gather operations
#

gather = _register_torch_operation("gather")
index_add = _register_torch_operation("index_add")
index_put = _register_torch_operation("index_put")
scatter_add = _register_torch_operation("scatter_add")
Expand All @@ -1117,6 +1118,16 @@ def _index_put_prim_transform(
return index_put(a, indices, values, accumulate)


@langctx(Languages.TORCH)
def _gather_prim_transform(a: TensorProxy, /, index: TensorProxy, dim: int) -> TensorProxy:
return gather(a, dim, index)


@langctx(Languages.TORCH)
def _gather_transform(a: TensorLike, /, dim: int, index: TensorLike) -> TensorLike:
return gather(a, dim, index)


# NOTE torch.compile has a compilation issue with scatter add in bfloat16,
# hence the special case here.
# NOTE The scatter add transforms must set the torch language context explicitly so the .to() method
Expand Down Expand Up @@ -1152,6 +1163,7 @@ def _take_along_axis_prim_transform(a: TensorProxy, /, index: TensorProxy, dim:
return take_along_dim(a, index, dim)


_register_implementation(prims.gather, checker=_always_executable, execution_transform=_gather_prim_transform)
_register_implementation(prims.index_add, checker=_always_executable, execution_transform=_index_add_prim_transform)
_register_implementation(prims.index_put, checker=_always_executable, execution_transform=_index_put_prim_transform)
_register_implementation(prims.scatter_add, checker=_always_executable, execution_transform=_scatter_add_prim_transform)
Expand All @@ -1160,6 +1172,7 @@ def _take_along_axis_prim_transform(a: TensorProxy, /, index: TensorProxy, dim:
prims.take_along_axis, checker=_always_executable, execution_transform=_take_along_axis_prim_transform
)

_register_implementation(ltorch.gather, checker=_always_executable, execution_transform=_gather_transform)
_register_implementation(ltorch.index_add, index_add, checker=_always_executable)
_register_implementation(ltorch.index_put, index_put, checker=_always_executable)
_register_implementation(ltorch.index_select, index_select, checker=_always_executable)
Expand Down
36 changes: 36 additions & 0 deletions thunder/tests/opinfos.py
Original file line number Diff line number Diff line change
Expand Up @@ -4345,6 +4345,42 @@ def take_along_axis_sample_generator(op, device, dtype, requires_grad, **kwargs)
shape_ops.append(take_along_axis_opinfo)


def gather_sample_generator(op, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
# torch.gather expects index to be long but not int
# Index is not differentiable! Marking requires_grad as False
make_index = partial(make_tensor, device=device, dtype=torch.long, requires_grad=False)

for shape_a, dim, shape_b in take_along_axis_cases:
canonicalized_dim = dim if dim >= 0 else dim + len(shape_a)
a = make(shape_a)
b = make_index(shape_b, low=0, high=shape_a[dim])
yield SampleInput(a, index=b, dim=dim)

# Note that gather doesn't have the broadcast requirement, it only requires
# 1. a.shape[i] >= index.shape[i] for i != dim
#
# a.shape, dim, index.shape
scatter_add_cases = (
((4, 5, 3), 0, (3, 2, 3)),
((4, 5, 3), 1, (3, 5, 2)),
((4, 5, 3), 2, (3, 2, 8)),
)
for shape_a, dim, shape_b in scatter_add_cases:
a = make(shape_a)
b = make_index(shape_b, low=0, high=shape_a[dim])
yield SampleInput(a, index=b, dim=dim)


gather_opinfo = OpInfo(
ltorch.gather,
supports_grad=True,
sample_input_generator=gather_sample_generator,
torch_reference=torch.gather,
)
shape_ops.append(gather_opinfo)


def scatter_add_sample_generator(op, device, dtype, requires_grad, **kwargs):
make = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
# torch.scatter_add expects index to be long but not int
Expand Down
5 changes: 5 additions & 0 deletions thunder/torch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1900,6 +1900,11 @@ def index_select(a: TensorLike, /, dim: int, index: TensorLike) -> TensorLike:
return clang.take(a, index, dim)


@torchsymbol(torch.gather)
def gather(a: TensorLike, /, dim: int, index: TensorLike) -> TensorLike:
return clang.gather(a, indices=index, dim=dim)


# NOTE PyTorch's scatter_add has a parameter named 'src', not 'source'
@torchsymbol(torch.scatter_add)
def scatter_add(a: TensorLike, /, dim: int, index: TensorLike, src: TensorLike) -> TensorLike:
Expand Down
Loading