In [1]:
from modeling_gpt_neox import GPTNeoXForCausalLM
model = GPTNeoXForCausalLM.from_pretrained("trl-internal-testing/tiny-random-GPTNeoXForCausalLM")
from modeling_gpt_neox import GPTNeoXAttention
attention = GPTNeoXAttention(config=model.config)

In [2]:
new_state_dict = {}
for k, v in model.state_dict().items():
    if k.startswith("gpt_neox.layers.1.attention."):
        new_state_dict[k.replace("gpt_neox.layers.1.attention.", "")] = v

attention.load_state_dict(new_state_dict)

<All keys matched successfully>

In [3]:
import torch
inputState = torch.randn(9, 5, model.config.hidden_size)

attention(inputState)

(tensor([[[-2.4565e-02,  1.0580e-03,  2.1217e-03,  ..., -1.8174e-03,
            1.3156e-02,  1.6827e-03],
          [ 1.4614e-03,  5.8233e-03,  3.3597e-03,  ..., -2.6782e-03,
            1.2068e-02,  5.5921e-03],
          [ 7.4161e-03, -6.1869e-03, -3.0464e-03,  ..., -5.6295e-03,
            3.9184e-03,  1.0624e-02],
          [ 1.1426e-02, -5.5984e-03, -2.1029e-03,  ..., -5.5380e-03,
            1.0034e-03,  9.9084e-03],
          [ 1.0451e-02, -1.0159e-02, -4.2054e-03,  ..., -1.1213e-03,
           -1.1776e-03,  7.0760e-03]],
 
         [[ 8.6722e-04, -7.8111e-03, -3.4572e-04,  ..., -9.0739e-03,
            1.0469e-02,  8.3143e-04],
          [ 7.0545e-03, -1.0884e-02,  8.1043e-03,  ..., -6.9769e-03,
            5.4900e-03,  7.9795e-03],
          [-1.1418e-05, -4.8175e-03,  8.2175e-03,  ..., -5.4572e-03,
            1.3758e-02,  6.2788e-03],
          [-1.1667e-03, -8.8485e-03,  9.1316e-03,  ..., -5.6465e-03,
            4.1825e-03,  3.7440e-03],
          [ 5.2492e-04, -1.0774e-0

In [4]:
print(model.config.hidden_size)
print(model.config.num_attention_heads)
print(model.config.max_position_embeddings)
print(model.config.hidden_size // model.config.num_attention_heads)

32
4
512
8


In [9]:
from torch.nn import functional as F
import math
import torch
import torch.nn as nn
from rotary_embedding_torch import RotaryEmbedding


class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.hidden_size % config.num_attention_heads == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=True)#config.bias)
        # output projection
        self.c_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True) #n_embd = hidden_size?
        # regularization
        self.attn_dropout = nn.Dropout(0.0)
        self.resid_dropout = nn.Dropout(0.0)
        self.n_head = config.num_attention_heads #config.n_head
        self.n_embd = config.hidden_size
        self.dropout = 0.0
        # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and self.dropout == 0.0
        
        self.rotary_emb = RotaryEmbedding(
            dim = int((config.hidden_size // config.num_attention_heads) * config.rotary_pct)
        )
        
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(config.max_position_embeddings, config.max_position_embeddings))
                                        .view(1, 1, config.max_position_embeddings, config.max_position_embeddings)) #block_size? -> am I using this wrong?

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        head_size = C // self.n_head
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        qkv = self.c_attn(x)
        
        # GPT has things arranged as k,k,k,k,v,v,v,v,q,q,q,q whereas NeoX has them as k,q,v,k,q,v,k,q,v,k,q,v
        qkv = qkv.view(B, T, self.n_head, 3 * head_size)
        q, k, v = qkv.split(head_size, dim=3)
        
        k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)

        q = self.rotary_emb.rotate_queries_or_keys(q)
        k = self.rotary_emb.rotate_queries_or_keys(k)

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
        else:
            # manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side

        # output projection
        y = self.resid_dropout(self.c_proj(y))
        return y

attention2 = CausalSelfAttention(model.config)
new_state_dict = {
    "c_proj.weight": model.state_dict()["gpt_neox.layers.1.attention.dense.weight"],
    "c_proj.bias": model.state_dict()["gpt_neox.layers.1.attention.dense.bias"],
    "c_attn.weight": model.state_dict()["gpt_neox.layers.1.attention.query_key_value.weight"],
    "c_attn.bias": model.state_dict()["gpt_neox.layers.1.attention.query_key_value.bias"],
    "bias": model.state_dict()["gpt_neox.layers.1.attention.bias"],
    "rotary_emb.freqs": attention2.state_dict()["rotary_emb.freqs"]
}
# print([(k, v.size()) for k,v in attention2.state_dict().items()])
# print([(k, v.size()) for k,v in new_state_dict.items()])

attention2.load_state_dict(new_state_dict)

x = attention2(inputState)
y = attention(inputState)[0]
torch.allclose(x,y)
# z = gpt2Attention(inputState)[0]



True

In [6]:

x[0][0][0], y[0][0][0]#, z[0][0][0]

(tensor(-0.0246, grad_fn=<SelectBackward0>),
 tensor(-0.0246, grad_fn=<SelectBackward0>))

In [7]:
torch.allclose(x,y)

True

# SUCCESS
It worked, we figured it out!

So, the *two* differneces between the self attentions are:
1. How they each unpack the `qkv`'s
2. Addition of rotary embeddings

It is *completely* unclear to me why they each do these different variations for the unpacking. If I'm lucky I might be able to figure out how to reshape things to get the same results.

I think that's the next test I should do in that case... isolate out each of these pieces and try to decipher how they manipulate the various weights
If I can end up minimising the difference between them, then *fuck yes*.


In [8]:
# Now time for rotary embeddings

# Seems like https://github.com/lucidrains/rotary-embedding-torch/blob/main/README.md is a clean implementation
# First step is to pull it in and see if it works

def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, offset: int = 0):
    cos = cos[..., offset : q.shape[-2] + offset, :]
    sin = sin[..., offset : q.shape[-2] + offset, :]
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

class RotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings, base=10000, device=None):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.cos_cached = emb.cos()[None, None, :, :]
        self.sin_cached = emb.sin()[None, None, :, :]

    # x: [bs, num_attention_heads, seq_len, head_size]
    def forward(self, x, seq_len=None):
        if seq_len > self.max_seq_len_cached:
            raise ValueError("seq_len {} is larger than max_position_embeddings {}".format(seq_len, self.max_seq_len_cached))

        return self.cos_cached[:seq_len, ...].to(x.device), self.sin_cached[:seq_len, ...].to(x.device)

query_rot = query[..., : self.rotary_ndims]
query_pass = query[..., self.rotary_ndims :]
key_rot = key[..., : self.rotary_ndims]
key_pass = key[..., self.rotary_ndims :]
seq_len = key.shape[-2]
cos, sin = self.rotary_emb(value, seq_len=seq_len)

query, key = apply_rotary_pos_emb(query_rot, key_rot, cos, sin, offset=offset)

query = torch.cat((query, query_pass), dim=-1)
key = torch.cat((key, key_pass), dim=-1)


NameError: name 'query' is not defined

In [None]:
!pip install rotary-embedding-torch

Collecting rotary-embedding-torch
  Downloading rotary_embedding_torch-0.2.1-py3-none-any.whl (4.5 kB)
Collecting einops>=0.3
  Downloading einops-0.6.0-py3-none-any.whl (41 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m41.6/41.6 kB[0m [31m3.3 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: einops, rotary-embedding-torch
Successfully installed einops-0.6.0 rotary-embedding-torch-0.2.1
