In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-3b")
hf_model = AutoModelForCausalLM.from_pretrained("./checkpoints/bigcode/starcoder2-3b", torch_dtype=torch.float16).to("cuda")

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
from model import ModelArgs, Transformer
from util import load_model
from pathlib import Path

ckpt_path = Path("./checkpoints/bigcode/starcoder2-3b/model.pth")
with torch.device("meta"):
    model = Transformer.from_name(ckpt_path.parent.name)
model = load_model(model, "./checkpoints/bigcode/starcoder2-3b/model.pth", device="cuda", precision=torch.bfloat16)

In [19]:
from torch.nn.attention.flex_attention import create_block_mask

def causal_mask(b, h, q_idx, kv_idx):
    return q_idx >= kv_idx


def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
    # input_pos: [B, S]
    mask = create_block_mask(causal_mask, 1, 1, input_pos.shape[0], model.max_seq_length, device=x.device)
    logits = model(mask, x, input_pos)
    return logits

device = torch.device('cuda')
with device:
    with torch.inference_mode():
        input_ids = tokenizer.encode("def foo(x):")
        input_ids = torch.tensor(input_ids).unsqueeze(0)
        input_pos = torch.arange(input_ids.shape[1])
        model.setup_caches(1, 4096)
        logits = prefill(model, input_ids, input_pos)
logits[0, 0]

tensor([ 5.0312, -0.1787,  1.3203,  ..., -3.7969, -2.5469, -1.5000],
       device='cuda:0', dtype=torch.bfloat16)

In [20]:
with torch.inference_mode():
    hf_logits = hf_model(input_ids).logits
hf_logits[0, 0]

tensor([ 5.1953, -0.0202,  1.5137,  ..., -3.6777, -2.3906, -1.3330],
       device='cuda:0', dtype=torch.float16)

In [8]:
hf_model.model.layers[0].self_attn.v_proj.bias

Parameter containing:
tensor([-0.1064, -0.0704, -0.0350,  0.0284,  0.3872,  0.0245, -0.0554, -0.0607,
        -0.0540, -0.0268,  0.0265,  0.0748, -0.0239, -0.0193, -0.0831,  0.0295,
         0.0307,  0.0203,  0.0144, -0.0967,  0.0245,  0.0349,  0.0264,  0.0864,
         0.0653,  0.0370,  0.0433, -0.0028,  0.0972, -0.0634,  0.0167, -0.0210,
        -0.0120,  0.0177, -0.0409, -0.0268, -0.0420,  0.0989,  0.0662,  0.0100,
         0.0153,  0.0403,  0.0277,  0.0042,  0.0643,  0.0005, -0.0150,  0.0175,
        -0.0486,  0.0706, -0.0342,  0.0209,  0.0337, -0.0156, -0.0322, -0.0441,
         0.0263,  0.0255,  0.0206,  0.0764,  0.4221,  0.0052,  0.1537,  0.0377,
         0.0268,  0.0210, -0.0186, -0.0291, -0.0569, -0.0666,  0.0228, -0.0198,
         0.0040, -0.0020,  0.0191, -0.0294, -0.0268, -0.0848, -0.1925, -0.0322,
         0.0199, -0.0363, -0.0210,  0.0505, -0.4062,  0.0842,  0.0570,  0.0236,
        -0.0049, -0.0291,  0.0443, -0.0729,  0.0069,  0.0616,  0.0599,  0.1005,
        -0.0626,  

In [9]:
model.layers[0].attention.wqkv.bias[]

SyntaxError: invalid syntax (1333213819.py, line 1)