Skip to content

gather-to-size-one domain should have a thread predicate #282

@naoyam

Description

@naoyam

take_along_axis or resize that generates a broadcast ID, and if the broadcast ID is parallelized, the output tensor should be predicated with the parallel dimension. Or, if the broadcast ID is resolved, it should be really broadcast to the other threads. If the broadcast ID is immediately squeezed, we don't need to broadcast it.

Example:

auto tv0 = [I1];
auto tv1 = [B];
auto tv2 = [I2];

// [B]
auto tv3 = take_along_axis(tv0, tv1, 0); 

// [I2]
auto tv4 = add(tv2, tv3);

for (auto tv: all_tvs) {
  tv->axis(0)->parallelize(ParallelType::TIDx);
}

Right now, only TIDx==0 has the valid value of tv3, so it's predicated by TIDx, however, the current thread predication doesn't detect these patterns.

And since tv4 is also parallelized by TIDx, tv3 needs to be available at all threads. We could issue a real parallel broadcast, but in this case we should just let all threads do the take_along_axis op as the input should be accessible from all threads.

Found while working on #250. See TakeAlongAxisIntermediateTensorReduction4.

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions