In [1]:
import numpy as np
import torch

In [2]:
def scaled_dot_product(Q, K, V, mask=None):
    d_k = Q.shape[-1] # dimension of the key/query vector
    K_transposed = K.transpose(0, 2, 1)
    scores = np.matmul(Q, K_transposed) / np.sqrt(d_k)

    if mask is not None:
        scores = np.where(mask == 0, -1e9, scores)

    # apply softmax
    attention_weights = np.exp(scores - np.max(scores, axis=1, keepdims=True))
    attention_weights = attention_weights / np.sum(attention_weights, axis=-1, keepdims=True)
    output = np.matmul(attention_weights, V)
    return output, attention_weights

In [3]:
embeddings = torch.tensor([[-1.0888, -0.2824,  0.4090,  0.6687],
        [ 0.9902,  0.1060, -0.1460,  0.9678],
        [ 1.1455,  0.5316, -0.0126,  0.1891],
        [-0.4129, -0.5683,  0.2142, -0.6977],
        [-1.2059,  0.6726, -0.5572,  0.1176],
        [ 0.7144, -2.4413,  0.4394,  1.0270],
        [-1.6094, -0.5242,  0.4525, -0.1211],
        [-0.1593, -0.2036,  0.9511, -0.4252],
        [ 0.7152, -0.6197,  0.7950,  1.1820],
        [-0.0077, -1.6696, -1.3080,  0.3963],
        [-0.8808,  0.1096,  1.1021,  2.3211],
        [-1.5365,  1.3074,  2.4265, -1.3851],
        [-0.3820,  2.0571,  0.5493, -0.5722],
        [ 1.8188,  0.0133, -0.3563, -0.2530],
        [ 1.5999,  0.5760,  0.3622, -1.4171],
        [ 1.0916, -1.0916, -0.5442, -0.4348],
        [ 1.9268, -0.1019, -0.1726, -0.9178],
        [ 0.2730,  0.1792, -0.5092, -0.1255],
        [-1.0621,  0.8124, -1.2130, -1.7982],
        [ 0.5199,  0.4318, -0.6762, -0.2020],
        [-1.1380, -0.7472,  0.6408,  0.0838],
        [-1.0337, -1.3030,  0.4298, -1.6550],
        [ 0.3106, -0.7349, -0.2298,  0.7607],
        [-0.3178, -0.0850, -0.1859,  1.2961],
        [ 0.5256,  1.7584,  0.6722,  1.9115],
        [ 0.9362,  1.6854, -0.3124, -0.2484],
        [ 1.7544,  0.4660, -0.1214,  0.8796],
        [ 0.0111,  2.6949, -0.9221, -0.2964],
        [ 0.3798, -0.4846, -0.6751, -0.8592],
        [ 0.0222, -0.0067,  0.9053,  0.8688],
        [ 1.1010,  0.1080,  0.7099,  1.1744],
        [ 0.9672, -1.9824,  1.5767,  0.4583],
        [ 0.9511,  1.5553,  1.3151, -0.3592],
        [-0.1451, -0.7237,  0.5969, -1.1076],
        [ 0.2883,  0.6982, -0.0380, -0.0168]])

In [4]:
position_embeddings = torch.tensor([[ 0.0000,  1.0000,  0.0000,  1.0000],
        [ 0.8415,  0.5403,  0.0100,  0.9999],
        [ 0.9093, -0.4161,  0.0200,  0.9998],
        [ 0.1411, -0.9900,  0.0300,  0.9996],
        [-0.7568, -0.6536,  0.0400,  0.9992],
        [-0.9589,  0.2837,  0.0500,  0.9988],
        [-0.2794,  0.9602,  0.0600,  0.9982],
        [ 0.6570,  0.7539,  0.0699,  0.9976],
        [ 0.9894, -0.1455,  0.0799,  0.9968],
        [ 0.4121, -0.9111,  0.0899,  0.9960],
        [-0.5440, -0.8391,  0.0998,  0.9950],
        [-1.0000,  0.0044,  0.1098,  0.9940],
        [-0.5366,  0.8439,  0.1197,  0.9928],
        [ 0.4202,  0.9074,  0.1296,  0.9916],
        [ 0.9906,  0.1367,  0.1395,  0.9902],
        [ 0.6503, -0.7597,  0.1494,  0.9888],
        [-0.2879, -0.9577,  0.1593,  0.9872],
        [-0.9614, -0.2752,  0.1692,  0.9856],
        [-0.7510,  0.6603,  0.1790,  0.9838],
        [ 0.1499,  0.9887,  0.1889,  0.9820],
        [ 0.9129,  0.4081,  0.1987,  0.9801],
        [ 0.8367, -0.5477,  0.2085,  0.9780],
        [-0.0089, -1.0000,  0.2182,  0.9759],
        [-0.8462, -0.5328,  0.2280,  0.9737],
        [-0.9056,  0.4242,  0.2377,  0.9713],
        [-0.1324,  0.9912,  0.2474,  0.9689],
        [ 0.7626,  0.6469,  0.2571,  0.9664],
        [ 0.9564, -0.2921,  0.2667,  0.9638],
        [ 0.2709, -0.9626,  0.2764,  0.9611],
        [-0.6636, -0.7481,  0.2860,  0.9582],
        [-0.9880,  0.1543,  0.2955,  0.9553],
        [-0.4040,  0.9147,  0.3051,  0.9523],
        [ 0.5514,  0.8342,  0.3146,  0.9492],
        [ 0.9999, -0.0133,  0.3240,  0.9460],
        [ 0.5291, -0.8486,  0.3335,  0.9428]])

