# 001. Why could we replace linear layers with attention layers?

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

torch.manual_seed(42)

<torch._C.Generator at 0x7bf46c0dc270>

In [2]:
# Compute a random linear layer result
x = torch.randn(4, 8)
fc_weight = torch.randn(32, 8)
y1 = x @ fc_weight.T

In [None]:
# Compute `v` of the attention layer inputs
x_abs = x.abs()
x_sum = x_abs.sum(dim=-1, keepdim=True)
attn_weights = (x_abs / x_sum)[:,None]
v = fc_weight.T[None] * (x_sum[...,None] * x.sign()[...,None]) # (4, 8, 32) (bsz, seq_len, out_dim)
_y = (attn_weights @ v)[:,0]

assert torch.allclose(_y, y1)


# Recovered a random `logits` (result of `q @ k^T`)
logits = torch.log(attn_weights)
logits = logits - torch.randn(1)

## Verify that softmax is approximate to the original logits
softmax_result = torch.softmax(logits, dim=-1)
# print("Recovered logits:\n", logits)
# print("Re-softmax result:\n", softmax_result)
# print("Original p:\n", _attn_weights)

assert torch.allclose(softmax_result, attn_weights)

In [None]:
# Recovered a random `q` and `k` of the attention layer inputs

import math

d_k = 8

s = logits * math.sqrt(d_k)

## Use random factorization: s = q @ k^T
## Choose random k, then solve for q = s @ inv(k^T)
k = torch.randn(4, 8, d_k)
k_inv = torch.linalg.inv(k.transpose(-2, -1))
q = s @ k_inv
logits_rec = q @ k.transpose(-2, -1) / math.sqrt(d_k)

assert torch.allclose(logits_rec, logits, atol=1e-3)

In [5]:
# Check whether the output result of attention is consistent with the output result of linear

y2 = F.scaled_dot_product_attention(q, k, v)[:,0]

assert torch.allclose(y2, y1, atol=1e-3)

In fact, the process of backtracking `q`, `k`, and `logits` is not strictly necessary.

After eliminating the effect of `sign bits` and `softmax`, 
we can approximately regard the `input` to a linear layer as the attention layer's computed `attn_weights`,
and the linear layer's weight matrix `fc_weight(out_features, in_features)` as the attention layer's value matrix `v`,
where the `sequence length` equals `in_features`.