Skip to content

alias analysis missing out opportunities on aliasing within fusion segments #2577

@jjsjann123

Description

@jjsjann123

In full rope example, we have a fusion starting with something like this:

Inputs:
  T0_g[ iS467{1024}, iS468{8} ], __bfloat
  T1_g[ iS471{1024}, iS472{8} ], __bfloat
  T2_g[ iS463{2}, iS464{1024}, iS6{i7} ], __bfloat
Outputs:
  T60_g[ iS324{2}, iS328{( 4 * 4 )}rf, iS326{1024}, iS327{32} ], __bfloat
  T39_g[ iS390{2}, iS391{16}, iS188{1024}, iS189{32} ], __bfloat
  T58_g[ iS461{2}, iS462{16}, iS272{1024}, iS273{32} ], __bfloat

%kernel_math {
T59_l[ iS465{2}, iS466{1024}, iS278{4}rf, iS280{6}rf, iS281{( ceilDiv(( ceilDiv(i7, 4) ), 6) )}rf ] = view( T2_g[ iS463{2}, iS464{1024}, iS6{i7} ] )
T4_g[ iS306{2}, iS308{4}, iS309{6}, iS307{1024}, iS310{32} ]
   = Set.Permute( T59_l[ iS465{2}, iS466{1024}, iS278{4}rf, iS280{6}rf, iS281{( ceilDiv(( ceilDiv(i7, 4) ), 6) )}rf ], cache_op=Streaming )
T7_g[ iS311{2}rf, iS312{4}rf, bS35{1}rf, iS314{1024}rf, iS315{32}rf ]
   = slice( T4_g[ iS306{2}, iS308{4}, iS309{6}, iS307{1024}, iS310{32} ], { {0, 2, 1} {0, 4, 1} {5, 6, 1} {0, 1024, 1} {0, 32, 1} } )
T10_l[ iS316{2}, iS317{4}, bS50{1}, iS318{1024}, iS319{32} ]
   = Set( T7_g[ iS311{2}rf, iS312{4}rf, bS35{1}rf, iS314{1024}rf, iS315{32}rf ], cache_op=Streaming )
T11_g[ iS320{2}, iS321{4}, bS55{1 ex 4}, iS322{1024}, iS323{32} ] = expand( T10_l[ iS316{2}, iS317{4}, bS50{1}, iS318{1024}, iS319{32} ], {2, 4, 4, 1024, 32} )
T60_g[ iS324{2}, iS328{( 4 * 4 )}rf, iS326{1024}, iS327{32} ] = view( T11_g[ iS320{2}, iS321{4}, bS55{1 ex 4}, iS322{1024}, iS323{32} ] )
T5_g[ iS329{2}rf, iS330{4}rf, iS23{4}rf, iS332{1024}rf, iS333{32}rf ]
   = slice( T4_g[ iS306{2}, iS308{4}, iS309{6}, iS307{1024}, iS310{32} ], { {0, 2, 1} {0, 4, 1} {0, 4, 1} {0, 1024, 1} {0, 32, 1} } )
T61_g[ iS334{2}, iS338{( 4 * 4 )}rf, iS336{1024}, iS337{32} ] = view( T5_g[ iS329{2}rf, iS330{4}rf, iS23{4}rf, iS332{1024}rf, iS333{32}rf ] )
i243 = 4 * 4;
i281 = fmin(i243, 16);
i283 = fmax(0, i281);
T15_g[ iS339{2}rf, iS340{( 4 * 4 )}rf, iS341{1024}rf, iS86{8}rf ]
   = slice( T61_g[ iS334{2}, iS338{( 4 * 4 )}rf, iS336{1024}, iS337{32} ], { {0, 2, 1} {0, i283, 1} {0, 1024, 1} {0, 8, 1} } )
T26_g[ iS343{2}, iS344{( 4 * 4 )}, iS345{1024}, iS134{8} ]
   = __bfloat2float(T15_g[ iS339{2}rf, iS340{( 4 * 4 )}rf, iS341{1024}rf, iS86{8}rf ]);
T24_l[ bS123{1}, bS124{1}, iS469{1024}, iS470{8} ]
   = broadcast( T0_g[ iS467{1024}, iS468{8} ] )
T25_g[ bS127{1 ex 2}, bS128{1 ex 16}, iS129{1024}, iS130{8} ] = expand( T24_l[ bS123{1}, bS124{1}, iS469{1024}, iS470{8} ], {2, 16, 1024, 8} )
T27_g[ bS135{1 ex 2}, bS136{1 ex 16}, iS137{1024}, iS138{8} ]
   = __bfloat2float(T25_g[ bS127{1 ex 2}, bS128{1 ex 16}, iS129{1024}, iS130{8} ]);
T28_g[ iS346{2}, iS347{( 4 * 4 )}, iS141{1024}, iS142{8} ]
   = T26_g[ iS343{2}, iS344{( 4 * 4 )}, iS345{1024}, iS134{8} ]
   * T27_g[ bS135{1 ex 2}, bS136{1 ex 16}, iS137{1024}, iS138{8} ];
i387 = fmin(i243, 16);
i389 = fmax(0, i387);
T17_g[ iS348{2}rf, iS349{( 4 * 4 )}rf, iS350{1024}rf, iS96{4}rf ]
   = slice( T15_g[ iS339{2}rf, iS340{( 4 * 4 )}rf, iS341{1024}rf, iS86{8}rf ], { {0, 2, 1} {0, i389, 1} {0, 1024, 1} {4, 8, 1} } )
T18_l[ iS351{2}, iS352{( 4 * 4 )}, iS353{1024}, iS100{4} ]
   = __bfloat2float(T17_g[ iS348{2}rf, iS349{( 4 * 4 )}rf, iS350{1024}rf, iS96{4}rf ]);
T19_g[ iS354{2}, iS355{( 4 * 4 )}, iS356{1024}, iS104{4} ]
   = -T18_l[ iS351{2}, iS352{( 4 * 4 )}, iS353{1024}, iS100{4} ];
T20_g[ iS357{2}, iS358{( 4 * 4 )}, iS359{1024}, iS108{4} ]
   = __float2bfloat(T19_g[ iS354{2}, iS355{( 4 * 4 )}, iS356{1024}, iS104{4} ]);
T21_l[ iS360{2}, iS361{( 4 * 4 )}, iS362{1024}, iS113{8}rf ]
   = pad( T20_g[ iS357{2}, iS358{( 4 * 4 )}, iS359{1024}, iS108{4} ], {0, 0, 0, 0, 0, 0, 0, 4} )
i334 = fmin(i243, 16);
i336 = fmax(0, i334);
T16_g[ iS363{2}rf, iS364{( 4 * 4 )}rf, iS365{1024}rf, iS91{4}rf ]
   = slice( T15_g[ iS339{2}rf, iS340{( 4 * 4 )}rf, iS341{1024}rf, iS86{8}rf ], { {0, 2, 1} {0, i336, 1} {0, 1024, 1} {0, 4, 1} } )
i448 = 0 + 4;
T22_g[ iS366{2}, iS367{( 4 * 4 )}, iS368{1024}, iS118{( ( 0 + 4 ) + 4 )}rf ]
   = pad( T16_g[ iS363{2}rf, iS364{( 4 * 4 )}rf, iS365{1024}rf, iS91{4}rf ], {0, 0, 0, 0, 0, 0, i448, 0} )
...

nvfuser unfortunately ended up segmenting on slice, which gives us a segment like this one below, where we are missing out on leveraging alias analysis but generated a kernel for a single slice op.

g{(pointwise)
inputs:
T15_g[ iS339{2}rf, iS340{( 4 * 4 )}rf, iS341{1024}rf, iS86{8}rf ] __bfloat
outputs:
T16_g[ iS363{2}rf, iS364{( 4 * 4 )}rf, iS365{1024}rf, iS91{4}rf ] __bfloat


i334 = fmin(i243, 16);
(29)
i336 = fmax(0, i334);
(30)
T16_g[ iS363{2}rf, iS364{( 4 * 4 )}rf, iS365{1024}rf, iS91{4}rf ]
   = slice( T15_g[ iS339{2}rf, iS340{( 4 * 4 )}rf, iS341{1024}rf, iS86{8}rf ], { {0, 2, 1} {0, i336, 1} {0, 1024, 1} {0, 4, 1} } )
(34)
}

The issue here is that, we might not always segment on the right position. Current alias analysis runs as preseg pass, so it will miss out alias on segments and we are unfortunately generating kernels for the example above.

repro code

import torch
import thunder

# operations to prepare q, k before sending it into rope
def split_qkv(x, n_head, n_query_groups, head_size):
    (   
        B,
        T,
        C,
    ) = x.size()
    q_per_kv = n_head // n_query_groups
    total_qkv = q_per_kv + 2
    qkv = x.view(
        B, T, n_query_groups, total_qkv, head_size)
    qkv = qkv.permute(0, 2, 3, 1, 4)
    q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
    k = k.expand(B, n_query_groups, q_per_kv, T, head_size)
    v = v.expand(B, n_query_groups, q_per_kv, T, head_size)
    q = q.reshape(B, -1, T, head_size)
    k = k.reshape(B, -1, T, head_size)
    v = v.reshape(B, -1, T, head_size)
    return q, k, v

def rope_one_entry(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, rope_n_elem: int) -> torch.Tensor:
    x_rope = x[..., : rope_n_elem]
    x1 = x_rope[..., : rope_n_elem // 2]  # (B, nh, T, hs/2)
    x2 = x_rope[..., rope_n_elem // 2 :]  # (B, nh, T, hs/2)
    rotated = torch.cat((-x2, x1), dim=-1)  # (B, nh, T, hs)
    roped = (x_rope * cos) + (rotated * sin)
    roped.to(dtype=x.dtype)
    return torch.cat((roped, x[..., rope_n_elem :]), dim=-1)

# full rope and surrounding operations to compute q, k, v
def rope_it_all(x, cos, sin, n_head, n_query_groups, head_size, rope_n_elem):
    q, k, v = split_qkv(x, n_head, n_query_groups, head_size)
    q = rope_one_entry(q, cos, sin, rope_n_elem)
    k = rope_one_entry(k, cos, sin, rope_n_elem)
    return q, k, v

dtype = torch.bfloat16
device = "cuda"
bsz = 2
block_size = 1024
n_head = 16
head_size = 32
n_query_groups = 4
rope_n_elem = 8


x = torch.randn([bsz, block_size, (n_head + 2 * n_query_groups) * head_size], device=device, dtype=dtype)
cos = torch.randn(block_size, rope_n_elem, device=device, dtype=dtype)
sin = torch.randn(block_size, rope_n_elem, device=device, dtype=dtype)

jit_rope = thunder.jit(rope_it_all, nv_enable_bookend=False)
q, k, v = jit_rope(x, cos, sin, n_head, n_query_groups, head_size, rope_n_elem)
print(thunder.last_traces(jit_rope)[-1])

Note this needs to run with bookend disabled in thunder, as in @wujingyue 's PR: Lightning-AI/lightning-thunder#731

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions