# masks

> Fill in a module description here

In [None]:
#| default_exp sketch_transformer.masks

In [None]:
#| hide
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [None]:
#| hide
from nbdev.showdoc import *
from fastcore.test import *

In [None]:
#| export
import torch

For attention masking, pytorch `nn.MultiHeadAttention` accepts either float or boolean masks.

There was a bug with float masks, causing `Nan` values to get generated sometimes:
- [regression - nn.MultiheadAttention does not respect adding of floating point mask to attention for the fast path · Issue #107084 · pytorch/pytorch](https://github.com/pytorch/pytorch/issues/107084)
- [TransformerEncoderLayer fast path predicts NaN when provided attention bias · Issue #118628 · pytorch/pytorch](https://github.com/pytorch/pytorch/issues/118628)
- [Disable nn.MHA fastpath for floating point masks by mikaylagawarecki · Pull Request #107641 · pytorch/pytorch](https://github.com/pytorch/pytorch/pull/107641)

So I'm using boolean masks instead. Note: pytorch converts `1 == True`, `0 == False`.

From the docs for [MultiheadAttention.forward()](https://pytorch.org/docs/stable/generated/torch.nn.MultiheadAttention.html#torch.nn.MultiheadAttention.forward):
- `key_padding_mask` – If specified, a mask of shape (N,S) indicating which elements within key to ignore for the purpose of attention (i.e. treat as “padding”). For unbatched query, shape should be (S). Binary and float masks are supported. **For a binary mask, a `True` value indicates that the corresponding key value will be ignored for the purpose of attention.** For a float mask, it will be directly added to the corresponding key value.
- `attn_mask` – If specified, a 2D or 3D mask preventing attention to certain positions. Must be of shape 
 (L,S) or (N⋅num_heads,L,S), where N is the batch size, L is the target sequence length, and S is the source sequence length. A 2D mask will be broadcasted across the batch while a 3D mask allows for a different mask for each entry in the batch. Binary and float masks are supported. **For a binary mask, a True value indicates that the corresponding position is not allowed to attend.** For a float mask, the mask values will be added to the attention weight. If both attn_mask and key_padding_mask are supplied, their types should match.


In [None]:
#| export
def create_padding_mask(seq):
    """
    In seq, the 5th entry in the last dimension is the padding column, which will
    be 1 if the row is padding.
    
    Convert to a boolean tensor, indicating 'True' for entries that are padding and should be ignored.

    :param seq: (batch_size, seq_len, 5)
    :return: (batch_size, seq_len)
    """
    return seq[..., -1].bool()


def create_lookahead_mask(seq_len):
    """
    Create an attention mask, with rows representing target position and columns representing source position.

    For row=i, column=j, mask[i][j] is 'True' if the decoder must ignore position j when processing position i.

    An upper diagonal matrix (without the diagonal) will have 'True' for any j > i.
    
    :param seq_len: sequence length
    :return: (seq_len, seq_len)
    """
    return torch.triu(torch.ones(seq_len, seq_len), diagonal=1).bool()


def create_masks(input_seq, target_seq, device='cuda'):
    enc_padding_mask = create_padding_mask(input_seq)

    # Used in the 2nd attention block in the decoder.
    # This padding mask is used to mask the encoder outputs.
    dec_padding_mask = create_padding_mask(input_seq)

    # Used in the 1st attention block in the decoder.
    # It is used to pad and mask future tokens in the input received by
    # the decoder.
    look_ahead_mask = create_lookahead_mask(target_seq.shape[1])
    dec_target_padding_mask = create_padding_mask(target_seq)

    # NOTE: torch nn.MHA takes separate padding & attn masks w/ different shapes,
    #       so use that instead of combining here
    return enc_padding_mask.to(device), dec_padding_mask.to(device), dec_target_padding_mask.to(device), look_ahead_mask.to(device)


def make_dummy_input(total_seq_len, nattn, batch_size):
  nignore = total_seq_len - nattn
  return torch.cat([
      torch.ones(batch_size, nattn, 5) * torch.tensor([0., 0., 0., 0., 0.]),
      torch.ones(batch_size, nignore, 5) * torch.tensor([0., 0., 0., 0., 1.])
  ], dim=1)

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()