In [39]:
import math
import torch
import torch.nn as nn
import torch.nn.functional as F

In [113]:
class Pinard(nn.Module):
    def __init__(self, in_features, out_features, N):
        """
        replaces nn.Linear(in_features, out_features)
        """
        super().__init__()

        self.k = nn.Parameter(torch.randn((N, in_features)))
        self.v = nn.Parameter(torch.randn((N, out_features)))

    def _norm_scores(self, scores):
        norm_outputs = scores / torch.norm(scores, p=2, dim=-1, keepdim=True) * math.sqrt(scores.shape[-1])
        return F.gelu(norm_outputs)

    def forward(self, q):
        # same shapes and overall computations as standard attention with T=1 or N, n_head=1, head_dim=d1 or d2
        # q: (B, T, n_head, head_dim) = (B, 1, 1, d1) = (B, d1), k: (T, 1, d1), v: (T, 1, d2)
        scores = q @ self.k.T # (B, N)
        out = self._norm_scores(scores) @ self.v # (B, d2)
        return out

In [160]:
d2, d1 = 768, 4*768
N = (4*d2)//5
B = 4

pattention = Pinard(in_features=d1, out_features=d2, N=N)

In [161]:
x = torch.randn(B, d1)
out = pattention(x)

In [162]:
out.shape

torch.Size([4, 768])

In [163]:
d1*d2, sum(p.numel() for p in pattention.parameters())

(2359296, 2357760)