You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
if constexpr (!Is_causal) { // Just masking based on col
if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) {
tSrS(i) = -INFINITY;
printf("seqlen_k - n_block * kBlockN=%d- %d* %d = %d\n", seqlen_k, n_block, kBlockN, seqlen_k - n_block * kBlockN);
}
}
I noticed if seq_lenk=seq_lenq=256, we will have :
seqlen_k - n_block * kBlockN=256- 1* 176 = 80
It is even not the size of a block (like, if we have block size = 128, but seq_len = 127, of course we want to set loc 128 to be -inf). So why we have mask here?
The text was updated successfully, but these errors were encountered:
This just says for the 2nd block (columns 176 -> 351), we keep the first 80 columns (176 -> 255) and the rest of the columns are masked out as infinity.
Oh, I understand now. Because we use a strange block_size: 128176 (MN), so even I use 256 as seq_len, in 2nd block we will have many useless calculation!
But...why we use 176 as kBlockN?? I noticed this is fixed. Maybe... it uses as much smem as possible? (Well, hard to imagine how it utilizes WGMMA's size....)
I noticed if seq_lenk=seq_lenq=256, we will have :
seqlen_k - n_block * kBlockN=256- 1* 176 = 80
It is even not the size of a block (like, if we have block size = 128, but seq_len = 127, of course we want to set loc 128 to be -inf). So why we have mask here?
The text was updated successfully, but these errors were encountered: