In [1]:
import torch

In [45]:
import torch.nn.functional as F

In [21]:
# Implemented by https://github.com/BreaGG/Attention_Is_All_You_Need_From_Scratch/blob/main/transformer_model.py

# Padding Mask (prevents attention to padding tokens)
def create_padding_mask(seq):
    mask = (seq == 0).unsqueeze(1).unsqueeze(2)  # (batch_size, 1, 1, seq_len)
    return mask  # Mask with shape (batch_size, 1, 1, seq_len)

# Look-Ahead Mask (prevents attention to future tokens)
def create_look_ahead_mask(size):
    mask = torch.triu(torch.ones((size, size)), diagonal=1).type(torch.uint8)
    return mask  # Shape (seq_len, seq_len)

In [23]:
mask = create_look_ahead_mask(10)
mask

tensor([[0, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 1, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 1, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 1, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 1, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 1, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 1, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 1, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], dtype=torch.uint8)

In [40]:
mask_test = torch.triu(torch.ones((10, 10)), diagonal=0).type(torch.uint8).T
mask_test

tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], dtype=torch.uint8)

In [29]:
mask.T

tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 0, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 0, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 0, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 0, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 0, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 0, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 0, 0],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 0]], dtype=torch.uint8)

In [24]:
scores = torch.randn((3,10,10))

In [42]:
scores[1,:,:]

tensor([[ 0.7658,  0.7277, -0.0079,  0.8647, -1.1751,  1.2150, -1.8414, -1.3888,
          0.0637, -0.7425],
        [-0.7507, -0.2582, -1.7429, -0.4782,  0.9306, -0.7092, -2.3939, -0.7612,
         -0.7418,  1.0319],
        [ 0.8907, -0.5042, -0.4075,  1.3381, -0.0196, -1.4255,  0.4109,  0.0600,
         -0.8547, -0.0805],
        [-1.2025, -1.3239, -3.1932, -1.6969, -0.5859, -0.9290,  1.4267,  2.8155,
         -0.6488, -0.8480],
        [-0.9908, -1.5148, -0.9942, -1.5934,  1.2252, -0.5018,  1.6984, -1.7580,
         -1.9725, -0.5985],
        [ 1.2378,  0.2936, -0.8252, -1.0608,  0.8540, -0.0779,  1.1706,  0.6345,
         -1.9783,  2.5372],
        [-0.2369,  0.1764, -0.7012, -0.5631,  0.9140, -1.5367,  0.8090,  0.3354,
         -0.9780,  1.3212],
        [-0.8529,  0.1341,  1.2273, -0.2964,  0.8344,  0.0899, -0.9421,  0.4613,
         -2.2921, -0.4909],
        [-0.5026,  1.3091, -0.2799,  0.0561,  0.6623,  0.8269,  0.3279,  1.3837,
         -1.0062, -0.4322],
        [-1.2741, -

In [47]:
F.softmax(scores.masked_fill(mask == 1, -torch.inf)[1],dim=-1)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.3793, 0.6207, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.6575, 0.1630, 0.1795, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.3799, 0.3365, 0.0519, 0.2317, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0813, 0.0481, 0.0810, 0.0445, 0.7452, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.3897, 0.1516, 0.0495, 0.0391, 0.2655, 0.1046, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0986, 0.1491, 0.0620, 0.0712, 0.3117, 0.0269, 0.2806, 0.0000, 0.0000,
         0.0000],
        [0.0384, 0.1030, 0.3074, 0.0670, 0.2075, 0.0986, 0.0351, 0.1429, 0.0000,
         0.0000],
        [0.0376, 0.2301, 0.0470, 0.0657, 0.1205, 0.1421, 0.0863, 0.2480, 0.0227,
         0.0000],
        [0.0267, 0.0232, 0.2043, 0.0860, 0.2545, 0.0300, 0.0224, 0.1200, 0.0715,
         0.1613]])

In [61]:
F.softmax(scores.masked_fill(mask == 1, -1e9)[1],dim=-1)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.3793, 0.6207, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.6575, 0.1630, 0.1795, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.3799, 0.3365, 0.0519, 0.2317, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0813, 0.0481, 0.0810, 0.0445, 0.7452, 0.0000, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.3897, 0.1516, 0.0495, 0.0391, 0.2655, 0.1046, 0.0000, 0.0000, 0.0000,
         0.0000],
        [0.0986, 0.1491, 0.0620, 0.0712, 0.3117, 0.0269, 0.2806, 0.0000, 0.0000,
         0.0000],
        [0.0384, 0.1030, 0.3074, 0.0670, 0.2075, 0.0986, 0.0351, 0.1429, 0.0000,
         0.0000],
        [0.0376, 0.2301, 0.0470, 0.0657, 0.1205, 0.1421, 0.0863, 0.2480, 0.0227,
         0.0000],
        [0.0267, 0.0232, 0.2043, 0.0860, 0.2545, 0.0300, 0.0224, 0.1200, 0.0715,
         0.1613]])

In [60]:
F.softmax((scores*mask_test)[1],dim=-1)

tensor([[0.1929, 0.0897, 0.0897, 0.0897, 0.0897, 0.0897, 0.0897, 0.0897, 0.0897,
         0.0897],
        [0.0511, 0.0836, 0.1082, 0.1082, 0.1082, 0.1082, 0.1082, 0.1082, 0.1082,
         0.1082],
        [0.2276, 0.0564, 0.0621, 0.0934, 0.0934, 0.0934, 0.0934, 0.0934, 0.0934,
         0.0934],
        [0.0442, 0.0392, 0.0060, 0.0270, 0.1473, 0.1473, 0.1473, 0.1473, 0.1473,
         0.1473],
        [0.0388, 0.0230, 0.0387, 0.0212, 0.3558, 0.1045, 0.1045, 0.1045, 0.1045,
         0.1045],
        [0.2684, 0.1044, 0.0341, 0.0269, 0.1828, 0.0720, 0.0778, 0.0778, 0.0778,
         0.0778],
        [0.0717, 0.1084, 0.0451, 0.0518, 0.2267, 0.0195, 0.2041, 0.0909, 0.0909,
         0.0909],
        [0.0325, 0.0873, 0.2605, 0.0568, 0.1758, 0.0835, 0.0298, 0.1211, 0.0763,
         0.0763],
        [0.0354, 0.2167, 0.0442, 0.0619, 0.1135, 0.1338, 0.0812, 0.2334, 0.0214,
         0.0585],
        [0.0267, 0.0232, 0.2043, 0.0860, 0.2545, 0.0300, 0.0224, 0.1200, 0.0715,
         0.1613]])