In [6]:
import torch

In [7]:
def _compression_make_causal_mask(
    input_ids_shape: torch.Size,
    dtype: torch.dtype,
    device: torch.device,
    past_key_values_length: int = 0,
    window_size=4,
):
    bsz, tgt_len = input_ids_shape
    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
    mask_cond = torch.arange(mask.size(-1), device=device)
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
    mask = mask.to(dtype)
    if past_key_values_length > 0:
        mask = torch.cat(
            [
                torch.zeros(
                    tgt_len, past_key_values_length, dtype=dtype, device=device
                ),
                mask,
            ],
            dim=-1,
        )

    block_mask = (
        torch.arange(past_key_values_length + tgt_len, device=device).unsqueeze(0)
        // window_size
    )
    block_mask = block_mask != block_mask.T
    casual_block_mask = torch.logical_or(
        mask, block_mask[past_key_values_length : past_key_values_length + tgt_len, :]
    )
    mask = torch.where(casual_block_mask, torch.finfo(dtype).min, 0)
    mask = mask[None, None, :, :].expand(bsz, 1, -1, -1)
    return mask

In [15]:
def _speculation_make_causal_mask(
    input_ids_shape: torch.Size,
    dtype: torch.dtype,
    device: torch.device,
    past_seen_compress_token: int = 0,
    window_size: int = 4,
    key_value_length: int = 0,
    speculation_length: int = 0,
):
    bsz, tgt_len = input_ids_shape
    mask = torch.full((tgt_len, tgt_len), torch.finfo(dtype).min, device=device)
    mask_cond = torch.arange(mask.size(-1), device=device)
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
    mask = mask.to(dtype)
    if past_seen_compress_token < key_value_length:
        # prefill, TODO: need to optimize
        block_mask = torch.arange(tgt_len, device=device).unsqueeze(0) // window_size
        block_mask = block_mask != block_mask.T
        casual_block_mask = torch.logical_or(mask, block_mask)
        mask = torch.where(casual_block_mask, torch.finfo(dtype).min, 0)
    if speculation_length > 0:
        speculaiton_mask = torch.zeros(
            (tgt_len, speculation_length), dtype=dtype, device=device
        )
        mask = torch.cat([speculaiton_mask, mask], dim=-1)
    key_value_length += past_seen_compress_token
    if key_value_length > 0:
        kv_mask = torch.full(
            (tgt_len, key_value_length), torch.finfo(dtype).min, device=device
        )
        q_position_id = torch.arange(
            past_seen_compress_token * window_size,
            past_seen_compress_token * window_size + tgt_len,
            device=device,
        )
        kv_position_id = torch.arange(
            window_size - 1,
            key_value_length * window_size,
            step=window_size,
            device=device,
        )
        kv_mask.masked_fill_(q_position_id.view(-1, 1) > kv_position_id, 0).to(dtype)
        mask = torch.cat([kv_mask, mask], dim=-1)
    mask = mask[None, None, :, :].expand(bsz, 1, -1, -1)
    return mask


# prefill
_speculation_make_causal_mask(
    torch.Size((1, 3)),
    dtype=torch.float32,
    device="cpu",
    past_seen_compress_token=0,
    window_size=2,
    key_value_length=1,
    speculation_length=0,
)

# decode
_speculation_make_causal_mask(
    torch.Size((1, 1)),
    dtype=torch.float32,
    device="cpu",
    past_seen_compress_token=1,
    window_size=2,
    key_value_length=0,
    speculation_length=0,
)

# tree_decode
_speculation_make_causal_mask(
    torch.Size((1, 3)),
    dtype=torch.float32,
    device="cpu",
    past_seen_compress_token=1,
    window_size=2,
    key_value_length=0,
    speculation_length=1,
)

tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00, -3.4028e+38],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]]]])