Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dlight] Use 16x32 spatial x reduction thread extents in GEMV scheduling #17082

Merged

Conversation

csullivan
Copy link
Contributor

Change to use 16x32 spatial x reduction thread extents regardless of workload size. This works around a lowering bug which I haven't tracked down yet.

Currently when the spatial dimension is larger than the reduction dimension, it uses a 4x64 thread layout. This implies two warps in the reduction dimension corresponding to blockDim.x=64. An illegal cuda instruction is encountered in the second warp during the __shfl_down_sync for the remainder portion of the computation (as a result of the rfactor, I believe). AFAICT the mask calculation used for this remainder shfl is incorrect and is causing the error. Specifically it occurs on the first thread of the second warp (two warps along x since blockDim.x = 64)

This is the relevant cuda causing the error:

if (((int)threadIdx.x) < 2) {
    red_buf0[0] = red_buf_staging[((((int)threadIdx.y) * 2) + ((int)threadIdx.x))];
  }
  mask[0] = (__activemask() & ((uint)(3 << (((int)threadIdx.y) * 2)))); // <<< likely the problem
  t0[0] = __shfl_down_sync(mask[0], red_buf0[0], 1, 32);
  red_buf0[0] = (red_buf0[0] + t0[0]);
  if (((int)threadIdx.x) == 0) {
    ((volatile half*)red_result)[((int)threadIdx.y)] = red_buf0[0];
  }

The corresponding sass where the illegal instruction occurs:

   0x00007d9e97b92490 <+1936>:  WARPSYNC.ALL
   0x00007d9e97b924a0 <+1952>:  BAR.SYNC.DEFER_BLOCKING 0x0
   0x00007d9e97b924b0 <+1968>:  @!P1 VIADD R13, R5, 0x8
   0x00007d9e97b924c0 <+1984>:  @!P1 LEA R7, R17, R14, 0x1
   0x00007d9e97b924d0 <+2000>:  @!P1 PRMT R6, R2, 0x654, R13
   0x00007d9e97b924e0 <+2016>:  @!P1 LEA R7, R7, R6, 0x1
   0x00007d9e97b924f0 <+2032>:  @!P1 LDS.U16 R16, [R7]
   0x00007d9e97b92500 <+2048>:  IMAD.MOV.U32 R6, RZ, RZ, 0x3
   0x00007d9e97b92510 <+2064>:  SHF.L.U32 R17, R17, 0x1, RZ
   0x00007d9e97b92520 <+2080>:  VOTEU.ANY UR4, UPT, PT
   0x00007d9e97b92530 <+2096>:  SHF.L.U32 R3, R6, R17, RZ
   0x00007d9e97b92540 <+2112>:  LOP3.LUT R3, R3, UR4, RZ, 0xc0, !PT
   0x00007d9e97b92550 <+2128>:  ISETP.NE.AND P0, PT, R14, RZ, PT
   0x00007d9e97b92560 <+2144>:  PRMT R2, R2, 0x654, R5
   0x00007d9e97b92570 <+2160>:  PRMT R4, R16, 0x5410, R16
*> 0x00007d9e97b92580 <+2176>:  WARPSYNC R3
=> 0x00007d9e97b92590 <+2192>:  SHFL.DOWN PT, R3, R4, 0x1, 0x1f
   0x00007d9e97b925a0 <+2208>:  IMAD.IADD R17, R17, 0x1, R2
   0x00007d9e97b925b0 <+2224>:  HADD2 R16, R16.H0_H0, R3.H0_H0
   0x00007d9e97b925c0 <+2240>:  @!P0 STS.U16 [R17], R16
   0x00007d9e97b925d0 <+2256>:  WARPSYNC.ALL
   0x00007d9e97b925e0 <+2272>:  BAR.SYNC.DEFER_BLOCKING 0x0

Changing the thread extents to 16x32 (one warp along the reduction dimension) works around the issue. It also improves performance for the tested shapes by ~10%.

Utilizing (8, 2048, 4096) to avoid the error,

# 4x64
 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)             Name
 --------  ---------------  ---------  --------  --------  --------  --------  -----------  --------------------------
     81.5           612214        101    6061.5    6048.0      5920      7872        188.5  moe_dequantize_gemv_kernel

# 16x32
 Time (%)  Total Time (ns)  Instances  Avg (ns)  Med (ns)  Min (ns)  Max (ns)  StdDev (ns)             Name
 --------  ---------------  ---------  --------  --------  --------  --------  -----------  --------------------------
     79.9           555901        101    5504.0    5472.0      5439      6880        142.7  moe_dequantize_gemv_kernel

@tqchen
Copy link
Member

tqchen commented Jun 12, 2024

@csullivan please help to fix the UT, in this case seems directly change the expected tir is fine

@csullivan csullivan force-pushed the fix/2024-06-11/dlight-gemv-thread-extents branch from be96146 to 92aa3f8 Compare June 19, 2024 17:05
@csullivan csullivan merged commit 269a4f7 into apache:main Jun 20, 2024
18 checks passed
@csullivan csullivan deleted the fix/2024-06-11/dlight-gemv-thread-extents branch June 20, 2024 05:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants