-
Notifications
You must be signed in to change notification settings - Fork 80
RNG output depends on segment execution order #5105
Copy link
Copy link
Open
Labels
Description
I found this when debugging #4973. That PR reorders segment execution order sometimes. When the segments include RNG methods (rand, randn, etc.) the values we compute depend on the segment execution order. This might not be a problem in practice but we do seem to be testing that this is the expected behavior: bin/test_rng --gtest_filter=RNGTest.Normal is an example that runs 4 separate rng-based kernels and reordering the segments causes the test to fail.
Lines 357 to 367 in c490444
| fusion->addInput(size_val); | |
| fusion->addInput(mean); | |
| fusion->addInput(std); | |
| TensorView* tv0 = normal({size_val}, mean, std, DataType::Float); | |
| TensorView* tv1 = normal({size_val}, mean, std, DataType::Double); | |
| TensorView* tv2 = randn({size_val}, DataType::Double); | |
| TensorView* tv3 = randn_like(tv2); | |
| fusion->addOutput(tv0); | |
| fusion->addOutput(tv1); | |
| fusion->addOutput(tv2); | |
| fusion->addOutput(tv3); |
To repro, try making this change then running bin/test_rng --gtest_filter=RNGTest.Normal:
diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp
index 22bf34d30..f08ff9395 100644
--- a/csrc/fusion_segmenter.cpp
+++ b/csrc/fusion_segmenter.cpp
@@ -5369,6 +5369,12 @@ RuntimeWorkSpace prepareRuntimeOrder(const SegmentedFusion& segmented_fusion) {
}
runtime_workspace.group_run_order = toposort(segmented_fusion.groups());
+ runtime_workspace.group_run_order = {
+ runtime_workspace.group_run_order[3],
+ runtime_workspace.group_run_order[2],
+ runtime_workspace.group_run_order[1],
+ runtime_workspace.group_run_order[0],
+ };
return runtime_workspace;
}Possible fixes:
- We could try and enforce execution order of rng factory functions when sorting segments. It might still be possible to reorder RNG methods within a segment.
- We could introduce a preseg pass where we set the seed and offset of each RNG method based on the definition order, using
expr->name()to determine definition order. This is not ideal because it means we would be baking in the random seed.
In any case, I think we should clearly document the circumstances in which we expect pytorch and nvFuser output to match.
Reactions are currently unavailable