In [None]:

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch

In [None]:

def apply_embed(inp, sd): 
    wte = sd['transformer.wte.weight']
    wpe = sd['transformer.wpe.weight']

    token_emb = wte[inp]
    pos_emb = wpe[torch.arange(0, len(inp))]
    out = token_emb + pos_emb
    return out

def layer_norm(inputs, weights, bias): 
    mean = inputs.mean(-1, keepdim=True)
    std = inputs.std(-1, keepdim=True)
    return (inputs - mean) / (std + 1e-5) * weights + bias


def attention(inp, attn_weight, attn_bias, attn_proj_weight, attn_proj_bias): 
    out = inp @ attn_weight.T + attn_bias
    query, key_value = out.split((2048, 2*128), dim=-1) 
    # q: 4, 2048 
    # k: 4, 256
    key, value = key_value.split((128, 128), dim=-1) # 4, 128

    # run Multi query attention 
    query_length = query.shape[-1]
    query = query.reshape(64, 128)
    attn_weights = query @ key.T / 128**0.5 # 64, 4
    attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)

    attn_out = attn_weights @ value # 64, 128
    attn_out = attn_out.reshape(4, 2048)

    attn_out = attn_out @ attn_proj_weight.T + attn_proj_bias

    return attn_out

def apply_mlp(inp, mlp, mlpb, mlp2, mlp2b): 
    out = inp @ mlp.T + mlpb
    out = out @ mlp2.T + mlp2b
    return out



def apply_transformer_block(inp, sd, block_num): 
    prefix = f'transformer.h.{block_num}.'

    ln1 = sd[prefix+'ln_1.weight']
    ln1b = sd[prefix+'ln_1.bias']

    residual = inp

    out = layer_norm(inp, ln1, ln1b)

    attn_weight = sd[prefix+'attn.c_attn.weight']
    attn_bias = sd[prefix+'attn.c_attn.bias']
    attn_proj_weight = sd[prefix+'attn.c_proj.weight']
    attn_proj_bias = sd[prefix+'attn.c_proj.bias']

    out = attention(out, attn_weight, attn_bias, attn_proj_weight, attn_proj_bias)

    out = residual + out

    ln2 = sd[prefix+'ln_2.weight']
    ln2b = sd[prefix+'ln_2.bias']

    out = layer_norm(out, ln2, ln2b)

    mlp = sd[prefix+'mlp.c_fc.weight']
    mlpb = sd[prefix+'mlp.c_fc.bias']
    mlp2 = sd[prefix+'mlp.c_proj.weight']
    mlp2b = sd[prefix+'mlp.c_proj.bias']

    out = apply_mlp(out, mlp, mlpb, mlp2, mlp2b)

    return out


def run_inference(inp, sd):
    out = apply_embed(inp, sd)
    for i in range(24): 
        out = apply_transformer_block(out, sd, i)
    out = layer_norm(out, sd['transformer.ln_f.weight'], sd['transformer.ln_f.bias'])
    out = out @ sd['lm_head.weight'].T
    return out


In [None]:
model_id = "bigcode/starcoderbase-1b"
tokenizer = AutoTokenizer.from_pretrained(model_id)

model = AutoModelForCausalLM.from_pretrained(model_id)
sd = model.state_dict() 

In [None]:
text = "Hello my name is"
inputs = tokenizer(text, return_tensors="pt")
seq = inputs['input_ids'][0]



In [None]:
# generate 10 tokens 
for i in range(10): 
    logits = run_inference(seq, sd)
    idx = torch.argmax(logits[-1], dim=-1)
    seq = torch.cat((seq, idx.unsqueeze(0)), dim=0)




In [None]:
se

In [None]:
start

In [None]:
logits.shape

In [None]:
attn_out.shape

In [None]:
value.shape

In [None]:
model.transformer.h[0].attn

In [None]:
atten = atten.reshape(ln1_out.shape[0], 16, -1)
# move the head to the batch dimension 
atten = atten.permute(1, 0, 2)
# split into keys, queries and values
q, k, v = torch.split(atten, 48, dim=-1)

# compute the attention score
attn_score = torch.matmul(q, k.transpose(-1, -2))
attn_score = attn_score / torch.sqrt(torch.tensor(48.0))
attn_score = torch.nn.functional.softmax(attn_score, dim=-1)

# compute the attention output
attn_output = torch.matmul(attn_score, v)
attn_output = attn_output.permute(1, 0, 2).reshape(ln1_out.shape)


In [None]:
ln1_out.shape # 4, 2048
attn_weight.shape # 2304, 2048
attn_bias.shape # 2304

# n_heads, seq_len, multi_query = 16, 4, True

In [None]:
ln1_out.shape

In [None]:
attn_output.shape

In [None]:
attn_output.permute(1, 0, 2).shape

In [None]:
2304/16/3

In [None]:
ln1_out.shape

In [None]:
layer_0_keys = [k for k in sd.keys() if k.startswith("transformer.h.0")]

In [None]:
layer_0_keys