This notebook will compare Full Attention to a sparse attention technique used in long time series problems to make the operations quick.

## Full Attention


$$
\text{Full Attention} = \text{softmax} \bigg( \frac{Q.K^T}{\sqrt d_q} \bigg).V
$$

Let's start with defining query, keys and value vectors.

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from math import sqrt

In [None]:
n_heads = 1 # Assume single head attention
batch_size = 1 # Assume single batch size
sequence_length = L_Q = L_K = L_V = 10 # Number of datapoints passed in parallel
d_model = 4 # Number of features per data point

d_k = d_model // n_heads
d_v = d_model // n_heads
d_q = d_model // n_heads

Q = torch.randn( (batch_size, L_Q, n_heads, d_q) )
K = torch.randn( (batch_size, L_K, n_heads, d_k) )
V = torch.randn( (batch_size, L_V, n_heads, d_v) )

In [None]:
Q.shape, K.shape, V.shape

(torch.Size([1, 10, 1, 4]),
 torch.Size([1, 10, 1, 4]),
 torch.Size([1, 10, 1, 4]))

In [None]:
scores = torch.einsum("blhe,bshe->bhls", Q, K)
scores.shape

torch.Size([1, 1, 10, 10])

`torch.einsum` is a powerful function that could perform addition, multiplication and rearrangement of tensors. You see this in the otiginal informer code.

In [None]:
scores

