In [2]:
import torch

  cpu = _conversion_method_template(device=torch.device("cpu"))


In [None]:
# This is a neat trick for creating any attention mask in a generalized fashion
# https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/pi0/modeling_pi0.py#L91

# From the above method
"""Copied from big_vision.

Tokens can attend to valid inputs tokens which have a cumulative mask_ar
smaller or equal to theirs. This way `mask_ar` int[B, N] can be used to
setup several types of attention, for example:

    [[1 1 1 1 1 1]]: pure causal attention.

    [[0 0 0 1 1 1]]: prefix-lm attention. The first 3 tokens can attend between
        themselves and the last 3 tokens have a causal attention. The first
        entry could also be a 1 without changing behaviour.

    [[1 0 1 0 1 0 0 1 0 0]]: causal attention between 4 blocks. Tokens of a
        block can attend all previous blocks and all tokens on the same block.

Args:
    input_mask: bool[B, N] true if its part of the input, false if padding.
    mask_ar: int32[B, N] mask that's 1 where previous tokens cannot depend on
    it and 0 where it shares the same attention mask as the previous token.
"""


In [9]:
att_masks = torch.Tensor([
    [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],  # causal
    [0, 0, 0, 0, 0, 1, 1, 1, 1, 1],  # prefix-lm attention
    [1, 0, 1, 0, 1, 0, 0, 1, 0, 0],  # flexible block attention - relevant for VLAs
])

In [10]:
cumsum = torch.cumsum(att_masks, dim=1)
cumsum

tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.],
        [ 0.,  0.,  0.,  0.,  0.,  1.,  2.,  3.,  4.,  5.],
        [ 1.,  1.,  2.,  2.,  3.,  3.,  3.,  4.,  4.,  4.]])

In [11]:
att_2d_masks = cumsum[:, None, :] <= cumsum[:, :, None]
att_2d_masks

tensor([[[ True, False, False, False, False, False, False, False, False, False],
         [ True,  True, False, False, False, False, False, False, False, False],
         [ True,  True,  True, False, False, False, False, False, False, False],
         [ True,  True,  True,  True, False, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True,  True, False, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True, False, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True, False, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True]],

        [[ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  True,  True,  True, False, False, False, False, False],
         [ True,  True,  T

In [12]:
pad_masks = torch.Tensor([
    [1,1,1,0,0,0]
])
pad_2d_masks = pad_masks[:, None, :] * pad_masks[:, :, None]
pad_2d_masks

tensor([[[1., 1., 1., 0., 0., 0.],
         [1., 1., 1., 0., 0., 0.],
         [1., 1., 1., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0.]]])

In [5]:
sequence_length, target_length = 3, 4
dtype = torch.bfloat16
min_dtype = torch.finfo(dtype).min
causal_mask = torch.full(
        (sequence_length, target_length), fill_value=min_dtype, dtype=dtype
)
causal_mask

tensor([[-3.3895e+38, -3.3895e+38, -3.3895e+38, -3.3895e+38],
        [-3.3895e+38, -3.3895e+38, -3.3895e+38, -3.3895e+38],
        [-3.3895e+38, -3.3895e+38, -3.3895e+38, -3.3895e+38]],
       dtype=torch.bfloat16)

In [6]:
causel_mask = torch.triu(causal_mask, diagonal=1)
causel_mask

tensor([[ 0.0000e+00, -3.3895e+38, -3.3895e+38, -3.3895e+38],
        [ 0.0000e+00,  0.0000e+00, -3.3895e+38, -3.3895e+38],
        [ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.3895e+38]],
       dtype=torch.bfloat16)