In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from transformers.models.qwen2 import modeling_qwen2
from torch.nn.attention import SDPBackend, sdpa_kernel

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load model and tokenizer
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
tokenizer = AutoTokenizer.from_pretrained(model_name)


In [3]:
lm_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
lm_model = lm_model.to("cuda")

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 50.64it/s]


In [4]:
print(lm_model.config)

Qwen2Config {
  "architectures": [
    "Qwen2ForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "eos_token_id": 151643,
  "hidden_act": "silu",
  "hidden_size": 3584,
  "initializer_range": 0.02,
  "intermediate_size": 18944,
  "max_position_embeddings": 131072,
  "max_window_layers": 28,
  "model_type": "qwen2",
  "num_attention_heads": 28,
  "num_hidden_layers": 28,
  "num_key_value_heads": 4,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000,
  "sliding_window": 4096,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.52.4",
  "use_cache": true,
  "use_mrope": false,
  "use_sliding_window": false,
  "vocab_size": 152064
}



In [5]:
lm_model

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(152064, 3584)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=3584, out_features=3584, bias=True)
          (k_proj): Linear(in_features=3584, out_features=512, bias=True)
          (v_proj): Linear(in_features=3584, out_features=512, bias=True)
          (o_proj): Linear(in_features=3584, out_features=3584, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (up_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (down_proj): Linear(in_features=18944, out_features=3584, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((3584,), eps=1e-06)
    (rotary_emb):

In [6]:
text = "The fifth smallest factor of 2012 is"
inputs = tokenizer(text, return_tensors="pt")
inputs = inputs["input_ids"].to("cuda")

In [7]:
inputs

tensor([[151646,    785,  17702,  24632,   8168,    315,    220,     17,     15,
             16,     17,    374]], device='cuda:0')

In [8]:
for i in range(28):
    print(lm_model.model.layers[i])

Qwen2DecoderLayer(
  (self_attn): Qwen2Attention(
    (q_proj): Linear(in_features=3584, out_features=3584, bias=True)
    (k_proj): Linear(in_features=3584, out_features=512, bias=True)
    (v_proj): Linear(in_features=3584, out_features=512, bias=True)
    (o_proj): Linear(in_features=3584, out_features=3584, bias=False)
  )
  (mlp): Qwen2MLP(
    (gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
    (up_proj): Linear(in_features=3584, out_features=18944, bias=False)
    (down_proj): Linear(in_features=18944, out_features=3584, bias=False)
    (act_fn): SiLU()
  )
  (input_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
  (post_attention_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
)
Qwen2DecoderLayer(
  (self_attn): Qwen2Attention(
    (q_proj): Linear(in_features=3584, out_features=3584, bias=True)
    (k_proj): Linear(in_features=3584, out_features=512, bias=True)
    (v_proj): Linear(in_features=3584, out_features=512, bias=True)
    (o_proj): Linear(in_features=

In [9]:
if not hasattr(modeling_qwen2, "_orig_apply_rope"):
    modeling_qwen2._orig_apply_rope = modeling_qwen2.apply_rotary_pos_emb  # keep a handle

saved_hf_cos = None
saved_hf_sin = None

q_out_arr, k_out_arr = [], []
def _rope_debug_hf(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    global saved_hf_cos, saved_hf_sin, q_out_arr, k_out_arr
    if saved_hf_cos is None or sin.shape[1] > saved_hf_cos.shape[1]:
        saved_hf_cos = cos.clone()
        saved_hf_sin = sin.clone()
    
    print(f"\n[HF] RoPE args:")
    print("q.shape       =", tuple(q.shape), " q[0,0,:2,:4] =", q[0, 0, :2, :4])
    print("k.shape       =", tuple(k.shape), " k[0,0,:2,:4] =", k[0, 0, :2, :4])
    print("cos.shape     =", tuple(cos.shape), " cos[0,:4]   =", cos[0, :4])
    print("sin.shape     =", tuple(sin.shape), " sin[0,:4]   =", sin[0, :4])
    print("position_ids  =", position_ids if position_ids is not None else "None")
    q_out, k_out = modeling_qwen2._orig_apply_rope(
        q, k, cos, sin, position_ids, unsqueeze_dim
    )
    q_out_arr.append(q_out)
    k_out_arr.append(k_out)
    print("\n[HF] q after RoPE   :", q_out[0, 0, :2, :8])
    print("[HF] k after RoPE   :", k_out[0, 0, :2, :8])
    return q_out, k_out

modeling_qwen2.apply_rotary_pos_emb = _rope_debug_hf

In [10]:
import torch.nn.functional as F


if not hasattr(F, "_orig_sdpa"):
    F._orig_sdpa = F.scaled_dot_product_attention

q_sdpa_out_arr, k_sdpa_out_arr, v_sdpa_out_arr, sdpa_out_arr = [], [], [], []

def _debug_sdpa(q, k, v, attn_mask=None, dropout_p=0.0, scale=None, is_causal=False):
    print("\n[SDPA] Input Shapes:")
    print("q.shape =", q.shape, " q[0, 0, :2, :4] =", q[0, 0, :2, :4])
    print("k.shape =", k.shape, " k[0, 0, :2, :4] =", k[0, 0, :2, :4])
    print("v.shape =", v.shape, " v[0, 0, :2, :4] =", v[0, 0, :2, :4])
    print("attn_mask =", "None" if attn_mask is None else attn_mask.shape)
    print("dropout_p =", dropout_p)
    print("is_causal =", is_causal)
    if scale is not None:
        print("scale =", scale)
    else:
        print("scale = None (default scaling)")

    q_sdpa_out_arr.append(q.clone())
    k_sdpa_out_arr.append(k.clone())
    v_sdpa_out_arr.append(v.clone())
    with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
        out = F._orig_sdpa(q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)

    print("[SDPA] Output[0, 0, :2, :4] =", out[0, 0, :2, :4])
    sdpa_out_arr.append(out.clone())

    return out

F.scaled_dot_product_attention = _debug_sdpa


In [11]:

# Dictionary to store activations
activations_hf = {}

def get_activation(name):
    def hook(model, input, output):
        #print(name, output, type(output))
        if("layer" in name):
            activations_hf[name] = output[0].detach()  # For transformer layers, output is a tuple (last_hidden_state, past_key_values)
        else:
            activations_hf[name] = output.detach()
    return hook

# Register hooks for all transformer layers
# DeepSeek-R1-Distill-Qwen-7B has 32 layers (layers 0-31)
for i in range(28):
    lm_model.model.layers[i].register_forward_hook(get_activation(f'layer_{i}'))
    lm_model.model.layers[i].self_attn.register_forward_hook(get_activation(f'layer_{i}_attention'))
    lm_model.model.layers[i].input_layernorm.register_forward_hook(get_activation(f'layer_{i}_input_layernorm'))
    lm_model.model.layers[i].self_attn.q_proj.register_forward_hook(get_activation(f'layer_{i}_q_proj'))
    lm_model.model.layers[i].self_attn.k_proj.register_forward_hook(get_activation(f'layer_{i}_k_proj'))
    lm_model.model.layers[i].self_attn.v_proj.register_forward_hook(get_activation(f'layer_{i}_v_proj'))

"""
# You can also register hooks for specific components within layers
# For example, attention and MLP outputs:
for i in range(28):
    model.layers[i].self_attn.register_forward_hook(get_activation(f'layer_{i}_attention'))
    model.layers[i].mlp.register_forward_hook(get_activation(f'layer_{i}_mlp'))
"""
# Register hook for embeddings
lm_model.model.embed_tokens.register_forward_hook(get_activation('embeddings'))

# Forward pass
outputs = lm_model(input_ids=inputs)

# Access activations
print(f"Embeddings shape: {activations_hf['embeddings'].shape}")
print(f"Layer 0 shape: {activations_hf['layer_0'][0].shape}")  # Note: output is tuple, take first element
print(f"Layer 15 shape: {activations_hf['layer_15'][0].shape}")
print(f"Layer 27 shape: {activations_hf['layer_27'][0].shape}")

# Print all available activation keys
print("\nAll captured activations:")
for key in activations_hf.keys():
    if isinstance(activations_hf[key], tuple):
        print(f"{key}: {activations_hf[key][0].shape}")
    else:
        print(f"{key}: {activations_hf[key].shape}")


[HF] RoPE args:
q.shape       = (1, 28, 12, 128)  q[0,0,:2,:4] = tensor([[ 1.6641,  3.5312, -0.9570,  3.0000],
        [ 0.6055,  1.7969, -0.3477,  1.7656]], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
k.shape       = (1, 4, 12, 128)  k[0,0,:2,:4] = tensor([[-0.4316,  1.1250,  0.5703,  0.3926],
        [-0.7812, -0.2080,  1.9688, -0.3086]], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
cos.shape     = (1, 12, 128)  cos[0,:4]   = tensor([[ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
          1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
          1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
          1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
          1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
          1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
          1.0000,  1.0000,

In [12]:
activations_hf

{'embeddings': tensor([[[-0.0006,  0.0005, -0.0049,  ..., -0.0001, -0.0004, -0.0010],
          [ 0.0168, -0.0157,  0.0140,  ..., -0.0033,  0.0210, -0.0420],
          [ 0.0033, -0.0175, -0.0221,  ..., -0.0092,  0.0140,  0.0275],
          ...,
          [ 0.0041,  0.0007,  0.0008,  ..., -0.0002, -0.0049, -0.0045],
          [-0.0042, -0.0027, -0.0135,  ..., -0.0029, -0.0064,  0.0028],
          [ 0.0098, -0.0019, -0.0090,  ..., -0.0132, -0.0011,  0.0065]]],
        device='cuda:0', dtype=torch.bfloat16),
 'layer_0_input_layernorm': tensor([[-0.0330,  0.0282, -0.2754,  ..., -0.0080, -0.0244, -0.0520],
         [ 0.1216, -0.1216,  0.0996,  ..., -0.0236,  0.1504, -0.2793],
         [ 0.0240, -0.1348, -0.1562,  ..., -0.0654,  0.1001,  0.1816],
         ...,
         [ 0.0359,  0.0067,  0.0070,  ..., -0.0021, -0.0432, -0.0366],
         [-0.0364, -0.0254, -0.1147,  ..., -0.0247, -0.0549,  0.0222],
         [ 0.0767, -0.0156, -0.0703,  ..., -0.1025, -0.0082,  0.0474]],
        device='cuda:

In [None]:
import torch


for i in range(28):
    is_equal = torch.allclose(q_out_arr[i], q_out_arr_custom[i], atol=1e-2)
    print(f"Layer {i} q_out match:", is_equal)


Layer 0 q_out match: True
Layer 1 q_out match: False
Layer 2 q_out match: False
Layer 3 q_out match: False
Layer 4 q_out match: False
Layer 5 q_out match: False
Layer 6 q_out match: False
Layer 7 q_out match: False
Layer 8 q_out match: False
Layer 9 q_out match: False
Layer 10 q_out match: False
Layer 11 q_out match: False
Layer 12 q_out match: False
Layer 13 q_out match: False
Layer 14 q_out match: False
Layer 15 q_out match: False
Layer 16 q_out match: False
Layer 17 q_out match: False
Layer 18 q_out match: False
Layer 19 q_out match: False
Layer 20 q_out match: False
Layer 21 q_out match: False
Layer 22 q_out match: False
Layer 23 q_out match: False
Layer 24 q_out match: False
Layer 25 q_out match: False
Layer 26 q_out match: False
Layer 27 q_out match: False


In [None]:
print(q_out_arr[1])

tensor([[[[ 2.0410e-01,  5.3438e+00, -6.8750e-01,  ...,  4.1504e-02,
            1.2109e-01, -4.5166e-02],
          [ 3.3750e+00,  4.0625e+00,  6.1719e-01,  ...,  1.2012e-01,
            2.8125e-01, -1.6309e-01],
          [ 4.4375e+00, -1.2422e+00,  1.7109e+00,  ..., -1.1768e-01,
            2.4512e-01, -9.2285e-02],
          ...,
          [-7.4219e-02,  4.1016e-02,  4.3945e-01,  ...,  5.2734e-02,
            2.0215e-01, -1.5137e-01],
          [-1.9688e+00, -1.7656e+00,  2.0781e+00,  ...,  8.9844e-02,
            2.9688e-01, -2.0386e-02],
          [-2.2969e+00, -2.8750e+00,  2.2500e+00,  ..., -8.3984e-02,
            9.9609e-02,  4.4141e-01]],

         [[-1.7109e+00,  2.0605e-01, -1.6641e+00,  ...,  2.6250e+00,
            6.2891e-01, -1.2969e+00],
          [-9.2188e-01,  1.2188e+00, -1.5469e+00,  ...,  2.5938e+00,
            2.8320e-01, -1.1406e+00],
          [ 4.7461e-01,  2.9688e-01, -7.3047e-01,  ...,  1.6641e+00,
            1.9434e-01,  2.6562e-01],
          ...,
     

In [None]:
#print(q_out_arr_custom[1])
diff = torch.abs(q_out_arr[1] - q_out_arr_custom[1])
max_diff = diff.max()
print(f"Not equal at index {1}, max difference: {max_diff.item()}")
print("Indices with large diffs (>1e-2):", (diff > 1e-2).nonzero(as_tuple=True))
print("Values from q_out_arr:       ", q_out_arr[1][diff > 1e-2])
print("Values from q_out_arr_custom:", q_out_arr_custom[1][diff > 1e-2])



Not equal at index 1, max difference: 0.0625
Indices with large diffs (>1e-2): (tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
        0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], device='cuda:0'), tensor([ 0,  0,  0,  0,  0,  1,  1,  1,  2,  2,  2,  2,  2,  3,  4,  5,  5,  6,
         6,  6,  6,  6,  7,  8, 10, 10, 10, 11, 11, 12, 12, 12, 12, 12, 14, 15,
        16, 20, 21, 21, 21, 22, 22, 23, 23, 24, 24, 24, 25, 25, 25, 26, 26, 26,
        26, 27, 27, 27, 27], device='cuda:0'), tensor([ 3,  3,  3,  5,  7,  3,  3,  8,  8, 11, 11, 11, 11, 11,  3,  3,  9,  3,
         5,  5,  9, 11,  7,  7,  4,  5,  5,  5, 11,  5,  6,  8,  9, 10,  6,  9,
        11,  3,  5,  5,  8,  5,  7,  3, 11,  5,  5, 11,  6, 11, 11,  3,  7, 11,
        11,  3,  3,  3,  8], device='cuda:0'), tensor([  8,  45,  72, 111,  47,  38, 108,  38,  48,  15,  33,  40,  79,   0,
         59,  33,  48,  52,  49, 120,  48,  40, 

In [None]:
activations_hf

{'embeddings': tensor([[[-0.0006,  0.0005, -0.0049,  ..., -0.0001, -0.0004, -0.0010],
          [ 0.0168, -0.0157,  0.0140,  ..., -0.0033,  0.0210, -0.0420],
          [ 0.0033, -0.0175, -0.0221,  ..., -0.0092,  0.0140,  0.0275],
          ...,
          [ 0.0041,  0.0007,  0.0008,  ..., -0.0002, -0.0049, -0.0045],
          [-0.0042, -0.0027, -0.0135,  ..., -0.0029, -0.0064,  0.0028],
          [ 0.0098, -0.0019, -0.0090,  ..., -0.0132, -0.0011,  0.0065]]],
        device='cuda:0', dtype=torch.bfloat16),
 'layer_0_input_layernorm': tensor([[-0.0330,  0.0282, -0.2754,  ..., -0.0080, -0.0244, -0.0520],
         [ 0.1216, -0.1216,  0.0996,  ..., -0.0236,  0.1504, -0.2793],
         [ 0.0240, -0.1348, -0.1562,  ..., -0.0654,  0.1001,  0.1816],
         ...,
         [ 0.0359,  0.0067,  0.0070,  ..., -0.0021, -0.0432, -0.0366],
         [-0.0364, -0.0254, -0.1147,  ..., -0.0247, -0.0549,  0.0222],
         [ 0.0767, -0.0156, -0.0703,  ..., -0.1025, -0.0082,  0.0474]],
        device='cuda:

In [None]:
import torch


for i in range(28):
    is_equal = torch.allclose(sdpa_out_arr_custom[i], v_sdpa_out_arr_custom[i][:, :, :12, :], atol=1e-2)
    print(f"Layer {i} q_out match:", is_equal)
    diff = torch.abs(v_sdpa_out_arr[i] - v_sdpa_out_arr_custom[i][:, :, :12, :])
    max_diff = diff.max()
    print(f"Not equal at index {i}, max difference: {max_diff.item()}")


Layer 0 q_out match: True
Not equal at index 0, max difference: 0.0078125
Layer 1 q_out match: True
Not equal at index 1, max difference: 0.00390625
Layer 2 q_out match: False
Not equal at index 2, max difference: 0.015625
Layer 3 q_out match: False
Not equal at index 3, max difference: 0.01171875
Layer 4 q_out match: False
Not equal at index 4, max difference: 0.05078125
Layer 5 q_out match: False
Not equal at index 5, max difference: 0.015625
Layer 6 q_out match: False
Not equal at index 6, max difference: 0.015625
Layer 7 q_out match: False
Not equal at index 7, max difference: 0.03125
Layer 8 q_out match: False
Not equal at index 8, max difference: 0.0234375
Layer 9 q_out match: False
Not equal at index 9, max difference: 0.03515625
Layer 10 q_out match: False
Not equal at index 10, max difference: 0.046875
Layer 11 q_out match: False
Not equal at index 11, max difference: 0.046875
Layer 12 q_out match: False
Not equal at index 12, max difference: 0.0625
Layer 13 q_out match: False

In [None]:
is_equal = torch.allclose(activations_hf["layer_27"], activations["layer_27_output"], atol=1e-2)
is_equal

True

In [None]:
activations["layer_0_v_proj"].dtype

torch.bfloat16

In [None]:
activations_hf

{'embeddings': tensor([[[-0.0006,  0.0005, -0.0049,  ..., -0.0001, -0.0004, -0.0010],
          [ 0.0168, -0.0157,  0.0140,  ..., -0.0033,  0.0210, -0.0420],
          [ 0.0033, -0.0175, -0.0221,  ..., -0.0092,  0.0140,  0.0275],
          ...,
          [ 0.0041,  0.0007,  0.0008,  ..., -0.0002, -0.0049, -0.0045],
          [-0.0042, -0.0027, -0.0135,  ..., -0.0029, -0.0064,  0.0028],
          [ 0.0098, -0.0019, -0.0090,  ..., -0.0132, -0.0011,  0.0065]]],
        device='cuda:0', dtype=torch.bfloat16),
 'layer_0_input_layernorm': tensor([[-0.0330,  0.0282, -0.2754,  ..., -0.0080, -0.0244, -0.0520],
         [ 0.1216, -0.1216,  0.0996,  ..., -0.0236,  0.1504, -0.2793],
         [ 0.0240, -0.1348, -0.1562,  ..., -0.0654,  0.1001,  0.1816],
         ...,
         [ 0.0359,  0.0067,  0.0070,  ..., -0.0021, -0.0432, -0.0366],
         [-0.0364, -0.0254, -0.1147,  ..., -0.0247, -0.0549,  0.0222],
         [ 0.0767, -0.0156, -0.0703,  ..., -0.1025, -0.0082,  0.0474]],
        device='cuda:

In [None]:
diff = torch.abs(activations_hf["layer_0_attention"] - activations["layer_0_attention"])
max_diff = diff.max()
max_diff

tensor(0.0156, device='cuda:0', dtype=torch.bfloat16)

In [None]:
logits = outputs.logits
last_token_logits = logits[0, -1, :] 
predicted_token_id = torch.argmax(last_token_logits, dim=-1)
predicted_token = tokenizer.decode(predicted_token_id)
print(f"\nPredicted next token ID: {predicted_token_id}")
print(f"Predicted next token: '{predicted_token}'")


Predicted next token ID: 23754
Predicted next token: ' ?

'


In [None]:
embeddings_hf = activations['embeddings']

In [None]:
print("Embeddings:", activations['embeddings'], activations['embeddings'].shape, activations['embeddings'].dtype)

Embeddings: tensor([[[-0.0006,  0.0005, -0.0049,  ..., -0.0001, -0.0004, -0.0010],
         [-0.0060, -0.0042,  0.0084,  ...,  0.0510,  0.0008, -0.0086],
         [-0.0225, -0.0293,  0.0107,  ..., -0.0100,  0.0104, -0.0571]]],
       device='cuda:0', dtype=torch.bfloat16) torch.Size([1, 3, 3584]) torch.bfloat16


In [None]:
print("Layer 0:", activations['layer_0'])

Layer 0: tensor([[[-1.6078,  1.7398,  0.9056,  ...,  2.0966, -0.3973, -0.7354],
         [-0.4516,  1.4271,  0.0635,  ...,  0.9967,  0.3587, -1.0365],
         [ 0.3759,  0.4770,  0.5341,  ...,  0.9871,  0.5478,  0.1407]]],
       device='cuda:0')


In [None]:
print("Layer 0:", activations['layer_0_attention'])

Layer 0: tensor([[[-0.4713,  0.8825,  0.1737,  ...,  0.5539, -0.0437, -0.2646],
         [ 0.0797,  0.8864, -0.3313,  ...,  0.1187,  0.4419, -0.5363],
         [ 0.0925,  0.0708,  0.4678,  ...,  0.3931,  0.7030, -0.1090]]],
       device='cuda:0')


In [None]:
print("Layer 0:", activations['layer_0_input_layernorm'])

Layer 0: tensor([[-0.0330,  0.0282, -0.2753,  ..., -0.0080, -0.0245, -0.0518],
        [-0.0433, -0.0326,  0.0590,  ...,  0.3580,  0.0059, -0.0567],
        [-0.1573, -0.2197,  0.0736,  ..., -0.0686,  0.0723, -0.3675]],
       device='cuda:0')


In [1]:
import itertools
import sys
import time
from pathlib import Path
from typing import Optional, Tuple, Union 

import torch
import torch._dynamo.config
import torch._inductor.config
from torch.nn.attention.flex_attention import BlockMask, create_block_mask

from torch.nn import functional as F

In [2]:
from model import Transformer
from tokenizer import get_tokenizer
import model as qwen2_model

In [3]:
orig_rope_mine = qwen2_model.apply_rotary_pos_emb            # keep original

saved_custom_sin = None
saved_custom_cos = None

q_out_arr_custom, k_out_arr_custom = [], []
def _rope_debug_mine(q, k, cos, sin, unsqueeze_dim=1):
    global saved_custom_sin, saved_custom_cos, q_out_arr_custom, k_out_arr_custom
    print(f"\n[MINE] RoPE args:")
    print("q.shape       =", tuple(q.shape), " q[0,0,:2,:4] =", q[0, 0, :2, :4])
    print("k.shape       =", tuple(k.shape), " k[0,0,:2,:4] =", k[0, 0, :2, :4])
    print("cos.shape     =", tuple(cos.shape), " cos[0,:4]   =", cos[0, :4])
    print("sin.shape     =", tuple(sin.shape), " sin[0,:4]   =", sin[0, :4])
    saved_custom_sin = sin
    saved_custom_cos = cos
    q_out, k_out = orig_rope_mine(q, k, cos, sin, unsqueeze_dim)
    q_out_arr_custom.append(q_out)
    k_out_arr_custom.append(k_out)
    print("\n[MINE] q after RoPE :", q_out[0, 0, :2, :8])
    print("[MINE] k after RoPE :", k_out[0, 0, :2, :8])
    return q_out, k_out
qwen2_model.apply_rotary_pos_emb = _rope_debug_mine
qwen2_model.Attention.apply_rotary_pos_emb = staticmethod(_rope_debug_mine)

In [4]:
import torch.nn.functional as F

if not hasattr(F, "_orig_sdpa"):
    F._orig_sdpa = F.scaled_dot_product_attention

q_sdpa_out_arr_custom, k_sdpa_out_arr_custom, v_sdpa_out_arr_custom, sdpa_out_arr_custom = [], [], [], []

def _debug_sdpa(q, k, v, attn_mask=None, dropout_p=0.0, scale=None, is_causal=False):
    print("\n[SDPA] Input Shapes:")
    print("q", q.shape, q.dtype, q.device)
    print("k", k.shape, k.dtype, k.device)
    print("v", v.shape, v.dtype, v.device)
    if attn_mask is None:
        print("attn_mask =", "None" )
    else:
        print("attn_mask =", attn_mask.shape, " attn_mask[0, :2, :4] =", attn_mask)
    
    print("dropout_p =", dropout_p)
    print("is_causal =", is_causal)
    if scale is not None:
        print("scale =", scale)
    else:
        print("scale = None (default scaling)")

    q_sdpa_out_arr_custom.append(q.clone())
    k_sdpa_out_arr_custom.append(k.clone())
    v_sdpa_out_arr_custom.append(v.clone())
    with sdpa_kernel(SDPBackend.MATH):
        out = F._orig_sdpa(q, k, v, attn_mask=attn_mask, dropout_p=dropout_p, is_causal=is_causal)
    # out = F._orig_sdpa(q, k, v, dropout_p=dropout_p, is_causal=True)
    print("[SDPA] Output[0, 0, :2, :4] =", out[0, 0, :2, :4])
    sdpa_out_arr_custom.append(out.clone())

    return out

F.scaled_dot_product_attention = _debug_sdpa


In [5]:
checkpoint_path = Path("checkpoints/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B/model.pth")
tokenizer_path = checkpoint_path.parent / "tokenizer.json"

In [6]:
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)

In [7]:
def _load_model(checkpoint_path, device, precision, use_tp):
    use_cuda = 'cuda' in device
    with torch.device('meta'):
        model = Transformer.from_name(checkpoint_path.parent.name)
    model = model.to_empty(device=torch.device(device))
    model = model.to(dtype=precision)
    if "int8" in str(checkpoint_path):
        print("Using int8 weight-only quantization!")
        from quantize import WeightOnlyInt8QuantHandler
        simple_quantizer = WeightOnlyInt8QuantHandler(model)
        model = simple_quantizer.convert_for_runtime()

    if "int4" in str(checkpoint_path):
        print("Using int4 weight-only quantization!")
        path_comps = checkpoint_path.name.split(".")
        groupsize = int(path_comps[-2][1:])
        from quantize import WeightOnlyInt4QuantHandler
        simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
        model = simple_quantizer.convert_for_runtime()

    checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
    if "model" in checkpoint and "stories" in str(checkpoint_path):
        checkpoint = checkpoint["model"]
    model.load_state_dict(checkpoint, assign=True)

    if use_tp:
        from tp import apply_tp
        print("Applying tensor parallel to model ...")
        apply_tp(model)
    model = model.to(device)
    return model.eval()

device = "cuda"
precision = torch.bfloat16
use_tp = False

model = _load_model(checkpoint_path, device, precision, use_tp)
with torch.device(device):
    model.setup_caches(max_batch_size=1, max_seq_length=212)
model.causal_mask.device

device(type='cuda', index=0)

In [8]:
model.causal_mask.shape

torch.Size([216, 216])

In [9]:
def encode_tokens(tokenizer, string, bos=True, device="cuda"):
    tokens = tokenizer.encode(string)
    if bos:
        tokens = [tokenizer.bos_id()] + tokens
    return torch.tensor(tokens, dtype=torch.int, device=device)
text = "The fifth smallest factor of 2012 is"
encoded = encode_tokens(tokenizer, text, bos=True, device=device)

In [10]:
prompt = encoded.view(1, -1).repeat(1, 1)
print(prompt)

tensor([[151646,    785,  17702,  24632,   8168,    315,    220,     17,     15,
             16,     17,    374]], device='cuda:0', dtype=torch.int32)


In [11]:
prompt_length = prompt.size(-1)
input_pos = torch.arange(0, prompt_length, device=device)

In [12]:
activations = {}

def get_activation(name):
    def hook(model, input, output):
        if isinstance(output, tuple):
            activations[name] = output[0].detach()
        else:
            activations[name] = output.detach()
    return hook

model.tok_embeddings.register_forward_hook(get_activation('embeddings'))
for i, block in enumerate(model.layers):
    block.register_forward_hook(get_activation(f'layer_{i}_output'))
    block.attention.register_forward_hook(get_activation(f'layer_{i}_attention'))
    block.attention_norm.register_forward_hook(get_activation(f'layer_{i}_attn_norm'))
    block.attention.q_proj.register_forward_hook(get_activation(f'layer_{i}_q_proj'))
    block.attention.k_proj.register_forward_hook(get_activation(f'layer_{i}_k_proj'))
    block.attention.v_proj.register_forward_hook(get_activation(f'layer_{i}_v_proj'))
    

In [13]:

def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
    q = torch.empty_like(probs_sort).exponential_(1)
    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)

def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
    logits = logits / max(temperature, 1e-5)

    if top_k is not None:
        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
        pivot = v.select(-1, -1).unsqueeze(-1)
        logits = torch.where(logits < pivot, -float("Inf"), logits)
    probs = torch.nn.functional.softmax(logits, dim=-1)
    return probs
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
    probs = logits_to_probs(logits[:, -1], temperature, top_k)
    idx_next = multinomial_sample_one_no_sync(probs)
    return idx_next, probs
def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
    # input_pos: [B, S]
    logits = model(x, input_pos)
    return sample(logits, **sampling_kwargs)[0]
print("shapes", prompt.shape, input_pos.shape)
next_token = prefill(model, prompt, input_pos, temperature=0, top_k=1).clone()

shapes torch.Size([1, 12]) torch.Size([12])
cuda:0
cuda:0



[MINE] RoPE args:
q.shape       = (1, 28, 12, 128)  q[0,0,:2,:4] = tensor([[ 1.6641,  3.5312, -0.9570,  3.0000],
        [ 0.6055,  1.7969, -0.3477,  1.7656]], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
k.shape       = (1, 4, 12, 128)  k[0,0,:2,:4] = tensor([[-0.4316,  1.1250,  0.5703,  0.3926],
        [-0.7812, -0.2080,  1.9688, -0.3086]], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
cos.shape     = (1, 131072, 128)  cos[0,:4]   = tensor([[ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
          1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
          1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
          1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
          1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
          1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
          1.0000,  1

NameError: name 'sdpa_kernel' is not defined

In [None]:
print(saved_custom_sin.dtype)

torch.bfloat16


In [None]:
activations

{'embeddings': tensor([[[-0.0006,  0.0005, -0.0049,  ..., -0.0001, -0.0004, -0.0010],
          [ 0.0168, -0.0157,  0.0140,  ..., -0.0033,  0.0210, -0.0420],
          [ 0.0033, -0.0175, -0.0221,  ..., -0.0092,  0.0140,  0.0275],
          ...,
          [ 0.0041,  0.0007,  0.0008,  ..., -0.0002, -0.0049, -0.0045],
          [-0.0042, -0.0027, -0.0135,  ..., -0.0029, -0.0064,  0.0028],
          [ 0.0098, -0.0019, -0.0090,  ..., -0.0132, -0.0011,  0.0065]]],
        device='cuda:0', dtype=torch.bfloat16),
 'layer_0_attn_norm': tensor([[[-0.0330,  0.0282, -0.2754,  ..., -0.0080, -0.0244, -0.0520],
          [ 0.1216, -0.1216,  0.0996,  ..., -0.0236,  0.1504, -0.2793],
          [ 0.0240, -0.1348, -0.1562,  ..., -0.0654,  0.1001,  0.1816],
          ...,
          [ 0.0359,  0.0067,  0.0070,  ..., -0.0021, -0.0432, -0.0366],
          [-0.0364, -0.0254, -0.1147,  ..., -0.0247, -0.0549,  0.0222],
          [ 0.0767, -0.0156, -0.0703,  ..., -0.1025, -0.0082,  0.0474]]],
        device='cud

In [None]:
activations['layer_27_output'].shape, activations_hf['']

In [None]:
next_token.shape

torch.Size([1, 1])

In [None]:
input_pos = torch.tensor([prompt_length], device=device, dtype=torch.int)

In [None]:
logits = model(next_token, input_pos)

3 1

[MINE] RoPE args:
q.shape       = (1, 28, 1, 128)  q[0,0,:2,:4] = tensor([[0.4082, 2.1250, 0.2441, 1.2969]], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
k.shape       = (1, 4, 1, 128)  k[0,0,:2,:4] = tensor([[-1.1797,  0.7305,  1.3672, -0.1553]], device='cuda:0',
       dtype=torch.bfloat16, grad_fn=<SliceBackward0>)
cos.shape     = (1, 4, 128)  cos[0,:4]   = tensor([[ 1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
          1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
          1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
          1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
          1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
          1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
          1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,  1.0000,
          1.0000,  1.0000,  1.0000,  1

In [None]:
next_token, next_prob = sample(logits, temperature=0, top_k=1)

In [None]:
next_token.shape
next_token

tensor([[32313]], device='cuda:0', dtype=torch.int32)

In [None]:
print(tokenizer.decode(next_token[0].cpu().numpy()))

!




In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

# Load model and tokenizer
model_id = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"  # or another chat-based model
tokenizer = AutoTokenizer.from_pretrained(model_id)
model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")

# Define your prompt
prompt = "What is the capital of France?"

# Construct the chat template messages
messages = [
    {"role": "system", "content": "You are a helpful assistant."},
    {"role": "user", "content": prompt}
]

# Use the tokenizer's chat template (requires tokenizer config to include `chat_template`)
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt")

# Run generation
outputs = model.generate(input_ids=input_ids, max_new_tokens=100)
print(tokenizer.decode(outputs[0], skip_special_tokens=True))