In [1]:
import torch
import torch.nn as nn

In [2]:
weight = torch.empty((7, 8))
nn.init.trunc_normal_(weight, a=-3, b=3)
w = nn.Parameter(weight)
w

Parameter containing:
tensor([[-6.2718e-01, -2.3692e-01, -6.5170e-01, -3.7918e-01,  1.8936e+00,
         -1.6325e+00,  7.1764e-01, -1.9833e+00],
        [-5.2158e-01, -2.0997e-01,  7.1299e-01, -6.0251e-01,  5.0125e-04,
         -8.6180e-01,  1.1143e+00,  5.6994e-01],
        [ 4.7557e-01, -9.7971e-02, -7.8141e-01, -2.1923e+00,  8.2116e-02,
          6.2901e-01, -1.6155e+00,  8.3228e-01],
        [-1.1955e+00, -3.8827e-01, -1.2004e+00, -1.6852e-01, -1.0342e-01,
         -1.2304e+00,  5.3759e-01,  1.2281e+00],
        [-1.2974e+00, -7.4639e-01, -7.6854e-01,  2.0425e-01,  5.0526e-01,
         -1.2049e+00,  6.2339e-01,  1.0362e+00],
        [-4.7334e-01, -4.2351e-01,  1.4082e-01, -4.3955e-02,  1.0506e+00,
         -2.1156e-01,  1.2013e+00,  1.4378e+00],
        [-4.0865e-01,  2.8389e-01,  5.7553e-01, -4.6749e-01, -7.3301e-01,
         -1.9630e+00, -2.1937e+00,  1.2181e+00]], requires_grad=True)

In [3]:
token_ids = torch.LongTensor([1, 4, 5, 0])
w[token_ids]

tensor([[-5.2158e-01, -2.0997e-01,  7.1299e-01, -6.0251e-01,  5.0125e-04,
         -8.6180e-01,  1.1143e+00,  5.6994e-01],
        [-1.2974e+00, -7.4639e-01, -7.6854e-01,  2.0425e-01,  5.0526e-01,
         -1.2049e+00,  6.2339e-01,  1.0362e+00],
        [-4.7334e-01, -4.2351e-01,  1.4082e-01, -4.3955e-02,  1.0506e+00,
         -2.1156e-01,  1.2013e+00,  1.4378e+00],
        [-6.2718e-01, -2.3692e-01, -6.5170e-01, -3.7918e-01,  1.8936e+00,
         -1.6325e+00,  7.1764e-01, -1.9833e+00]], grad_fn=<IndexBackward0>)

In [4]:
w.max(dim=-1, keepdim=True).values

tensor([[1.8936],
        [1.1143],
        [0.8323],
        [1.2281],
        [1.0362],
        [1.4378],
        [1.2181]], grad_fn=<MaxBackward0>)

In [5]:
device = 'mps' if torch.backends.mps.is_available() else 'cpu'
d_k = 6
max_seq_len = 10
theta = 0.5
kd = torch.arange(0, d_k, 2, device=device) / d_k
print(kd)
kd.shape

tensor([0.0000, 0.3333, 0.6667], device='mps:0')


torch.Size([3])

In [6]:
positions = torch.arange(max_seq_len, device=device).unsqueeze(1)
print(positions)
positions.shape

tensor([[0],
        [1],
        [2],
        [3],
        [4],
        [5],
        [6],
        [7],
        [8],
        [9]], device='mps:0')


torch.Size([10, 1])

In [7]:
angles = positions / (theta ** kd)
print(angles)
angles.shape

tensor([[ 0.0000,  0.0000,  0.0000],
        [ 1.0000,  1.2599,  1.5874],
        [ 2.0000,  2.5198,  3.1748],
        [ 3.0000,  3.7798,  4.7622],
        [ 4.0000,  5.0397,  6.3496],
        [ 5.0000,  6.2996,  7.9370],
        [ 6.0000,  7.5595,  9.5244],
        [ 7.0000,  8.8194, 11.1118],
        [ 8.0000, 10.0794, 12.6992],
        [ 9.0000, 11.3393, 14.2866]], device='mps:0')


torch.Size([10, 3])

In [8]:
d_model = 20
num_heads = 5
qkv_proj = nn.Linear(d_model, d_model * 3)
x = torch.randn((10, 10, d_model))
qkv = qkv_proj(x)
qkv.size()

torch.Size([10, 10, 60])

In [9]:
q, k, v = qkv.split(d_model, dim=2)
assert q.size() == k.size() == v.size()
q.size()

torch.Size([10, 10, 20])

In [10]:
from einops import rearrange
q = rearrange(q, "batch seq_len (h d_k) -> batch h seq_len d_k", h=num_heads)
q.shape

torch.Size([10, 5, 10, 4])