In [5]:
combined_embeddings = embeddings + position_embeddings

In [15]:
combined_embeddings

tensor([[-1.0888,  0.7176,  0.4090,  1.6687],
        [ 1.8317,  0.6463, -0.1360,  1.9677],
        [ 2.0548,  0.1155,  0.0074,  1.1889],
        [-0.2718, -1.5583,  0.2442,  0.3019],
        [-1.9627,  0.0190, -0.5172,  1.1168],
        [-0.2445, -2.1576,  0.4894,  2.0258],
        [-1.8888,  0.4360,  0.5125,  0.8771],
        [ 0.4977,  0.5503,  1.0210,  0.5724],
        [ 1.7046, -0.7652,  0.8749,  2.1788],
        [ 0.4044, -2.5807, -1.2181,  1.3923],
        [-1.4248, -0.7295,  1.2019,  3.3161],
        [-2.5365,  1.3118,  2.5363, -0.3911],
        [-0.9186,  2.9010,  0.6690,  0.4206],
        [ 2.2390,  0.9207, -0.2267,  0.7386],
        [ 2.5905,  0.7127,  0.5017, -0.4269],
        [ 1.7419, -1.8513, -0.3948,  0.5540],
        [ 1.6389, -1.0596, -0.0133,  0.0694],
        [-0.6884, -0.0960, -0.3400,  0.8601],
        [-1.8131,  1.4727, -1.0340, -0.8144],
        [ 0.6698,  1.4205, -0.4873,  0.7800],
        [-0.2251, -0.3391,  0.8395,  1.0639],
        [-0.1970, -1.8507,  0.6383

In [6]:
Q = combined_embeddings.unsqueeze(0).numpy()
K = combined_embeddings.unsqueeze(0).numpy()
V = combined_embeddings.unsqueeze(0).numpy()

In [7]:
Q

array([[[-1.0888    ,  0.7176    ,  0.409     ,  1.6687    ],
        [ 1.8317    ,  0.6463    , -0.13599999,  1.9677    ],
        [ 2.0548    ,  0.1155    ,  0.0074    ,  1.1889    ],
        [-0.27179998, -1.5583    ,  0.2442    ,  0.30189997],
        [-1.9626999 ,  0.01899999, -0.5172    ,  1.1168    ],
        [-0.24449998, -2.1576    ,  0.4894    ,  2.0258    ],
        [-1.8888    ,  0.436     ,  0.5125    ,  0.8771    ],
        [ 0.4977    ,  0.5503    ,  1.021     ,  0.57240003],
        [ 1.7046001 , -0.7652    ,  0.8749    ,  2.1788    ],
        [ 0.4044    , -2.5807    , -1.2181    ,  1.3923    ],
        [-1.4248    , -0.7295    ,  1.2019    ,  3.3161001 ],
        [-2.5365    ,  1.3118    ,  2.5363002 , -0.3911    ],
        [-0.91859996,  2.901     ,  0.669     ,  0.4206    ],
        [ 2.2389998 ,  0.9207    , -0.2267    ,  0.7386    ],
        [ 2.5904999 ,  0.7127    ,  0.5017    , -0.42689997],
        [ 1.7419    , -1.8513    , -0.3948    ,  0.554     ],
        

In [8]:
K

array([[[-1.0888    ,  0.7176    ,  0.409     ,  1.6687    ],
        [ 1.8317    ,  0.6463    , -0.13599999,  1.9677    ],
        [ 2.0548    ,  0.1155    ,  0.0074    ,  1.1889    ],
        [-0.27179998, -1.5583    ,  0.2442    ,  0.30189997],
        [-1.9626999 ,  0.01899999, -0.5172    ,  1.1168    ],
        [-0.24449998, -2.1576    ,  0.4894    ,  2.0258    ],
        [-1.8888    ,  0.436     ,  0.5125    ,  0.8771    ],
        [ 0.4977    ,  0.5503    ,  1.021     ,  0.57240003],
        [ 1.7046001 , -0.7652    ,  0.8749    ,  2.1788    ],
        [ 0.4044    , -2.5807    , -1.2181    ,  1.3923    ],
        [-1.4248    , -0.7295    ,  1.2019    ,  3.3161001 ],
        [-2.5365    ,  1.3118    ,  2.5363002 , -0.3911    ],
        [-0.91859996,  2.901     ,  0.669     ,  0.4206    ],
        [ 2.2389998 ,  0.9207    , -0.2267    ,  0.7386    ],
        [ 2.5904999 ,  0.7127    ,  0.5017    , -0.42689997],
        [ 1.7419    , -1.8513    , -0.3948    ,  0.554     ],
        

In [9]:
V

array([[[-1.0888    ,  0.7176    ,  0.409     ,  1.6687    ],
        [ 1.8317    ,  0.6463    , -0.13599999,  1.9677    ],
        [ 2.0548    ,  0.1155    ,  0.0074    ,  1.1889    ],
        [-0.27179998, -1.5583    ,  0.2442    ,  0.30189997],
        [-1.9626999 ,  0.01899999, -0.5172    ,  1.1168    ],
        [-0.24449998, -2.1576    ,  0.4894    ,  2.0258    ],
        [-1.8888    ,  0.436     ,  0.5125    ,  0.8771    ],
        [ 0.4977    ,  0.5503    ,  1.021     ,  0.57240003],
        [ 1.7046001 , -0.7652    ,  0.8749    ,  2.1788    ],
        [ 0.4044    , -2.5807    , -1.2181    ,  1.3923    ],
        [-1.4248    , -0.7295    ,  1.2019    ,  3.3161001 ],
        [-2.5365    ,  1.3118    ,  2.5363002 , -0.3911    ],
        [-0.91859996,  2.901     ,  0.669     ,  0.4206    ],
        [ 2.2389998 ,  0.9207    , -0.2267    ,  0.7386    ],
        [ 2.5904999 ,  0.7127    ,  0.5017    , -0.42689997],
        [ 1.7419    , -1.8513    , -0.3948    ,  0.554     ],
        

In [10]:
output, attention_weights = scaled_dot_product(Q, K, V)

In [11]:
output

array([[[-0.40483868,  0.11095251,  0.2696042 ,  1.0122806 ],
        [ 1.1211162 ,  0.22585152,  0.20820767,  0.99253434],
        [ 1.1494709 , -0.11064901,  0.2543017 ,  0.76504236],
        [ 0.21812643, -1.0103799 ,  0.3362688 ,  0.44373697],
        [-0.8966201 , -0.06305939,  0.01196374,  0.898434  ],
        [ 0.16425262, -1.1337931 ,  0.36351457,  1.008104  ],
        [-0.78413   ,  0.04194552,  0.19172506,  0.8753997 ],
        [ 0.51976854, -0.06619034,  0.4661802 ,  0.67317516],
        [ 0.8972922 , -0.5365232 ,  0.48026976,  1.0298588 ],
        [ 0.45933005, -1.4242238 , -0.04203663,  0.77141064],
        [-0.46203268, -0.4865973 ,  0.5434941 ,  1.5313141 ],
        [-1.2547109 ,  0.6436388 ,  0.86473614,  0.46626607],
        [-0.17788833,  1.5645952 ,  0.18443987,  0.6535972 ],
        [ 1.2906549 ,  0.36953446,  0.16597107,  0.7104645 ],
        [ 1.4328204 ,  0.14682475,  0.3310614 ,  0.37457493],
        [ 0.9129127 , -1.0840325 ,  0.14538339,  0.4775942 ],
        

In [12]:
output.shape

(1, 35, 4)

In [13]:
attention_weights

array([[[0.08250054, 0.00772808, 0.00627255, ..., 0.00284996,
         0.03678321, 0.06093527],
        [0.00998801, 0.08095191, 0.07939018, ..., 0.00872614,
         0.05301088, 0.1136087 ],
        [0.0048787 , 0.04777716, 0.07566947, ..., 0.00606615,
         0.10011036, 0.11447173],
        ...,
        [0.00921999, 0.02184267, 0.02523153, ..., 0.13945882,
         0.05544952, 0.05426313],
        [0.00466399, 0.00520075, 0.01632025, ..., 0.00217328,
         0.24101266, 0.0791427 ],
        [0.01304652, 0.01882047, 0.0315111 , ..., 0.0035912 ,
         0.13363755, 0.11458109]]], dtype=float32)

In [14]:
attention_weights.shape

(1, 35, 35)