Skip to content

RNG output depends on segment execution order #5105

@jacobhinkle

Description

@jacobhinkle

I found this when debugging #4973. That PR reorders segment execution order sometimes. When the segments include RNG methods (rand, randn, etc.) the values we compute depend on the segment execution order. This might not be a problem in practice but we do seem to be testing that this is the expected behavior: bin/test_rng --gtest_filter=RNGTest.Normal is an example that runs 4 separate rng-based kernels and reordering the segments causes the test to fail.

fusion->addInput(size_val);
fusion->addInput(mean);
fusion->addInput(std);
TensorView* tv0 = normal({size_val}, mean, std, DataType::Float);
TensorView* tv1 = normal({size_val}, mean, std, DataType::Double);
TensorView* tv2 = randn({size_val}, DataType::Double);
TensorView* tv3 = randn_like(tv2);
fusion->addOutput(tv0);
fusion->addOutput(tv1);
fusion->addOutput(tv2);
fusion->addOutput(tv3);

To repro, try making this change then running bin/test_rng --gtest_filter=RNGTest.Normal:

diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp
index 22bf34d30..f08ff9395 100644
--- a/csrc/fusion_segmenter.cpp
+++ b/csrc/fusion_segmenter.cpp
@@ -5369,6 +5369,12 @@ RuntimeWorkSpace prepareRuntimeOrder(const SegmentedFusion& segmented_fusion) {
   }
 
   runtime_workspace.group_run_order = toposort(segmented_fusion.groups());
+  runtime_workspace.group_run_order = {
+    runtime_workspace.group_run_order[3],
+    runtime_workspace.group_run_order[2],
+    runtime_workspace.group_run_order[1],
+    runtime_workspace.group_run_order[0],
+  };
 
   return runtime_workspace;
 }

Possible fixes:

  • We could try and enforce execution order of rng factory functions when sorting segments. It might still be possible to reorder RNG methods within a segment.
  • We could introduce a preseg pass where we set the seed and offset of each RNG method based on the definition order, using expr->name() to determine definition order. This is not ideal because it means we would be baking in the random seed.

In any case, I think we should clearly document the circumstances in which we expect pytorch and nvFuser output to match.

Metadata

Metadata

Assignees

No one assigned

    Type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions