-
Notifications
You must be signed in to change notification settings - Fork 3
/
Copy pathtest_rotary.py
70 lines (55 loc) · 2.62 KB
/
test_rotary.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import pytest
import torch
from lean_transformer.rotary import RotaryEmbeddings, get_auxiliary_tensors
@pytest.mark.forked
def test_rotary_embeddings():
batch_size = 11
seq_len = 50
nhead = 4
dim = 1024
dtype = torch.float32
device = torch.device("cpu")
base = 10 ** 4
torch.use_deterministic_algorithms(True)
# auxiliary tensors
a, b = get_auxiliary_tensors(seq_len, dim, dtype, device, base)
x, y = Rotary(dim, base).forward(torch.randn(1, seq_len, dim, device=device))
assert torch.allclose(a.view_as(x), x, atol=1e-4, rtol=0)
assert torch.allclose(b.view_as(y), y, atol=1e-4, rtol=0)
# full layer outputs
ref_layer = Rotary(dim, base)
our_layer = RotaryEmbeddings(dim, base)
q = torch.randn(batch_size, seq_len, nhead, dim, dtype=dtype, device=device)
k = torch.randn(batch_size, seq_len, nhead, dim, dtype=dtype, device=device)
q_ref, k_ref = apply_rotary_pos_emb(q.permute(1, 0, 2, 3), k.permute(1, 0, 2, 3), *ref_layer(k))
q_our, k_our = our_layer(q), our_layer(k)
assert torch.allclose(q_ref.permute(1, 0, 2, 3), q_our, atol=1e-4, rtol=0)
assert torch.allclose(k_ref.permute(1, 0, 2, 3), k_our, atol=1e-4, rtol=0)
# check rotation equivariance of dot product
original_dot = (q[0, :, 0] * k[0, :, 0]).sum(-1)
rotated_dot = (our_layer(q)[0, :, 0] * our_layer(k)[0, :, 0]).sum(-1)
assert torch.allclose(original_dot, rotated_dot, atol=1e-4, rtol=0)
class Rotary(torch.nn.Module):
"""Reference implementation by ElutherAI"""
def __init__(self, dim, base=10000):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
def forward(self, x, seq_dim=1):
seq_len = x.shape[seq_dim]
if seq_len != self.seq_len_cached:
self.seq_len_cached = seq_len
t = torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()[:, None, None, :]
self.sin_cached = emb.sin()[:, None, None, :]
return self.cos_cached, self.sin_cached
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat((-x2, x1), dim=x1.ndim - 1) # dim=-1 triggers a bug in torch < 1.8.0
def apply_rotary_pos_emb(q, k, cos, sin):
return (q * cos) + (rotate_half(q) * sin), (k * cos) + (rotate_half(k) * sin)