In [1]:
from modeling_gpt_neox import GPTNeoXForCausalLM
model = GPTNeoXForCausalLM.from_pretrained("trl-internal-testing/tiny-random-GPTNeoXForCausalLM")
from modeling_gpt_neox import GPTNeoXAttention
attention = GPTNeoXAttention(config=model.config)

In [2]:
new_state_dict = {}
for k, v in model.state_dict().items():
    if k.startswith("gpt_neox.layers.1.attention."):
        new_state_dict[k.replace("gpt_neox.layers.1.attention.", "")] = v

attention.load_state_dict(new_state_dict)

<All keys matched successfully>

In [3]:
import torch
inputState = torch.randn(9, 5, model.config.hidden_size)

attention(inputState)

(tensor([[[-6.9567e-03,  3.3929e-03,  1.0490e-02,  ..., -8.9337e-03,
            1.2362e-02,  1.7269e-03],
          [ 1.2703e-02,  1.9357e-03,  7.1906e-03,  ..., -2.1879e-03,
            1.2637e-02, -8.9477e-04],
          [ 5.2897e-03,  4.8720e-03,  3.7390e-03,  ...,  1.5283e-04,
            1.4059e-02, -3.1914e-03],
          [ 2.1649e-03,  5.0285e-03,  3.8143e-03,  ...,  2.2147e-03,
            1.0520e-02, -2.4457e-03],
          [ 2.3755e-03,  3.5694e-03,  4.6525e-03,  ..., -1.5534e-03,
            3.6072e-03, -1.9443e-03]],
 
         [[-2.9194e-03,  7.6916e-03,  4.2001e-03,  ...,  8.7902e-03,
            2.7098e-02, -1.8283e-03],
          [-1.5174e-03, -3.0349e-03, -1.1576e-02,  ...,  9.2955e-03,
            1.7590e-02, -3.2332e-03],
          [ 3.4629e-03, -2.9270e-03, -1.1274e-02,  ...,  1.0710e-02,
            6.0601e-03,  9.1808e-04],
          [ 2.9374e-03, -6.2745e-03, -1.2581e-02,  ...,  1.0311e-02,
            2.1893e-03, -1.6492e-03],
          [-7.6566e-04,  2.4798e-0

In [4]:
print(model.config.hidden_size)
print(model.config.num_attention_heads)
print(model.config.max_position_embeddings)
print(model.config.hidden_size // model.config.num_attention_heads)

32
4
512
8


In [15]:
from torch.nn import functional as F
import math
import torch
import torch.nn as nn
class CausalSelfAttention(nn.Module):

    def __init__(self, config):
        super().__init__()
        assert config.hidden_size % config.num_attention_heads == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.hidden_size, 3 * config.hidden_size, bias=True)#config.bias)
        # output projection
        self.c_proj = nn.Linear(config.hidden_size, config.hidden_size, bias=True) #n_embd = hidden_size?
        # regularization
        self.attn_dropout = nn.Dropout(0.0)
        self.resid_dropout = nn.Dropout(0.0)
        self.n_head = config.num_attention_heads #config.n_head
        self.n_embd = config.hidden_size
        self.dropout = 0.0
        # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
        self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') and self.dropout == 0.0
        if not self.flash:
            print("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0")
            # causal mask to ensure that attention is only applied to the left in the input sequence
            self.register_buffer("bias", torch.tril(torch.ones(config.max_position_embeddings, config.max_position_embeddings))
                                        .view(1, 1, config.max_position_embeddings, config.max_position_embeddings)) #block_size? -> am I using this wrong?

    def forward(self, x):
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
        head_size = C // self.n_head
        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        qkv = self.c_attn(x)
        
        qkv = qkv.view(B, T, self.n_head, 3 * head_size)
        q, k, v = qkv.split(head_size, dim=3)
        
        k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
        q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
        v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)

        # TODO: Alexs addition -> works for 1,1 -> now figure out how to make it work for BIGGER
        # q, k ,v  = qkv.view(B,T,4,3,8).transpose(0,3)#qkv.split(self.n_embd, dim=2)#
        # print(q.size()) #SAME
        # print('q1', q)
        # print('k', k)
        # # print(q.shape, k.shape, v.shape)
        # <OLD>
        # q, k, v = qkv.split(self.n_embd, dim=2)
        # k = k.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
        # q = q.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
        # v = v.view(B, T, self.n_head, head_size).transpose(1, 2) # (B, nh, T, hs)
        # print('k', k)
        # print('qkv1', qkv)
        # </OLD>
        
        # <NEW>
        
        # qkv = qkv.view(B, T, self.n_head, 3 * head_size)
        
        # # [batch, seq_len, num_attention_heads, 3 * head_size] --> 3 [batch, num_attention_heads, seq_len, head_size]
        # q = qkv[..., : head_size].transpose(1, 2)
        # k = qkv[..., head_size : 2 * head_size].transpose(1, 2)
        # v = qkv[..., 2 * head_size :].transpose(1, 2)
        # print('q2', q)
        # </NEW>
        # <NEWNEW>
        # qkv = qkv.view(B, T, self.n_head, 3 * head_size)
        # q, k, v = qkv.split(head_size, dim=-1)
        # q = q.transpose(1, 2)
        # k = k.transpose(1, 2)
        # v = v.transpose(1, 2)
        # </NEWNEW>
        # <NEWNEWNEW>
        qkv = qkv.view(B, T, self.n_head, 3 * head_size)
        q, k, v = qkv.split(head_size, dim=3)
        q = q.view(B, T, self.n_head, head_size).transpose(1, 2)
        k = k.view(B, T, self.n_head, head_size).transpose(1, 2)
        v = v.view(B, T, self.n_head, head_size).transpose(1, 2)
        # </NEWNEWNEW>
        
        # print(q.shape, k.shape, v.shape)
        # print(q[0][0][0][1], k[0][0][0][1], v[0][0][0][1])
        # print(q[0][0][0][0], k[0][0][0][0], v[0][0][0][0])

        # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
        if self.flash:
            # efficient attention using Flash Attention CUDA kernels
            y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=True)
        else:
            # manual implementation of attention
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            att = att.masked_fill(self.bias[:,:,:T,:T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            att = self.attn_dropout(att)
            y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
        # print(y.shape)
        # output projection
        # print(y[0][0][0])
        y = self.resid_dropout(self.c_proj(y))
        return y

attention2 = CausalSelfAttention(model.config)
new_state_dict = {
    "c_proj.weight": model.state_dict()["gpt_neox.layers.1.attention.dense.weight"],
    "c_proj.bias": model.state_dict()["gpt_neox.layers.1.attention.dense.bias"],
    "c_attn.weight": model.state_dict()["gpt_neox.layers.1.attention.query_key_value.weight"],
    "c_attn.bias": model.state_dict()["gpt_neox.layers.1.attention.query_key_value.bias"],
    "bias": model.state_dict()["gpt_neox.layers.1.attention.bias"],
}
# print([(k, v.size()) for k,v in attention2.state_dict().items()])
# print([(k, v.size()) for k,v in new_state_dict.items()])

attention2.load_state_dict(new_state_dict)

from modeling_gpt import GPT2Attention
gpt2Attention = GPT2Attention(model.config)

new_state_dict['c_attn.weight'] = new_state_dict['c_attn.weight'].t()
new_state_dict['c_proj.weight'] = new_state_dict['c_proj.weight'].t()
new_state_dict['masked_bias'] = gpt2Attention.state_dict()['masked_bias']
gpt2Attention.load_state_dict(new_state_dict)


x = attention2(inputState)
y = attention(inputState)[0]
torch.allclose(x,y)
# z = gpt2Attention(inputState)[0]



True

In [None]:

x[0][0][0], y[0][0][0]#, z[0][0][0]

(tensor(0.0066, grad_fn=<SelectBackward0>),
 tensor(0.0002, grad_fn=<SelectBackward0>))

In [None]:
torch.allclose(x,y)

False

# SUCCESS
It worked, we figured it out!

So, the *two* differneces between the self attentions are:
1. How they each unpack the `qkv`'s
2. Addition of rotary embeddings

It is *completely* unclear to me why they each do these different variations for the unpacking. If I'm lucky I might be able to figure out how to reshape things to get the same results.

I think that's the next test I should do in that case... isolate out each of these pieces and try to decipher how they manipulate the various weights
If I can end up minimising the difference between them, then *fuck yes*.