tensor([[[[ 0.9837,  1.0102, -1.4656, -1.8151, -0.5829,  0.0494,  0.7742,
           -1.3344,  1.6809,  0.0954],
          [-1.2752, -0.3817, -0.3368,  1.9013,  0.3895, -0.1814, -1.2206,
           -0.5959, -2.7997,  0.4464],
          [-1.3294, -1.3175,  1.4667,  1.1584,  1.3219,  0.5197,  0.1335,
            0.7488, -1.5054,  0.9336],
          [ 0.2817,  0.3830,  1.2102, -0.4990, -0.8077,  0.0072,  0.1867,
            1.2884,  1.6660, -0.7949],
          [ 0.1928,  0.0379,  2.0728,  1.9656, -1.4251, -0.9910, -1.8786,
            3.0749, -0.0796, -2.4746],
          [ 2.7199, -1.1410,  2.3034, -1.9675,  0.5378,  0.0364,  1.9027,
            3.4813,  4.1183, -1.5084],
          [-0.3402,  0.4161, -3.1723,  0.3304,  0.4489, -0.1852, -0.5669,
           -3.2813, -2.4455,  1.2121],
          [-4.0858, -0.1368,  2.9327,  0.5041,  0.4458,  1.8001,  0.7948,
           -0.0374, -0.7481,  2.5106],
          [ 2.5345,  1.6985,  0.1427, -0.0807, -3.0892, -1.6868, -1.6755,
            2.3248,  2


`scores`: Each element in this matrix is a number that corresponds to the affinity of a query i with a key j.

This matrix multiplcation requires 4 multiplcations and 3 additions for every 10 x 10 spot.

Hence, this matrix multiplication has `O(L_Q * L_K)` multiplication operations.

This is quadratic in the number of samples in the time series sequence.

`scores` can be computed with the equivalent matrix multiplication of the query and key vectors

In [None]:
torch.matmul(Q.squeeze(0, 2), K.squeeze(0, 2).T)

tensor([[ 0.9837,  1.0102, -1.4656, -1.8151, -0.5829,  0.0494,  0.7742, -1.3344,
          1.6809,  0.0954],
        [-1.2752, -0.3817, -0.3368,  1.9013,  0.3895, -0.1814, -1.2206, -0.5959,
         -2.7997,  0.4464],
        [-1.3294, -1.3175,  1.4667,  1.1584,  1.3219,  0.5197,  0.1335,  0.7488,
         -1.5054,  0.9336],
        [ 0.2817,  0.3830,  1.2102, -0.4990, -0.8077,  0.0072,  0.1867,  1.2884,
          1.6660, -0.7949],
        [ 0.1928,  0.0379,  2.0728,  1.9656, -1.4251, -0.9910, -1.8786,  3.0749,
         -0.0796, -2.4746],
        [ 2.7199, -1.1410,  2.3034, -1.9675,  0.5378,  0.0364,  1.9027,  3.4813,
          4.1183, -1.5084],
        [-0.3402,  0.4161, -3.1723,  0.3304,  0.4489, -0.1852, -0.5669, -3.2813,
         -2.4455,  1.2121],
        [-4.0858, -0.1368,  2.9327,  0.5041,  0.4458,  1.8001,  0.7948, -0.0374,
         -0.7481,  2.5106],
        [ 2.5345,  1.6985,  0.1427, -0.0807, -3.0892, -1.6868, -1.6755,  2.3248,
          2.6323, -3.8829],
        [-1.9989, -

We can use `squeeze()` to remove all dimensions with 1 value in them

In [None]:
Q.shape, Q.squeeze().shape

(torch.Size([1, 10, 1, 4]), torch.Size([10, 4]))

We scale the scores.

This prevents gradients from vanishing and hence promotes stable training

In [None]:
scale = 1/sqrt(d_q)
scale

0.5

In [None]:
scale * scores

tensor([[[[ 0.4918,  0.5051, -0.7328, -0.9076, -0.2914,  0.0247,  0.3871,
           -0.6672,  0.8404,  0.0477],
          [-0.6376, -0.1908, -0.1684,  0.9507,  0.1948, -0.0907, -0.6103,
           -0.2979, -1.3998,  0.2232],
          [-0.6647, -0.6587,  0.7333,  0.5792,  0.6609,  0.2599,  0.0668,
            0.3744, -0.7527,  0.4668],
          [ 0.1408,  0.1915,  0.6051, -0.2495, -0.4038,  0.0036,  0.0934,
            0.6442,  0.8330, -0.3974],
          [ 0.0964,  0.0190,  1.0364,  0.9828, -0.7126, -0.4955, -0.9393,
            1.5374, -0.0398, -1.2373],
          [ 1.3599, -0.5705,  1.1517, -0.9838,  0.2689,  0.0182,  0.9513,
            1.7407,  2.0592, -0.7542],
          [-0.1701,  0.2080, -1.5861,  0.1652,  0.2244, -0.0926, -0.2835,
           -1.6406, -1.2228,  0.6060],
          [-2.0429, -0.0684,  1.4664,  0.2521,  0.2229,  0.9000,  0.3974,
           -0.0187, -0.3741,  1.2553],
          [ 1.2672,  0.8493,  0.0714, -0.0403, -1.5446, -0.8434, -0.8377,
            1.1624,  1

And apply a softmax across the key dimension.

So the sum of each row is 1.

In [None]:
A = torch.softmax(scale * scores, dim=-1)
A

tensor([[[[0.1447, 0.1466, 0.0425, 0.0357, 0.0661, 0.0907, 0.1303, 0.0454,
           0.2051, 0.0928],
          [0.0545, 0.0852, 0.0871, 0.2668, 0.1253, 0.0942, 0.0560, 0.0765,
           0.0254, 0.1289],
          [0.0404, 0.0407, 0.1637, 0.1403, 0.1522, 0.1019, 0.0840, 0.1143,
           0.0370, 0.1254],
          [0.0912, 0.0960, 0.1451, 0.0617, 0.0529, 0.0795, 0.0870, 0.1509,
           0.1823, 0.0533],
          [0.0734, 0.0679, 0.1879, 0.1781, 0.0327, 0.0406, 0.0260, 0.3101,
           0.0640, 0.0193],
          [0.1447, 0.0210, 0.1175, 0.0139, 0.0486, 0.0378, 0.0962, 0.2117,
           0.2912, 0.0175],
          [0.0970, 0.1416, 0.0235, 0.1356, 0.1439, 0.1048, 0.0866, 0.0223,
           0.0339, 0.2108],
          [0.0076, 0.0547, 0.2540, 0.0754, 0.0733, 0.1442, 0.0872, 0.0575,
           0.0403, 0.2057],
          [0.2210, 0.1455, 0.0668, 0.0598, 0.0133, 0.0268, 0.0269, 0.1990,
           0.2321, 0.0089],
          [0.0351, 0.0762, 0.0402, 0.2683, 0.1646, 0.0989, 0.0528, 0.0272

In [None]:
A.shape

torch.Size([1, 1, 10, 10])


In other words, the affinitiy for each query i across all keys j should now sum to 1.


In [None]:
A.sum(-1)

tensor([[[1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000]]])

`v` is the multiplication of the attention matrix with the value matrix.


In [None]:
v = torch.einsum("bhls,bshd->blhd", A, V)
v

tensor([[[[ 0.2121, -0.6374,  0.1187,  0.2179]],

         [[ 0.4360, -0.4998, -0.3921, -0.2926]],

         [[ 0.4121, -0.2484,  0.0373, -0.0361]],

         [[ 0.2643, -0.3929,  0.2992,  0.1209]],

         [[ 0.4075, -0.3508,  0.2866, -0.1352]],

         [[ 0.2616, -0.3893,  0.3912,  0.1565]],

         [[ 0.3146, -0.5593, -0.1781, -0.0128]],

         [[ 0.5706,  0.0008,  0.1774,  0.0079]],

         [[ 0.1821, -0.5857,  0.3143,  0.0102]],

         [[ 0.4354, -0.4904, -0.5414, -0.3290]]]])

This is the equivalent of the following matrix operation


In [None]:
torch.matmul(A.squeeze(),V.squeeze())

tensor([[ 0.2121, -0.6374,  0.1187,  0.2179],
        [ 0.4360, -0.4998, -0.3921, -0.2926],
        [ 0.4121, -0.2484,  0.0373, -0.0361],
        [ 0.2643, -0.3929,  0.2992,  0.1209],
        [ 0.4075, -0.3508,  0.2866, -0.1352],
        [ 0.2616, -0.3893,  0.3912,  0.1565],
        [ 0.3146, -0.5593, -0.1781, -0.0128],
        [ 0.5706,  0.0008,  0.1774,  0.0079],
        [ 0.1821, -0.5857,  0.3143,  0.0102],
        [ 0.4354, -0.4904, -0.5414, -0.3290]])

This too has the number of operations proportional to quadratic of the input sequence length (10 mutliplications and 9 additions for every spot in the 10 x 4 matrix).

Overall the 2 main matrix operations (multiplying query and key matricies and then multiplying the scaled output with the value matrix)  are quadratic in input sequence length in both space and time.

This can become challenging when dealing with longer sequences.

## ProbAttention


To deal with this issue, we can decrease the number of multiplication operations by multiplying some subset of the queries $ \bar Q \subseteq Q$

$$
\text{Prob Sparse Attention} = \text{softmax} \bigg( \frac{\bar Q.K^T}{\sqrt d_q} \bigg).V
$$

In [None]:
n_heads = 1 # Assume single head attention
batch_size = 1 # Assume single batch size
sequence_length = L_Q = L_K = L_V = 10 # Number of datapoints passed in parallel
d_model = 4 # Number of features per data point

d_k = d_model // n_heads
d_v = d_model // n_heads
d_q = d_model // n_heads

Q = torch.randn( (batch_size, L_Q, n_heads, d_q) )
K = torch.randn( (batch_size, L_K, n_heads, d_k) )
V = torch.randn( (batch_size, L_V, n_heads, d_v) )

Q.shape, K.shape, V.shape

(torch.Size([1, 10, 1, 4]),
 torch.Size([1, 10, 1, 4]),
 torch.Size([1, 10, 1, 4]))

In [None]:
Q = Q.transpose(2, 1)
K = K.transpose(2, 1)
V = V.transpose(2, 1)
Q.shape, K.shape, V.shape

(torch.Size([1, 1, 10, 4]),
 torch.Size([1, 1, 10, 4]),
 torch.Size([1, 1, 10, 4]))

Let's determine the size of the subset of datapoints during probabilistic attention.

- `L_Q_bar`: This is the number of query vectors we select of the total query vectors to attend.
- `L_K_bar`: This is the number of key vectors we select. Note this is used internally to determine the subset of query vectors in under quadratic time/space complexity. In practice, we are attending on all keys but only a subset of query vectors.

In [None]:
factor = 2 # multiplier
L_K_bar = factor * np.ceil(np.log(L_K)).astype('int').item() # U_part = factor * ln(L_k)
L_Q_bar= factor * np.ceil(np.log(L_Q)).astype('int').item() # u = factor * ln(L_q)
L_Q_bar, L_K_bar

(6, 6)

For very short sequences, we will perform full attention.

For long sequences, we likely perform probabilistic attention on a subset of query data points across all keys

In [None]:
L_K_bar = L_K_bar if L_K_bar < L_K else L_K
L_Q_bar = L_Q_bar if L_Q_bar < L_Q else L_Q
L_Q_bar, L_K_bar

(6, 6)

From this point, we perform operations that will help us select the appropriate query vectors $\bar Q$.

`K_expand`: Every query will have the same 10 x 4 key matrix.

This is an initialization from which we will later extract specific key vectors for each query

In [None]:
B, H, L_K, E = K.shape
_, _, L_Q, _ = Q.shape
# unsqueeze adds dimension, expand with reshape
K_expand = K.unsqueeze(-3).expand(B, H, L_Q, L_K, E)
K_expand

tensor([[[[[ 1.1493, -0.8073, -0.4246, -1.2295],
           [ 1.1640,  0.2684, -0.0041,  1.7503],
           [-0.8284, -0.2130, -1.0211, -0.2565],
           [-1.5922,  0.2786,  0.4121,  0.6504],
           [ 0.3582,  0.3248, -2.1272, -1.5119],
           [ 0.4669,  0.5410, -1.4750, -1.9943],
           [ 0.1041,  0.2502,  1.1637,  0.1986],
           [ 1.1851,  0.3053, -0.1947, -1.3754],
           [ 0.4741, -0.8833, -0.6060, -0.6919],
           [-0.0110,  1.5004, -0.1369,  0.9347]],

          [[ 1.1493, -0.8073, -0.4246, -1.2295],
           [ 1.1640,  0.2684, -0.0041,  1.7503],
           [-0.8284, -0.2130, -1.0211, -0.2565],
           [-1.5922,  0.2786,  0.4121,  0.6504],
           [ 0.3582,  0.3248, -2.1272, -1.5119],
           [ 0.4669,  0.5410, -1.4750, -1.9943],
           [ 0.1041,  0.2502,  1.1637,  0.1986],
           [ 1.1851,  0.3053, -0.1947, -1.3754],
           [ 0.4741, -0.8833, -0.6060, -0.6919],
           [-0.0110,  1.5004, -0.1369,  0.9347]],

          [[ 1.1

In [None]:
K_expand.shape

torch.Size([1, 1, 10, 10, 4])

For every query vector, we now select `L_K_bar` random key vectors.

`index_sample`: matrix of shape (number of query items, subset of key items) where for each query, we are determining randomly with replacement the key vectors to consider attending on.

In [None]:
index_sample = torch.randint(L_K, (L_Q, L_K_bar))
index_sample

tensor([[4, 7, 2, 0, 8, 6],
        [6, 9, 9, 4, 6, 2],
        [8, 5, 9, 5, 9, 3],
        [9, 2, 1, 9, 7, 2],
        [9, 3, 4, 6, 8, 2],
        [5, 6, 3, 2, 5, 9],
        [6, 4, 7, 1, 1, 0],
        [3, 0, 3, 0, 2, 7],
        [6, 2, 2, 3, 9, 6],
        [4, 7, 9, 2, 7, 8]])

In [None]:
index_sample.shape

torch.Size([10, 6])

For each query vector, we now only select the subset of key vectors to attend on using `index_sample`

`K_sample`: Matrix where for each query, we select a subset of `L_K_bar` key vectors.

In [None]:
 K_sample = K_expand[:, :, torch.arange(L_Q).unsqueeze(1), index_sample, :]
 K_sample.shape

torch.Size([1, 1, 10, 6, 4])

In [None]:
K_sample

tensor([[[[[ 0.3582,  0.3248, -2.1272, -1.5119],
           [ 1.1851,  0.3053, -0.1947, -1.3754],
           [-0.8284, -0.2130, -1.0211, -0.2565],
           [ 1.1493, -0.8073, -0.4246, -1.2295],
           [ 0.4741, -0.8833, -0.6060, -0.6919],
           [ 0.1041,  0.2502,  1.1637,  0.1986]],

          [[ 0.1041,  0.2502,  1.1637,  0.1986],
           [-0.0110,  1.5004, -0.1369,  0.9347],
           [-0.0110,  1.5004, -0.1369,  0.9347],
           [ 0.3582,  0.3248, -2.1272, -1.5119],
           [ 0.1041,  0.2502,  1.1637,  0.1986],
           [-0.8284, -0.2130, -1.0211, -0.2565]],

          [[ 0.4741, -0.8833, -0.6060, -0.6919],
           [ 0.4669,  0.5410, -1.4750, -1.9943],
           [-0.0110,  1.5004, -0.1369,  0.9347],
           [ 0.4669,  0.5410, -1.4750, -1.9943],
           [-0.0110,  1.5004, -0.1369,  0.9347],
           [-1.5922,  0.2786,  0.4121,  0.6504]],

          [[-0.0110,  1.5004, -0.1369,  0.9347],
           [-0.8284, -0.2130, -1.0211, -0.2565],
           [ 1

`Q_K_sample`: For each query, we now determine an affinity score with each of the selected key vectors.

In [None]:
Q_K_sample = torch.matmul(Q.unsqueeze(-2), K_sample.transpose(-2, -1)).squeeze(-2)
Q_K_sample.shape

torch.Size([1, 1, 10, 6])

In detail, each 1 x 4 query vector is applied to a 4 x 6 key tensor to generate a 1 x 6 vector of query-key affinities.

In [None]:
Q.unsqueeze(-2).shape, K_sample.transpose(-2, -1).shape

(torch.Size([1, 1, 10, 1, 4]), torch.Size([1, 1, 10, 4, 6]))

In [None]:
Q_K_sample

tensor([[[[-0.1436,  1.3637, -0.7733,  1.7698,  0.9381,  0.4367],
          [-2.3545,  1.2845,  1.2845,  3.7222, -2.3545,  2.6977],
          [-0.0084,  0.4724,  0.7017,  0.4724,  0.7017, -1.4189],
          [-1.3828, -0.1384, -3.6623, -1.3828,  1.2203, -0.1384],
          [ 1.5805,  2.8538,  0.6118, -1.8568, -0.7366,  2.6788],
          [-1.1198,  0.8741, -1.7703, -1.8456, -1.1198,  0.1116],
          [ 2.8242, -4.4661,  0.3125, -1.4249, -1.4249, -0.2983],
          [-0.8213,  2.5528, -0.8213,  2.5528,  1.4454,  1.9796],
          [ 0.0983,  0.6171,  0.6171,  1.9412, -1.6207,  0.0983],
          [-1.8467, -3.0208,  2.9274,  0.6800, -3.0208, -2.4974]]]])

How did we get this shape:
- `Q.unsqueeze(-2)` adds an extra dimension 1 from the last: 1 x 1 x 10 x 1 x 4
- `K_sample.transpose(-2, -1)`: 1 x 1 x 10 x 4 x 6
- matmul matrix multiply the last 2 dimensions: 1 x 1 x 10 x 1 x 6
- `squeeze(-2)`: 1 x 1 x 10 x 6
- This operation takes `O(L_Q.L_K_bar)` time complexity.

`M(j)`: The largest affinity score divergence for query j.

$$ M= \text{Max affinity of Q for any K} -\text{Mean affinity of Q for any K} $$

In [None]:
 M = Q_K_sample.max(-1)[0] - torch.div(Q_K_sample.sum(-1), L_K)
 M

tensor([[[1.4107, 3.2942, 0.6096, 1.7688, 2.3407, 1.3611, 3.2719, 1.8640,
          1.7660, 3.6053]]])

In [None]:
1.7698 - 0.3591

1.4107

In [None]:
M.shape

torch.Size([1, 1, 10])

More on the operation:
- find the Top_k query with sparisty measurement
- `Q_K_sample.max(-1)[0]`: finds the maximum k value for every query
- `Q_K_sample.sum(-1) / L_K`: finds the mean of k value for every query
- Each operation thus involes `O(L_Q.L_K_bar)` multiplications

`M_top`: Index of the subset of `L_Q_bar` queries with the highest affinitiy spreads.

In [None]:
M.topk(L_Q_bar, sorted=False)

torch.return_types.topk(
values=tensor([[[3.2942, 3.6053, 2.3407, 3.2719, 1.8640, 1.7688]]]),
indices=tensor([[[1, 9, 4, 6, 7, 3]]]))

In [None]:
M_top = M.topk(L_Q_bar, sorted=False)[1]
M_top.shape

torch.Size([1, 1, 6])

In [None]:
M_top

tensor([[[1, 9, 4, 6, 7, 3]]])

`Q_bar`: Subset of query vectors determined by `L_Q_bar`.

In [None]:
Q_bar = Q[
    torch.arange(B)[:, None, None],
    torch.arange(H)[None, :, None],
    M_top,
    :
]
Q_bar

tensor([[[[-0.8949,  0.4277, -2.0982,  0.3700],
          [-1.4002,  1.1498, -0.0857,  1.2573],
          [-1.6022, -0.1021, -1.7010,  1.5870],
          [-0.1238, -0.0730,  2.5755, -0.7146],
          [-0.5048, -0.6817, -0.3681, -1.9735],
          [-0.7324,  0.1899,  1.1001, -1.6319]]]])

In [None]:
Q_bar.size()

torch.Size([1, 1, 6, 4])

`Q_bar_K`: Each entry is an affinity score between the subset of queries (that have the highest affinity spreads) and all keys.

In [None]:
Q_bar_K = torch.matmul(Q_bar, K.transpose(-2, -1))
Q_bar_K

tensor([[[[-0.9378, -0.2707,  2.6977,  0.9200,  3.7222,  2.1705, -2.3545,
           -1.0304,  0.2136,  1.2845],
          [-4.0469,  0.8800,  0.6800,  3.3320, -1.8467, -2.4127,  0.2918,
           -3.0208, -2.4974,  2.9274],
          [-2.9880,  0.8924,  2.6788,  2.8538,  0.6118, -1.4594, -1.8568,
           -3.7816, -0.7366,  1.5805],
          [-0.2983, -1.4249, -2.3284,  0.7732, -4.4661, -2.4709,  2.8242,
            0.3125, -1.0606, -1.1285],
          [ 2.5528, -4.2234,  1.4454, -0.8213,  3.3646,  3.8742, -1.0433,
            1.9796,  1.9514, -2.8115],
          [ 0.5441, -3.6623, -0.1384,  0.6110, -0.0735,  1.3926,  0.9275,
            1.2203, -0.0525, -1.3828]]]])

In [None]:
Q_bar.shape, K.shape, Q_bar_K.shape

(torch.Size([1, 1, 6, 4]),
 torch.Size([1, 1, 10, 4]),
 torch.Size([1, 1, 6, 10]))

This matrix multiplication has 4 multiplications and 3 additions for every 6 x 10 cell.

Hence there are `O(L_Q_bar . L_K)` operations performed.

Scale by mutliplying with $ \frac{1}{\sqrt d_q}$

In [None]:
Q_bar_K = 1./sqrt(d_q) * Q_bar_K
Q_bar_K

tensor([[[[-0.4689, -0.1353,  1.3489,  0.4600,  1.8611,  1.0853, -1.1772,
           -0.5152,  0.1068,  0.6422],
          [-2.0235,  0.4400,  0.3400,  1.6660, -0.9234, -1.2063,  0.1459,
           -1.5104, -1.2487,  1.4637],
          [-1.4940,  0.4462,  1.3394,  1.4269,  0.3059, -0.7297, -0.9284,
           -1.8908, -0.3683,  0.7902],
          [-0.1491, -0.7124, -1.1642,  0.3866, -2.2331, -1.2355,  1.4121,
            0.1563, -0.5303, -0.5643],
          [ 1.2764, -2.1117,  0.7227, -0.4107,  1.6823,  1.9371, -0.5217,
            0.9898,  0.9757, -1.4057],
          [ 0.2721, -1.8312, -0.0692,  0.3055, -0.0367,  0.6963,  0.4637,
            0.6102, -0.0263, -0.6914]]]])

We now perform the softmax operation to scale these scores and get the attention matrix

In [None]:
attn = torch.softmax(Q_bar_K, dim=-1)
attn

tensor([[[[0.0309, 0.0431, 0.1903, 0.0782, 0.3176, 0.1462, 0.0152, 0.0295,
           0.0550, 0.0939],
          [0.0088, 0.1031, 0.0933, 0.3512, 0.0264, 0.0199, 0.0768, 0.0147,
           0.0190, 0.2869],
          [0.0149, 0.1038, 0.2536, 0.2768, 0.0902, 0.0320, 0.0263, 0.0100,
           0.0460, 0.1464],
          [0.0864, 0.0492, 0.0313, 0.1477, 0.0108, 0.0292, 0.4119, 0.1173,
           0.0591, 0.0571],
          [0.1438, 0.0049, 0.0826, 0.0266, 0.2157, 0.2784, 0.0238, 0.1079,
           0.1064, 0.0098],
          [0.1128, 0.0138, 0.0802, 0.1166, 0.0828, 0.1724, 0.1366, 0.1581,
           0.0837, 0.0430]]]])

In [None]:
attn.shape

torch.Size([1, 1, 6, 10])

get the context (every 100 vectors is the same 64 dim mean value vector)

In [None]:
V, V.shape

(tensor([[[[ 2.7265, -0.9099, -0.4123,  0.2838],
           [-0.1987, -0.1942, -0.8162,  0.7390],
           [-1.3244, -1.2526,  0.6507, -0.7998],
           [-1.2643, -0.2841,  1.3642,  0.1140],
           [-0.7225, -2.2770, -1.2280, -1.1679],
           [-1.3774, -1.1384, -0.7864, -0.5385],
           [ 0.1696,  1.4029, -0.4873,  1.9040],
           [ 0.4020, -0.0779,  0.4436, -0.9891],
           [ 0.5908, -0.2860,  1.2081,  1.6977],
           [-0.2286, -0.9445, -0.6943, -0.2456]]]]),
 torch.Size([1, 1, 10, 4]))

We now initialize the value vectors. These are the vectors for every data point that will be propagated through the architecture in the event they are not selected in `M_top`.

In [None]:
V_mean = V.mean(dim=-2)
V_mean, V_mean.shape

(tensor([[[-0.1227, -0.5962, -0.0758,  0.0998]]]), torch.Size([1, 1, 4]))

`values`: its the same "average value vector" for query.

In [None]:
values = V_mean.unsqueeze(-2).expand(B, H, L_Q, V_mean.shape[-1]).clone()
values

tensor([[[[-0.1227, -0.5962, -0.0758,  0.0998],
          [-0.1227, -0.5962, -0.0758,  0.0998],
          [-0.1227, -0.5962, -0.0758,  0.0998],
          [-0.1227, -0.5962, -0.0758,  0.0998],
          [-0.1227, -0.5962, -0.0758,  0.0998],
          [-0.1227, -0.5962, -0.0758,  0.0998],
          [-0.1227, -0.5962, -0.0758,  0.0998],
          [-0.1227, -0.5962, -0.0758,  0.0998],
          [-0.1227, -0.5962, -0.0758,  0.0998],
          [-0.1227, -0.5962, -0.0758,  0.0998]]]])

In [None]:
values.shape

torch.Size([1, 1, 10, 4])

`values`: Only the subset of query vectors determined by `M_top` are over ridden. The context of other vectors remains the same as the "average of value vectors"

In [None]:
values[
    torch.arange(B)[:, None, None],
    torch.arange(H)[None, :, None],
    M_top,
    :
] = torch.matmul(attn, V).type_as(values)

In [None]:
values

tensor([[[[-0.1227, -0.5962, -0.0758,  0.0998],
          [-0.6807, -1.2721, -0.3155, -0.4823],
          [-0.1227, -0.5962, -0.0758,  0.0998],
          [-0.1197, -0.5089, -0.0089,  0.0370],
          [-0.7728, -0.7873,  0.2613, -0.1308],
          [-0.1227, -0.5962, -0.0758,  0.0998],
          [ 0.0886,  0.2708, -0.0070,  0.7789],
          [-0.1833, -1.0657, -0.2990, -0.3037],
          [-0.1227, -0.5962, -0.0758,  0.0998],
          [-0.6459, -0.4971,  0.1969,  0.0962]]]])

In [None]:
M_top, attn.shape, V.shape, torch.matmul(attn, V).shape, values.shape

(tensor([[[1, 9, 4, 6, 7, 3]]]),
 torch.Size([1, 1, 6, 10]),
 torch.Size([1, 1, 10, 4]),
 torch.Size([1, 1, 6, 4]),
 torch.Size([1, 1, 10, 4]))

The matrix multiplcation of `attn` and `V` requires `O(L_Q_bar.L_K)` multiplication operations.

Notice only the rows indexed by `M_top` have been over ridden.

In [None]:
out = values.transpose(2, 1).contiguous()
out

tensor([[[[-0.1227, -0.5962, -0.0758,  0.0998]],

         [[-0.6807, -1.2721, -0.3155, -0.4823]],

         [[-0.1227, -0.5962, -0.0758,  0.0998]],

         [[-0.1197, -0.5089, -0.0089,  0.0370]],

         [[-0.7728, -0.7873,  0.2613, -0.1308]],

         [[-0.1227, -0.5962, -0.0758,  0.0998]],

         [[ 0.0886,  0.2708, -0.0070,  0.7789]],

         [[-0.1833, -1.0657, -0.2990, -0.3037]],

         [[-0.1227, -0.5962, -0.0758,  0.0998]],

         [[-0.6459, -0.4971,  0.1969,  0.0962]]]])

In [None]:
out.shape

torch.Size([1, 10, 1, 4])

In [None]:
out = out.view(batch_size, sequence_length, -1)

`out`: each entry is the output embedding for every data point

In [None]:
out.shape

torch.Size([1, 10, 4])

## Cost Analysis

The cost of operations is the following:
- Cost of generating `Q_K_sample`: O(L_Q.L_K_bar)
- Cost of generating `M_top`: O(L_Q.L_K_bar)
- Cost of generating `Q_bar_K`: O(L_Q_bar.L_K)
- Cost of `attn.V`: O(L_Q_bar.L_K)


Total cost = O( 2.L_Q.L_K_bar + 2.L_Q_bar. L_K)

But, let L = L_Q = L_K.

Then L_Q_bar = L_K_bar = c log L = O(log L).

Then Total cost = O( 2 Llog L + 2 Llog L )

**Total cost = O( L log L )**

Hence PropSparse attention can perform similar attention to Full Attention but with lower time and space complexity O( L^2 ) vs O( L log L )