Inputs:
T0_g[ iS467{1024}, iS468{8} ], __bfloat
T1_g[ iS471{1024}, iS472{8} ], __bfloat
T2_g[ iS463{2}, iS464{1024}, iS6{i7} ], __bfloat
Outputs:
T60_g[ iS324{2}, iS328{( 4 * 4 )}rf, iS326{1024}, iS327{32} ], __bfloat
T39_g[ iS390{2}, iS391{16}, iS188{1024}, iS189{32} ], __bfloat
T58_g[ iS461{2}, iS462{16}, iS272{1024}, iS273{32} ], __bfloat
%kernel_math {
T59_l[ iS465{2}, iS466{1024}, iS278{4}rf, iS280{6}rf, iS281{( ceilDiv(( ceilDiv(i7, 4) ), 6) )}rf ] = view( T2_g[ iS463{2}, iS464{1024}, iS6{i7} ] )
T4_g[ iS306{2}, iS308{4}, iS309{6}, iS307{1024}, iS310{32} ]
= Set.Permute( T59_l[ iS465{2}, iS466{1024}, iS278{4}rf, iS280{6}rf, iS281{( ceilDiv(( ceilDiv(i7, 4) ), 6) )}rf ], cache_op=Streaming )
T7_g[ iS311{2}rf, iS312{4}rf, bS35{1}rf, iS314{1024}rf, iS315{32}rf ]
= slice( T4_g[ iS306{2}, iS308{4}, iS309{6}, iS307{1024}, iS310{32} ], { {0, 2, 1} {0, 4, 1} {5, 6, 1} {0, 1024, 1} {0, 32, 1} } )
T10_l[ iS316{2}, iS317{4}, bS50{1}, iS318{1024}, iS319{32} ]
= Set( T7_g[ iS311{2}rf, iS312{4}rf, bS35{1}rf, iS314{1024}rf, iS315{32}rf ], cache_op=Streaming )
T11_g[ iS320{2}, iS321{4}, bS55{1 ex 4}, iS322{1024}, iS323{32} ] = expand( T10_l[ iS316{2}, iS317{4}, bS50{1}, iS318{1024}, iS319{32} ], {2, 4, 4, 1024, 32} )
T60_g[ iS324{2}, iS328{( 4 * 4 )}rf, iS326{1024}, iS327{32} ] = view( T11_g[ iS320{2}, iS321{4}, bS55{1 ex 4}, iS322{1024}, iS323{32} ] )
T5_g[ iS329{2}rf, iS330{4}rf, iS23{4}rf, iS332{1024}rf, iS333{32}rf ]
= slice( T4_g[ iS306{2}, iS308{4}, iS309{6}, iS307{1024}, iS310{32} ], { {0, 2, 1} {0, 4, 1} {0, 4, 1} {0, 1024, 1} {0, 32, 1} } )
T61_g[ iS334{2}, iS338{( 4 * 4 )}rf, iS336{1024}, iS337{32} ] = view( T5_g[ iS329{2}rf, iS330{4}rf, iS23{4}rf, iS332{1024}rf, iS333{32}rf ] )
i243 = 4 * 4;
i281 = fmin(i243, 16);
i283 = fmax(0, i281);
T15_g[ iS339{2}rf, iS340{( 4 * 4 )}rf, iS341{1024}rf, iS86{8}rf ]
= slice( T61_g[ iS334{2}, iS338{( 4 * 4 )}rf, iS336{1024}, iS337{32} ], { {0, 2, 1} {0, i283, 1} {0, 1024, 1} {0, 8, 1} } )
T26_g[ iS343{2}, iS344{( 4 * 4 )}, iS345{1024}, iS134{8} ]
= __bfloat2float(T15_g[ iS339{2}rf, iS340{( 4 * 4 )}rf, iS341{1024}rf, iS86{8}rf ]);
T24_l[ bS123{1}, bS124{1}, iS469{1024}, iS470{8} ]
= broadcast( T0_g[ iS467{1024}, iS468{8} ] )
T25_g[ bS127{1 ex 2}, bS128{1 ex 16}, iS129{1024}, iS130{8} ] = expand( T24_l[ bS123{1}, bS124{1}, iS469{1024}, iS470{8} ], {2, 16, 1024, 8} )
T27_g[ bS135{1 ex 2}, bS136{1 ex 16}, iS137{1024}, iS138{8} ]
= __bfloat2float(T25_g[ bS127{1 ex 2}, bS128{1 ex 16}, iS129{1024}, iS130{8} ]);
T28_g[ iS346{2}, iS347{( 4 * 4 )}, iS141{1024}, iS142{8} ]
= T26_g[ iS343{2}, iS344{( 4 * 4 )}, iS345{1024}, iS134{8} ]
* T27_g[ bS135{1 ex 2}, bS136{1 ex 16}, iS137{1024}, iS138{8} ];
i387 = fmin(i243, 16);
i389 = fmax(0, i387);
T17_g[ iS348{2}rf, iS349{( 4 * 4 )}rf, iS350{1024}rf, iS96{4}rf ]
= slice( T15_g[ iS339{2}rf, iS340{( 4 * 4 )}rf, iS341{1024}rf, iS86{8}rf ], { {0, 2, 1} {0, i389, 1} {0, 1024, 1} {4, 8, 1} } )
T18_l[ iS351{2}, iS352{( 4 * 4 )}, iS353{1024}, iS100{4} ]
= __bfloat2float(T17_g[ iS348{2}rf, iS349{( 4 * 4 )}rf, iS350{1024}rf, iS96{4}rf ]);
T19_g[ iS354{2}, iS355{( 4 * 4 )}, iS356{1024}, iS104{4} ]
= -T18_l[ iS351{2}, iS352{( 4 * 4 )}, iS353{1024}, iS100{4} ];
T20_g[ iS357{2}, iS358{( 4 * 4 )}, iS359{1024}, iS108{4} ]
= __float2bfloat(T19_g[ iS354{2}, iS355{( 4 * 4 )}, iS356{1024}, iS104{4} ]);
T21_l[ iS360{2}, iS361{( 4 * 4 )}, iS362{1024}, iS113{8}rf ]
= pad( T20_g[ iS357{2}, iS358{( 4 * 4 )}, iS359{1024}, iS108{4} ], {0, 0, 0, 0, 0, 0, 0, 4} )
i334 = fmin(i243, 16);
i336 = fmax(0, i334);
T16_g[ iS363{2}rf, iS364{( 4 * 4 )}rf, iS365{1024}rf, iS91{4}rf ]
= slice( T15_g[ iS339{2}rf, iS340{( 4 * 4 )}rf, iS341{1024}rf, iS86{8}rf ], { {0, 2, 1} {0, i336, 1} {0, 1024, 1} {0, 4, 1} } )
i448 = 0 + 4;
T22_g[ iS366{2}, iS367{( 4 * 4 )}, iS368{1024}, iS118{( ( 0 + 4 ) + 4 )}rf ]
= pad( T16_g[ iS363{2}rf, iS364{( 4 * 4 )}rf, iS365{1024}rf, iS91{4}rf ], {0, 0, 0, 0, 0, 0, i448, 0} )
...
g{(pointwise)
inputs:
T15_g[ iS339{2}rf, iS340{( 4 * 4 )}rf, iS341{1024}rf, iS86{8}rf ] __bfloat
outputs:
T16_g[ iS363{2}rf, iS364{( 4 * 4 )}rf, iS365{1024}rf, iS91{4}rf ] __bfloat
i334 = fmin(i243, 16);
(29)
i336 = fmax(0, i334);
(30)
T16_g[ iS363{2}rf, iS364{( 4 * 4 )}rf, iS365{1024}rf, iS91{4}rf ]
= slice( T15_g[ iS339{2}rf, iS340{( 4 * 4 )}rf, iS341{1024}rf, iS86{8}rf ], { {0, 2, 1} {0, i336, 1} {0, 1024, 1} {0, 4, 1} } )
(34)
}
The issue here is that, we might not always segment on the right position. Current alias analysis runs as preseg pass, so it will miss out alias on segments and we are unfortunately generating kernels for the example above.
import torch
import thunder
# operations to prepare q, k before sending it into rope
def split_qkv(x, n_head, n_query_groups, head_size):
(
B,
T,
C,
) = x.size()
q_per_kv = n_head // n_query_groups
total_qkv = q_per_kv + 2
qkv = x.view(
B, T, n_query_groups, total_qkv, head_size)
qkv = qkv.permute(0, 2, 3, 1, 4)
q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
k = k.expand(B, n_query_groups, q_per_kv, T, head_size)
v = v.expand(B, n_query_groups, q_per_kv, T, head_size)
q = q.reshape(B, -1, T, head_size)
k = k.reshape(B, -1, T, head_size)
v = v.reshape(B, -1, T, head_size)
return q, k, v
def rope_one_entry(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, rope_n_elem: int) -> torch.Tensor:
x_rope = x[..., : rope_n_elem]
x1 = x_rope[..., : rope_n_elem // 2] # (B, nh, T, hs/2)
x2 = x_rope[..., rope_n_elem // 2 :] # (B, nh, T, hs/2)
rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
roped = (x_rope * cos) + (rotated * sin)
roped.to(dtype=x.dtype)
return torch.cat((roped, x[..., rope_n_elem :]), dim=-1)
# full rope and surrounding operations to compute q, k, v
def rope_it_all(x, cos, sin, n_head, n_query_groups, head_size, rope_n_elem):
q, k, v = split_qkv(x, n_head, n_query_groups, head_size)
q = rope_one_entry(q, cos, sin, rope_n_elem)
k = rope_one_entry(k, cos, sin, rope_n_elem)
return q, k, v
dtype = torch.bfloat16
device = "cuda"
bsz = 2
block_size = 1024
n_head = 16
head_size = 32
n_query_groups = 4
rope_n_elem = 8
x = torch.randn([bsz, block_size, (n_head + 2 * n_query_groups) * head_size], device=device, dtype=dtype)
cos = torch.randn(block_size, rope_n_elem, device=device, dtype=dtype)
sin = torch.randn(block_size, rope_n_elem, device=device, dtype=dtype)
jit_rope = thunder.jit(rope_it_all, nv_enable_bookend=False)
q, k, v = jit_rope(x, cos, sin, n_head, n_query_groups, head_size, rope_n_elem)
print(thunder.last_traces(jit_rope)[-1])
In full rope example, we have a fusion starting with something like this:
nvfuser unfortunately ended up segmenting on
slice, which gives us a segment like this one below, where we are missing out on leveraging alias analysis but generated a kernel for a single slice op.The issue here is that, we might not always segment on the right position. Current alias analysis runs as preseg pass, so it will miss out alias on segments and we are unfortunately generating kernels for the example above.
repro code
Note this needs to run with bookend disabled in thunder, as in @wujingyue 's PR: Lightning-AI/lightning-thunder#731