In [1]:
import numpy as np
import torch
import torch.nn.functional as F

### Approximating within-softmax QK decomposition to bring dimensionality reduction outside softmax

In [142]:
psd_ratio = 1  # test how error increases as Q becomes less PSD
d = 64
e = 200
l = 400
assert l >= d
arange_d = torch.arange(d, 0, -1)
psd_matrix = torch.Tensor(np.random.rand(d, d))
psd_matrix = torch.matmul(psd_matrix, psd_matrix)
psd_matrix = torch.cat([psd_matrix, torch.ones(l - d, d)], dim=0)
q = torch.Tensor(np.random.rand(l, d)) * (1 - psd_ratio) + psd_matrix * psd_ratio
q_inv = torch.linalg.pinv(q)
k = torch.Tensor(np.random.rand(l, d)) * (1 - psd_ratio) + psd_matrix * psd_ratio

# make sure query vectors are not all zero (creates large error when solving for K_hat, and will never occur in practice)

q, k = q.type(torch.float), k.type(torch.float)
v = torch.Tensor(np.random.rand(l, d))

In [143]:
A = torch.softmax(q @ k.T, dim=-1)
out = A @ v

In [144]:
# decompose A into USD
U, s, D = torch.linalg.svd(A)
S = torch.diag(s)

# take highest singular value vectors
U = U[:,:e]
S = S[:e,:e]
D = D[:e,:]

(A, U @ S @ D)

(tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
 tensor([[ 6.1396e-09, -3.1815e-08,  0.0000e+00,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [-5.8241e-08,  2.2453e-05,  1.4986e-13,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 3.1224e-09,  1.0219e-06,  7.5583e-08,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         ...,
         [ 3.1224e-09,  1.0219e-06,  1.9642e-14,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 3.1224e-09,  1.0219e-06,  1.9642e-14,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00],
         [ 3.1224e-09,  1.0219e-06,  1.9642e-14,  ...,  0.0000e+00,
           0.0000e+00,  0.0000e+00]]))

In [145]:
# compute pseudo A
US = U @ S

# compute log offset
offset = 1 - US.min()
US_ = US + offset
M_offset = torch.linalg.lstsq(US_, US).solution

# make rows of pseudo A sum to 1
rowsum = US_.sum(dim=-1)
A_hat = US_ / rowsum[:,None]

# compute V_hat
V_hat = M_offset @ D @ v
(A, US_ @ M_offset @ D) 

(tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
 tensor([[2.9765e-09, 1.0729e-06, 2.3208e-15,  ..., 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [2.9765e-09, 1.0729e-06, 2.3208e-15,  ..., 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [2.9765e-09, 1.0729e-06, 2.3208e-15,  ..., 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         ...,
         [2.9765e-09, 1.0729e-06, 2.3208e-15,  ..., 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [2.9765e-09, 1.0729e-06, 2.3208e-15,  ..., 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [2.9765e-09, 1.0729e-06, 2.3208e-15,  ..., 0.0000e+00, 0.0000e+00,
          0.0000e+00]]))

In [146]:
# solve for K_hat
K_hat = torch.linalg.lstsq(q, torch.log(US_)).solution.T
(torch.log(US_), q @ K_hat.T)  # overdetermined system when l > d (always)

(tensor([[ 6.9315e-01, -5.9605e-08,  1.0729e-06,  ...,  1.0729e-06,
           1.0729e-06,  1.0729e-06],
         [ 6.9315e-01,  2.2411e-05,  1.0729e-06,  ...,  1.0729e-06,
           1.0729e-06,  1.0729e-06],
         [ 6.9315e-01,  1.0729e-06,  9.5367e-07,  ...,  1.0729e-06,
           1.0729e-06,  1.0729e-06],
         ...,
         [ 6.9315e-01,  1.0729e-06,  1.0729e-06,  ...,  1.0729e-06,
           1.0729e-06,  1.0729e-06],
         [ 6.9315e-01,  1.0729e-06,  1.0729e-06,  ...,  1.0729e-06,
           1.0729e-06,  1.0729e-06],
         [ 6.9315e-01,  1.0729e-06,  1.0729e-06,  ...,  1.0729e-06,
           1.0729e-06,  1.0729e-06]]),
 tensor([[6.9587e-01, 1.4325e-07, 1.0771e-06,  ..., 1.0774e-06, 1.0774e-06,
          1.0774e-06],
         [6.8875e-01, 2.2121e-05, 1.0661e-06,  ..., 1.0656e-06, 1.0655e-06,
          1.0655e-06],
         [6.9417e-01, 1.1348e-06, 9.5517e-07,  ..., 1.0745e-06, 1.0745e-06,
          1.0745e-06],
         ...,
         [6.9314e-01, 1.0727e-06, 1.0729e-0

In [147]:
# recompute A prime
A_hat_p = torch.softmax(q @ K_hat.T, dim=-1)
(A_hat, A_hat_p)

(tensor([[0.0100, 0.0050, 0.0050,  ..., 0.0050, 0.0050, 0.0050],
         [0.0100, 0.0050, 0.0050,  ..., 0.0050, 0.0050, 0.0050],
         [0.0100, 0.0050, 0.0050,  ..., 0.0050, 0.0050, 0.0050],
         ...,
         [0.0100, 0.0050, 0.0050,  ..., 0.0050, 0.0050, 0.0050],
         [0.0100, 0.0050, 0.0050,  ..., 0.0050, 0.0050, 0.0050],
         [0.0100, 0.0050, 0.0050,  ..., 0.0050, 0.0050, 0.0050]]),
 tensor([[0.0100, 0.0050, 0.0050,  ..., 0.0050, 0.0050, 0.0050],
         [0.0099, 0.0050, 0.0050,  ..., 0.0050, 0.0050, 0.0050],
         [0.0100, 0.0050, 0.0050,  ..., 0.0050, 0.0050, 0.0050],
         ...,
         [0.0100, 0.0050, 0.0050,  ..., 0.0050, 0.0050, 0.0050],
         [0.0100, 0.0050, 0.0050,  ..., 0.0050, 0.0050, 0.0050],
         [0.0100, 0.0050, 0.0050,  ..., 0.0050, 0.0050, 0.0050]]))

In [148]:
# reconstruct original attention matrix
A_p = (A_hat @ M_offset @ D) * rowsum[:,None]
A_pp = (A_hat_p @ M_offset @ D) * rowsum[:,None]
(A, A_p, A_pp)

(tensor([[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]]),
 tensor([[2.9765e-09, 1.0729e-06, 2.3208e-15,  ..., 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [2.9765e-09, 1.0729e-06, 2.3208e-15,  ..., 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [2.9765e-09, 1.0729e-06, 2.3208e-15,  ..., 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         ...,
         [2.9765e-09, 1.0729e-06, 2.3208e-15,  ..., 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [2.9765e-09, 1.0729e-06, 2.3208e-15,  ..., 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [2.9765e-09, 1.0729e-06, 2.3208e-15,  ..., 0.0000e+00, 0.0000e+00,
          0.0000e+00]]),
 tensor([[2.9766e-09, 1.0729e-06, 2.3209e-15,  ..., 0.0000e+00, 0.0000e+00,
          0.0000e+00],
         [2.9764e-09, 1.0728e

In [149]:
# compute new outputs
out_p = (A_hat @ V_hat) * rowsum[:,None]
out_pp = (A_hat_p @ V_hat) * rowsum[:,None]
(out, out_p, out_pp)
(((out - out_pp) ** 2).sum(dim=-1) ** 0.5).mean()

tensor(3.2497e-05)

### Potential issues
- Recomputing softmax for A_pp without original scaling factor of A_p
- Scaling in alpha and inverse alpha at different points where they dont cancel
- All-zero query vectors will make solving for K_hat ineffective
- Q x K_hat = A_hat is overdetermined for K_hat