-
Notifications
You must be signed in to change notification settings - Fork 78
Description
take_along_axis or resize that generates a broadcast ID, and if the broadcast ID is parallelized, the output tensor should be predicated with the parallel dimension. Or, if the broadcast ID is resolved, it should be really broadcast to the other threads. If the broadcast ID is immediately squeezed, we don't need to broadcast it.
Example:
auto tv0 = [I1];
auto tv1 = [B];
auto tv2 = [I2];
// [B]
auto tv3 = take_along_axis(tv0, tv1, 0);
// [I2]
auto tv4 = add(tv2, tv3);
for (auto tv: all_tvs) {
tv->axis(0)->parallelize(ParallelType::TIDx);
}
Right now, only TIDx==0 has the valid value of tv3, so it's predicated by TIDx, however, the current thread predication doesn't detect these patterns.
And since tv4 is also parallelized by TIDx, tv3 needs to be available at all threads. We could issue a real parallel broadcast, but in this case we should just let all threads do the take_along_axis op as the input should be accessible from all threads.
Found while working on #250. See TakeAlongAxisIntermediateTensorReduction4.