**MASK GENERATION**

In [1]:
from xformers.ops.fmha.attn_bias import (
    BlockDiagonalCausalMask,
    BlockDiagonalCausalWithOffsetPaddedKeysMask,
    BlockDiagonalMask,
)

In [2]:
# debugging BlockDiagonalCausalMask
seqlens = [2, 2]
sliding_window = 2
batch_size = 1
total_seq = sum(seqlens)

mask = BlockDiagonalCausalMask.from_seqlens(seqlens).make_local_attention(sliding_window)
mask.materialize((batch_size, total_seq, total_seq))

tensor([[[0., -inf, -inf, -inf],
         [0., 0., -inf, -inf],
         [-inf, -inf, 0., -inf],
         [-inf, -inf, 0., 0.]]])

In [3]:
# debugging BlockDiagonalMask
seqlens = [2, 2]
kv_seqlens = [4, 4]
sliding_window = 2
batch_size = 1
total_seq = sum(seqlens)

mask = BlockDiagonalMask.from_seqlens(
                q_seqlen=seqlens,
                kv_seqlen=kv_seqlens,
            ).make_local_attention_from_bottomright(sliding_window)
mask.materialize((batch_size, total_seq, sum(kv_seqlens)))

tensor([[[-inf, 0., 0., -inf, -inf, -inf, -inf, -inf],
         [-inf, -inf, 0., 0., -inf, -inf, -inf, -inf],
         [-inf, -inf, -inf, -inf, -inf, 0., 0., -inf],
         [-inf, -inf, -inf, -inf, -inf, -inf, 0., 0.]]])

In [4]:
# debugging
seqlens = [2, 2]
kv_seqlens = [4, 4]
kv_padding = 9
sliding_window = 2
batch_size = 1
total_seq = sum(seqlens)

mask = BlockDiagonalCausalWithOffsetPaddedKeysMask.from_seqlens(
                q_seqlen=seqlens,
                kv_padding=kv_padding,
                kv_seqlen=kv_seqlens
            )
mask.materialize((batch_size, total_seq, kv_padding * len(kv_seqlens)))

tensor([[[0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 0., 0., 0., -inf, -inf, -inf, -inf, -inf, -inf],
         [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 0., 0., 0., 0., -inf, -inf, -inf, -inf, -inf]]])

**SLIDING WINDOW ATTENTION**

In [44]:
# create the sequence
example = ["I", "am", "living", "in", "Nepal"]
sequence = [{example[i]} for i in range(len(example))]
print(sequence)

[{'I'}, {'am'}, {'living'}, {'in'}, {'Nepal'}]


In [45]:
# create the sliding window attention
sliding_window = 3

def sliding_window_attention(seq: list[set[str]], w: int):
    seq_len = len(seq)
    attn_scores: list[list[set]] = [[None for _ in range(seq_len)] for _ in range(seq_len)]
    
    for i, q_tokens in enumerate(seq):
        for j, k_tokens in enumerate(seq):
            # if j > i, then we are looking at the future tokens
            if j > i:
                continue
            # if i - j > w, then we are looking at tokens that are too far away
            if i - j >= w:
                continue
            
            attention = set()
            attention.update(q_tokens)
            attention.update(k_tokens)
            attn_scores[i][j] = attention
            
    return attn_scores

# create the attention scores
attn_scores = sliding_window_attention(sequence, sliding_window)
attn_scores

[[{'I'}, None, None, None, None],
 [{'I', 'am'}, {'am'}, None, None, None],
 [{'I', 'living'}, {'am', 'living'}, {'living'}, None, None],
 [None, {'am', 'in'}, {'in', 'living'}, {'in'}, None],
 [None, None, {'Nepal', 'living'}, {'Nepal', 'in'}, {'Nepal'}]]

In [46]:
# multiply the attention scores with the values
def multiply_attn_scores_with_values(attn_scores: list[list[set]], v_seq: list[set[str]]) -> list[set[str]]:
    seq_len = len(v_seq)
    result = [set() for _ in range(seq_len)]
    
    for i in range(seq_len):
        for j in range(seq_len):
            attention = attn_scores[i][j]
            v = v_seq[j]
            r = result[i]
            
            if attention is not None:
                r.update(v)
                r.update(attention)
    
    return result

# create the attention
v_seq = sequence
result = multiply_attn_scores_with_values(attn_scores, v_seq)
result

[{'I'},
 {'I', 'am'},
 {'I', 'am', 'living'},
 {'am', 'in', 'living'},
 {'Nepal', 'in', 'living'}]

In [47]:
# inspect the attention
def print_attention(attn_scores: list[list[set[str]]]):
    for i, row in enumerate(attn_scores):
        for j, attention in enumerate(row):
            if attention is None:
                print("None", end="\t")
            else:
                print(f"{sorted(attention, key=lambda x: example.index(x))}", end="\t")
        print()

# print the attention
print_attention(attn_scores)

['I']	None	None	None	None	
['I', 'am']	['am']	None	None	None	
['I', 'living']	['am', 'living']	['living']	None	None	
None	['am', 'in']	['living', 'in']	['in']	None	
None	None	['living', 'Nepal']	['in', 'Nepal']	['Nepal']	


In [48]:
# print the sequence:
def print_sequence(seq: list[set[str]]):
    for i, tokens in enumerate(seq):
        print(f"{i}: {sorted(tokens, key=lambda x: example.index(x))}")

# print the sequence
print_sequence(sequence)

0: ['I']
1: ['am']
2: ['living']
3: ['in']
4: ['Nepal']


In [50]:
# print the layer output:
def print_layer_output(input: list[set[str]], layer: int):
    print(f"Layer {layer} input:")
    print_sequence(input)
    print()
    attn_scores = sliding_window_attention(input, sliding_window)
    print(f"Layer {layer} attention scores:")
    print_attention(attn_scores)
    print()
    output = multiply_attn_scores_with_values(attn_scores, input)
    print(f"Layer {layer} output")
    print_sequence(output)
    return output

# print the layer output
layer1_output = print_layer_output(sequence, 1)

Layer 1 input:
0: ['I']
1: ['am']
2: ['living']
3: ['in']
4: ['Nepal']

Layer 1 attention scores:
['I']	None	None	None	None	
['I', 'am']	['am']	None	None	None	
['I', 'living']	['am', 'living']	['living']	None	None	
None	['am', 'in']	['living', 'in']	['in']	None	
None	None	['living', 'Nepal']	['in', 'Nepal']	['Nepal']	

Layer 1 output
0: ['I']
1: ['I', 'am']
2: ['I', 'am', 'living']
3: ['am', 'living', 'in']
4: ['living', 'in', 'Nepal']
