In [1]:
import torch

def make_layer_mask(attention_mask: torch.Tensor) -> torch.Tensor:
    """Create a tensor to do masking on layers.

    Args:
        attention_mask: mask for attention operation, [batch_size, seq_len, seq_len]

    Returns:
        tensor: aim to keep the current layer, the same size of attention mask
        a diagonal matrix, [batch_size, seq_len, seq_len]
    """
    # [batch_size, seq_len, seq_len]
    layer_mask = (
        (torch.eye(attention_mask.size(1)) < 1)
        .unsqueeze(0)
        .expand_as(attention_mask)
    )
    return layer_mask

In [10]:

def make_combined_att_mask(
    attention_mask: torch.Tensor, layer_mask: torch.Tensor
) -> torch.Tensor:
    """Combined attention mask and layer mask.

    Args:
        attention_mask: mask for attention operation, [batch_size, seq_len, seq_len]
        layer_mask: mask for other layers, [batch_size, seq_len, seq_len]

    Returns:
        tensor: [batch_size, seq_len * 2, seq_len * 2]
    """
    # [batch_size, seq_len, seq_len * 2]
    combined_mask = torch.cat([attention_mask, layer_mask], dim=-1)
    # [batch_size, seq_len, seq_len * 2]
    contextual_mask = torch.cat(
        [attention_mask, torch.ones_like(layer_mask)], dim=-1
    )
    # [batch_size, seq_len * 2, seq_len * 2]
    combined_mask = torch.cat([contextual_mask, combined_mask], dim=1)
    return combined_mask

In [5]:
attention_mask = torch.triu(
    torch.ones((2, 5, 5), device="cpu"), diagonal=1
).bool().logical_not_()

attention_mask

tensor([[[ True, False, False, False, False],
         [ True,  True, False, False, False],
         [ True,  True,  True, False, False],
         [ True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True]],

        [[ True, False, False, False, False],
         [ True,  True, False, False, False],
         [ True,  True,  True, False, False],
         [ True,  True,  True,  True, False],
         [ True,  True,  True,  True,  True]]])

In [7]:
layer_mask = make_layer_mask(attention_mask)

layer_mask

tensor([[[False,  True,  True,  True,  True],
         [ True, False,  True,  True,  True],
         [ True,  True, False,  True,  True],
         [ True,  True,  True, False,  True],
         [ True,  True,  True,  True, False]],

        [[False,  True,  True,  True,  True],
         [ True, False,  True,  True,  True],
         [ True,  True, False,  True,  True],
         [ True,  True,  True, False,  True],
         [ True,  True,  True,  True, False]]])

In [12]:
combined_mask = make_combined_att_mask(attention_mask, layer_mask)

combined_mask

tensor([[[ True, False, False, False, False,  True,  True,  True,  True,  True],
         [ True,  True, False, False, False,  True,  True,  True,  True,  True],
         [ True,  True,  True, False, False,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True, False,  True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True,  True],
         [ True, False, False, False, False, False,  True,  True,  True,  True],
         [ True,  True, False, False, False,  True, False,  True,  True,  True],
         [ True,  True,  True, False, False,  True,  True, False,  True,  True],
         [ True,  True,  True,  True, False,  True,  True,  True, False,  True],
         [ True,  True,  True,  True,  True,  True,  True,  True,  True, False]],

        [[ True, False, False, False, False,  True,  True,  True,  True,  True],
         [ True,  True, False, False, False,  True,  True,  True,  True,  True],
         [ True,  True,  T