Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ignore Resize ops when validating all ID uses are exactly mapped. #64

Merged
merged 1 commit into from
Mar 24, 2023

Conversation

naoyam
Copy link
Collaborator

@naoyam naoyam commented Mar 23, 2023

Resize ops are not replayed, so they don't need to be exactly mapped

Previously, FusionSliceForNanoGPT3_CUDA was segmented as the resize ops are not exactly mapped since they have the different expansion arguments. Since those resize ops are part of rfactor transformations, they were detected as conflicting rfactor transformations. However, unlike the split and merge used by reshape, resize ops are not replayed, so they don't need to be uniform.

This is also part of the fix for #58. Looks like the Python example is not segmented anymore, although I suspect there's still something need to do for permute.

Resize ops are not replayed, so they don't need to be exactly mapped
@naoyam naoyam requested a review from zasdfgbnm March 23, 2023 23:20
@zasdfgbnm
Copy link
Collaborator

Hmmm, if I have multiple slices with different result extent, should I reject the fusion?

  auto tv1 = slice(
      tv0,
      {{IrBuilder::create<Int>(0), IrBuilder::create<Int>(16)},
       {IrBuilder::create<Int>(0), IrBuilder::create<Int>(128)},
       {IrBuilder::create<Int>(0), IrBuilder::create<Int>(1024)}});
  auto tv2 = slice(
      tv0,
      {{IrBuilder::create<Int>(0), IrBuilder::create<Int>(16)},
       {IrBuilder::create<Int>(0), IrBuilder::create<Int>(128)},
       {IrBuilder::create<Int>(1024), IrBuilder::create<Int>(2049)}}); // Note: not 2048

@naoyam
Copy link
Collaborator Author

naoyam commented Mar 23, 2023

For correctness, it doesn't need to be.

%kernel {
T4_l[ iblockIdx.x37{( ceilDiv(( ceilDiv(( 128 * 1024 ), blockDim.x) ), 4) )}, iblockIdx.y39{( ceilDiv(16, 1) )}, iUS40{1}, iS38{4}, ithreadIdx.x36{blockDim.x} ] ca_pos( 5 )
   = slice( T0_g[ iS58{( ceilDiv(( ceilDiv(( i2 * i3 ), blockDim.x) ), 4) )}, iS60{( ceilDiv(i0, 1) )}, iS61{1}, iS59{4}, iS57{blockDim.x} ], { {0, 16, 1} {0, 128, 1} {0, 1024, 1} } )
T1_g[ iblockIdx.x30{( ceilDiv(( ceilDiv(( 128 * 1024 ), blockDim.x) ), 4) )}, iblockIdx.y32{( ceilDiv(16, 1) )}, iUS33{1}, iS31{4}, ithreadIdx.x29{blockDim.x} ] ca_pos( 3 ) produce_pos( 5 )
   = T4_l[ iblockIdx.x37{( ceilDiv(( ceilDiv(( 128 * 1024 ), blockDim.x) ), 4) )}, iblockIdx.y39{( ceilDiv(16, 1) )}, iUS40{1}, iS38{4}, ithreadIdx.x36{blockDim.x} ] ca_pos( 5 );
T5_l[ iblockIdx.x65{( ceilDiv(( ceilDiv(( 128 * 1025 ), blockDim.x) ), 4) )}, iblockIdx.y67{( ceilDiv(16, 1) )}, iUS68{1}, iS66{4}, ithreadIdx.x64{blockDim.x} ] ca_pos( 5 )
   = slice( T0_g[ iS58{( ceilDiv(( ceilDiv(( i2 * i3 ), blockDim.x) ), 4) )}, iS60{( ceilDiv(i0, 1) )}, iS61{1}, iS59{4}, iS57{blockDim.x} ], { {0, 16, 1} {0, 128, 1} {1024, 2049, 1} } )
T2_g[ iblockIdx.x72{( ceilDiv(( ceilDiv(( 128 * 1025 ), blockDim.x) ), 4) )}, iblockIdx.y74{( ceilDiv(16, 1) )}, iUS75{1}, iS73{4}, ithreadIdx.x71{blockDim.x} ] produce_pos( 5 )
   = T5_l[ iblockIdx.x65{( ceilDiv(( ceilDiv(( 128 * 1025 ), blockDim.x) ), 4) )}, iblockIdx.y67{( ceilDiv(16, 1) )}, iUS68{1}, iS66{4}, ithreadIdx.x64{blockDim.x} ] ca_pos( 5 );

The two slices are scheduled in the same way, although each axis may have a different extent.

T4_l[ iblockIdx.x37{( ceilDiv(( ceilDiv(( 128 * 1024 ), blockDim.x) ), 4) )},
T5_l[ iblockIdx.x65{( ceilDiv(( ceilDiv(( 128 * 1025 ), blockDim.x) ), 4) )},

This means that blockIdx.x is no longer unique in ParallelDimensionMap, but it still should work correctly (if not, there must be a bug).

For performance, I'm not sure if it's always better to fuse them or reject them. It's likely the overall performance is determined by the larger slice output, so we might want to pick that as the reference of scheduling.

I'd say it's too early to worry too much about the performance. Since it should be fine for correctness, I'd like to make them opportunistically fused and revisit when perf problems were found.

Copy link
Collaborator

@zasdfgbnm zasdfgbnm left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's good to know that there is no correctness issue. Thanks for explaining!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants