In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load model and tokenizer
model_name = "deepseek-ai/DeepSeek-R1-Distill-Qwen-7B"
tokenizer = AutoTokenizer.from_pretrained(model_name)


In [3]:
lm_model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
lm_model = lm_model.to("cuda")

Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  8.51it/s]


In [29]:
print(lm_model.config)

Qwen2Config {
  "architectures": [
    "Qwen2ForCausalLM"
  ],
  "attention_dropout": 0.0,
  "bos_token_id": 151643,
  "eos_token_id": 151643,
  "hidden_act": "silu",
  "hidden_size": 3584,
  "initializer_range": 0.02,
  "intermediate_size": 18944,
  "max_position_embeddings": 131072,
  "max_window_layers": 28,
  "model_type": "qwen2",
  "num_attention_heads": 28,
  "num_hidden_layers": 28,
  "num_key_value_heads": 4,
  "rms_norm_eps": 1e-06,
  "rope_scaling": null,
  "rope_theta": 10000,
  "sliding_window": 4096,
  "tie_word_embeddings": false,
  "torch_dtype": "bfloat16",
  "transformers_version": "4.52.4",
  "use_cache": true,
  "use_mrope": false,
  "use_sliding_window": false,
  "vocab_size": 152064
}



In [4]:
lm_model

Qwen2ForCausalLM(
  (model): Qwen2Model(
    (embed_tokens): Embedding(152064, 3584)
    (layers): ModuleList(
      (0-27): 28 x Qwen2DecoderLayer(
        (self_attn): Qwen2Attention(
          (q_proj): Linear(in_features=3584, out_features=3584, bias=True)
          (k_proj): Linear(in_features=3584, out_features=512, bias=True)
          (v_proj): Linear(in_features=3584, out_features=512, bias=True)
          (o_proj): Linear(in_features=3584, out_features=3584, bias=False)
        )
        (mlp): Qwen2MLP(
          (gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (up_proj): Linear(in_features=3584, out_features=18944, bias=False)
          (down_proj): Linear(in_features=18944, out_features=3584, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
        (post_attention_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
      )
    )
    (norm): Qwen2RMSNorm((3584,), eps=1e-06)
    (rotary_emb):

In [5]:
text = "Hello world"
inputs = tokenizer(text, return_tensors="pt")
inputs = inputs["input_ids"].to("cuda")

In [6]:
inputs

tensor([[151646,   9707,   1879]], device='cuda:0')

In [7]:
for i in range(28):
    print(lm_model.model.layers[i])

Qwen2DecoderLayer(
  (self_attn): Qwen2Attention(
    (q_proj): Linear(in_features=3584, out_features=3584, bias=True)
    (k_proj): Linear(in_features=3584, out_features=512, bias=True)
    (v_proj): Linear(in_features=3584, out_features=512, bias=True)
    (o_proj): Linear(in_features=3584, out_features=3584, bias=False)
  )
  (mlp): Qwen2MLP(
    (gate_proj): Linear(in_features=3584, out_features=18944, bias=False)
    (up_proj): Linear(in_features=3584, out_features=18944, bias=False)
    (down_proj): Linear(in_features=18944, out_features=3584, bias=False)
    (act_fn): SiLU()
  )
  (input_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
  (post_attention_layernorm): Qwen2RMSNorm((3584,), eps=1e-06)
)
Qwen2DecoderLayer(
  (self_attn): Qwen2Attention(
    (q_proj): Linear(in_features=3584, out_features=3584, bias=True)
    (k_proj): Linear(in_features=3584, out_features=512, bias=True)
    (v_proj): Linear(in_features=3584, out_features=512, bias=True)
    (o_proj): Linear(in_features=

In [28]:
print(lm_model.model.layers[0].self_attn.rotary_emb)

AttributeError: 'Qwen2Attention' object has no attribute 'rotary_emb'

In [8]:

# Dictionary to store activations
activations = {}

def get_activation(name):
    def hook(model, input, output):
        #print(name, output, type(output))
        if("layer" in name):
            activations[name] = output[0].detach()  # For transformer layers, output is a tuple (last_hidden_state, past_key_values)
        else:
            activations[name] = output.detach()
    return hook

# Register hooks for all transformer layers
# DeepSeek-R1-Distill-Qwen-7B has 32 layers (layers 0-31)
for i in range(28):
    lm_model.model.layers[i].register_forward_hook(get_activation(f'layer_{i}'))
    lm_model.model.layers[i].self_attn.register_forward_hook(get_activation(f'layer_{i}_attention'))
    lm_model.model.layers[i].input_layernorm.register_forward_hook(get_activation(f'layer_{i}_input_layernorm'))
    lm_model.model.layers[i].self_attn.q_proj.register_forward_hook(get_activation(f'layer_{i}_q_proj'))
    lm_model.model.layers[i].self_attn.k_proj.register_forward_hook(get_activation(f'layer_{i}_k_proj'))
    lm_model.model.layers[i].self_attn.v_proj.register_forward_hook(get_activation(f'layer_{i}_v_proj'))

"""
# You can also register hooks for specific components within layers
# For example, attention and MLP outputs:
for i in range(28):
    model.layers[i].self_attn.register_forward_hook(get_activation(f'layer_{i}_attention'))
    model.layers[i].mlp.register_forward_hook(get_activation(f'layer_{i}_mlp'))
"""
# Register hook for embeddings
lm_model.model.embed_tokens.register_forward_hook(get_activation('embeddings'))

# Forward pass
outputs = lm_model(input_ids=inputs)

# Access activations
print(f"Embeddings shape: {activations['embeddings'].shape}")
print(f"Layer 0 shape: {activations['layer_0'][0].shape}")  # Note: output is tuple, take first element
print(f"Layer 15 shape: {activations['layer_15'][0].shape}")
print(f"Layer 27 shape: {activations['layer_27'][0].shape}")

# Print all available activation keys
print("\nAll captured activations:")
for key in activations.keys():
    if isinstance(activations[key], tuple):
        print(f"{key}: {activations[key][0].shape}")
    else:
        print(f"{key}: {activations[key].shape}")

Embeddings shape: torch.Size([1, 3, 3584])
Layer 0 shape: torch.Size([3, 3584])
Layer 15 shape: torch.Size([3, 3584])
Layer 27 shape: torch.Size([3, 3584])

All captured activations:
embeddings: torch.Size([1, 3, 3584])
layer_0_input_layernorm: torch.Size([3, 3584])
layer_0_q_proj: torch.Size([3, 3584])
layer_0_k_proj: torch.Size([3, 512])
layer_0_v_proj: torch.Size([3, 512])
layer_0_attention: torch.Size([1, 3, 3584])
layer_0: torch.Size([1, 3, 3584])
layer_1_input_layernorm: torch.Size([3, 3584])
layer_1_q_proj: torch.Size([3, 3584])
layer_1_k_proj: torch.Size([3, 512])
layer_1_v_proj: torch.Size([3, 512])
layer_1_attention: torch.Size([1, 3, 3584])
layer_1: torch.Size([1, 3, 3584])
layer_2_input_layernorm: torch.Size([3, 3584])
layer_2_q_proj: torch.Size([3, 3584])
layer_2_k_proj: torch.Size([3, 512])
layer_2_v_proj: torch.Size([3, 512])
layer_2_attention: torch.Size([1, 3, 3584])
layer_2: torch.Size([1, 3, 3584])
layer_3_input_layernorm: torch.Size([3, 3584])
layer_3_q_proj: torch.

In [9]:
print(activations.keys())

dict_keys(['embeddings', 'layer_0_input_layernorm', 'layer_0_q_proj', 'layer_0_k_proj', 'layer_0_v_proj', 'layer_0_attention', 'layer_0', 'layer_1_input_layernorm', 'layer_1_q_proj', 'layer_1_k_proj', 'layer_1_v_proj', 'layer_1_attention', 'layer_1', 'layer_2_input_layernorm', 'layer_2_q_proj', 'layer_2_k_proj', 'layer_2_v_proj', 'layer_2_attention', 'layer_2', 'layer_3_input_layernorm', 'layer_3_q_proj', 'layer_3_k_proj', 'layer_3_v_proj', 'layer_3_attention', 'layer_3', 'layer_4_input_layernorm', 'layer_4_q_proj', 'layer_4_k_proj', 'layer_4_v_proj', 'layer_4_attention', 'layer_4', 'layer_5_input_layernorm', 'layer_5_q_proj', 'layer_5_k_proj', 'layer_5_v_proj', 'layer_5_attention', 'layer_5', 'layer_6_input_layernorm', 'layer_6_q_proj', 'layer_6_k_proj', 'layer_6_v_proj', 'layer_6_attention', 'layer_6', 'layer_7_input_layernorm', 'layer_7_q_proj', 'layer_7_k_proj', 'layer_7_v_proj', 'layer_7_attention', 'layer_7', 'layer_8_input_layernorm', 'layer_8_q_proj', 'layer_8_k_proj', 'layer_8

In [10]:
activations

{'embeddings': tensor([[[-0.0006,  0.0005, -0.0049,  ..., -0.0001, -0.0004, -0.0010],
          [-0.0060, -0.0042,  0.0084,  ...,  0.0510,  0.0008, -0.0086],
          [-0.0225, -0.0293,  0.0107,  ..., -0.0100,  0.0104, -0.0571]]],
        device='cuda:0', dtype=torch.bfloat16),
 'layer_0_input_layernorm': tensor([[-0.0330,  0.0282, -0.2754,  ..., -0.0080, -0.0244, -0.0520],
         [-0.0435, -0.0325,  0.0591,  ...,  0.3574,  0.0059, -0.0566],
         [-0.1572, -0.2197,  0.0737,  ..., -0.0688,  0.0723, -0.3652]],
        device='cuda:0', dtype=torch.bfloat16),
 'layer_0_q_proj': tensor([[ 1.6641,  3.5312, -0.9570,  ...,  0.3555, -0.1641,  0.5820],
         [ 1.5234,  2.8438, -0.7656,  ..., -1.0703, -0.7930,  0.3828],
         [ 0.6250,  2.1250, -1.3750,  ...,  0.1187, -0.3633,  0.3320]],
        device='cuda:0', dtype=torch.bfloat16),
 'layer_0_k_proj': tensor([[-0.4316,  1.1250,  0.5703,  ...,  2.9688, -5.7812,  8.5000],
         [-0.9180,  1.1875,  0.5117,  ...,  3.4219, -6.4062,  

In [14]:
outputs

CausalLMOutputWithPast(loss=None, logits=tensor([[[-0.2520,  1.5391,  0.8359,  ..., -2.0625, -2.0625, -2.0625],
         [ 0.4375,  2.5938,  1.3359,  ..., -2.4062, -2.4062, -2.4062],
         [25.3750,  9.1875,  3.6094,  ...,  2.3438,  2.3438,  2.3438]]],
       device='cuda:0', dtype=torch.bfloat16, grad_fn=<UnsafeViewBackward0>), past_key_values=<transformers.cache_utils.DynamicCache object at 0x76fd0cd03d00>, hidden_states=None, attentions=None)

In [17]:
logits = outputs.logits
last_token_logits = logits[0, -1, :] 
predicted_token_id = torch.argmax(last_token_logits, dim=-1)
predicted_token = tokenizer.decode(predicted_token_id)
print(f"\nPredicted next token ID: {predicted_token_id}")
print(f"Predicted next token: '{predicted_token}'")


Predicted next token ID: 0
Predicted next token: '!'


In [12]:
embeddings_hf = activations['embeddings']

In [17]:
print("Embeddings:", activations['embeddings'], activations['embeddings'].shape, activations['embeddings'].dtype)

Embeddings: tensor([[[-0.0006,  0.0005, -0.0049,  ..., -0.0001, -0.0004, -0.0010],
         [-0.0060, -0.0042,  0.0084,  ...,  0.0510,  0.0008, -0.0086],
         [ 0.0299,  0.0116,  0.0133,  ..., -0.0177, -0.0156, -0.0669]]],
       device='cuda:0', dtype=torch.bfloat16) torch.Size([1, 3, 3584]) torch.bfloat16


In [18]:
print("Layer 0:", activations['layer_0'])

Layer 0: tensor([[[-1.6078,  1.7398,  0.9056,  ...,  2.0966, -0.3973, -0.7354],
         [-0.4516,  1.4271,  0.0635,  ...,  0.9967,  0.3587, -1.0365],
         [ 0.3759,  0.4770,  0.5341,  ...,  0.9871,  0.5478,  0.1407]]],
       device='cuda:0')


In [19]:
print("Layer 0:", activations['layer_0_attention'])

Layer 0: tensor([[[-0.4713,  0.8825,  0.1737,  ...,  0.5539, -0.0437, -0.2646],
         [ 0.0797,  0.8864, -0.3313,  ...,  0.1187,  0.4419, -0.5363],
         [ 0.0925,  0.0708,  0.4678,  ...,  0.3931,  0.7030, -0.1090]]],
       device='cuda:0')


In [20]:
print("Layer 0:", activations['layer_0_input_layernorm'])

Layer 0: tensor([[-0.0330,  0.0282, -0.2753,  ..., -0.0080, -0.0245, -0.0518],
        [-0.0433, -0.0326,  0.0590,  ...,  0.3580,  0.0059, -0.0567],
        [-0.1573, -0.2197,  0.0736,  ..., -0.0686,  0.0723, -0.3675]],
       device='cuda:0')


In [1]:
import itertools
import sys
import time
from pathlib import Path
from typing import Optional, Tuple, Union

import torch
import torch._dynamo.config
import torch._inductor.config
from torch.nn.attention.flex_attention import BlockMask, create_block_mask

from torch.nn import functional as F

In [2]:
from model import Transformer
from tokenizer import get_tokenizer

In [3]:
checkpoint_path = Path("checkpoints/deepseek-ai/DeepSeek-R1-Distill-Qwen-7B/model.pth")
tokenizer_path = checkpoint_path.parent / "tokenizer.json"

In [4]:
tokenizer = get_tokenizer(tokenizer_path, checkpoint_path)

In [None]:
def _load_model(checkpoint_path, device, precision, use_tp):
    use_cuda = 'cuda' in device
    model = Transformer.from_name(checkpoint_path.parent.name).to_empty(device=device, dtype=precision)

    if "int8" in str(checkpoint_path):
        print("Using int8 weight-only quantization!")
        from quantize import WeightOnlyInt8QuantHandler
        simple_quantizer = WeightOnlyInt8QuantHandler(model)
        model = simple_quantizer.convert_for_runtime()

    if "int4" in str(checkpoint_path):
        print("Using int4 weight-only quantization!")
        path_comps = checkpoint_path.name.split(".")
        groupsize = int(path_comps[-2][1:])
        from quantize import WeightOnlyInt4QuantHandler
        simple_quantizer = WeightOnlyInt4QuantHandler(model, groupsize)
        model = simple_quantizer.convert_for_runtime()

    checkpoint = torch.load(str(checkpoint_path), mmap=True, weights_only=True)
    if "model" in checkpoint and "stories" in str(checkpoint_path):
        checkpoint = checkpoint["model"]
    model.load_state_dict(checkpoint, assign=True)

    if use_tp:
        from tp import apply_tp
        print("Applying tensor parallel to model ...")
        apply_tp(model)

    model = model.to(device=device, dtype=precision)
    return model.eval()

device = "cuda"
precision = torch.bfloat16
use_tp = False

model = _load_model(checkpoint_path, device, precision, use_tp)
with torch.device(device):
    model.setup_caches(max_batch_size=1, max_seq_length=212)
model.causal_mask.device

NotImplementedError: Cannot copy out of meta tensor; no data! Please use torch.nn.Module.to_empty() instead of torch.nn.Module.to() when moving module from meta to a different device.

In [6]:
model.layers[0].attention_norm.weight

Parameter containing:
tensor([0.1797, 0.1924, 0.1768,  ..., 0.1758, 0.1777, 0.1650], device='cuda:0',
       dtype=torch.bfloat16, requires_grad=True)

In [7]:
model.tok_embeddings.weight.data

tensor([[ 5.6152e-03, -2.1606e-02, -8.9111e-03,  ..., -6.9580e-03,
          3.6621e-02, -1.3245e-02],
        [ 4.4922e-02,  4.7363e-02,  1.6602e-02,  ...,  7.2021e-03,
          2.9564e-04, -1.6235e-02],
        [-1.6724e-02, -6.4392e-03,  1.4400e-04,  ...,  1.0824e-04,
         -2.1973e-02,  2.3804e-03],
        ...,
        [-1.1755e-37,  1.1755e-37,  1.1755e-37,  ..., -1.1755e-37,
          1.1755e-37,  1.1755e-37],
        [ 1.1755e-37, -1.1755e-37,  1.1755e-37,  ..., -1.1755e-37,
          1.1755e-37, -1.1755e-37],
        [ 1.1755e-37, -1.1755e-37, -1.1755e-37,  ...,  1.1755e-37,
          1.1755e-37, -1.1755e-37]], device='cuda:0', dtype=torch.bfloat16)

In [8]:
def encode_tokens(tokenizer, string, bos=True, device="cuda"):
    tokens = tokenizer.encode(string)
    if bos:
        tokens = [tokenizer.bos_id()] + tokens
    return torch.tensor(tokens, dtype=torch.int, device=device)
encoded = encode_tokens(tokenizer, "Hello World", bos=True, device=device)

In [9]:
prompt = encoded.view(1, -1).repeat(1, 1)

In [10]:
print(prompt)

tensor([[151646,   9707,   4337]], device='cuda:0', dtype=torch.int32)


In [11]:
prompt[0][2] = 1879
prompt

tensor([[151646,   9707,   1879]], device='cuda:0', dtype=torch.int32)

In [12]:
prompt_length = prompt.size(-1)

In [13]:
input_pos = torch.arange(0, prompt_length, device=device)

In [14]:
input_pos.shape

torch.Size([3])

In [15]:
print(lm_model.model.layers[0].self_attn.q_proj.weight)
print(model.layers[0].attention.q_proj.weight)
print(torch.allclose(lm_model.model.layers[0].self_attn.q_proj.weight, model.layers[0].attention.q_proj.weight, rtol=1e-5, atol=1e-8))

NameError: name 'lm_model' is not defined

In [None]:
import torch

# Get the two tensors
a = lm_model.model.layers[0].self_attn.q_proj.weight
b = model.layers[0].attention.q_proj.weight

# Check if they are close numerically
print("Shapes:", a.shape, b.shape)
print("Same shape:", a.shape == b.shape)

# Check elementwise closeness
equal = torch.allclose(a, b, rtol=1e-5, atol=1e-8)
print("Allclose:", equal)

# If not equal, show diagnostic
if not equal:
    abs_diff = (a - b).abs()
    print("Max abs diff:", abs_diff.max().item())
    print("Mean abs diff:", abs_diff.mean().item())
    print("Std abs diff:", abs_diff.std().item())

    # Optionally, see where they differ most
    idx = torch.argmax(abs_diff)
    i, j = divmod(idx.item(), a.shape[1])
    print(f"Largest diff at position ({i}, {j}): {a[i,j].item()} vs {b[i,j].item()}")


Shapes: torch.Size([3584, 3584]) torch.Size([3584, 3584])
Same shape: True
Allclose: True


In [16]:
activations = {}

def get_activation(name):
    def hook(model, input, output):
        if isinstance(output, tuple):
            activations[name] = output[0].detach()
        else:
            activations[name] = output.detach()
    return hook

model.tok_embeddings.register_forward_hook(get_activation('embeddings'))
for i, block in enumerate(model.layers):
    block.register_forward_hook(get_activation(f'layer_{i}_output'))
    block.attention.register_forward_hook(get_activation(f'layer_{i}_attention'))
    block.attention_norm.register_forward_hook(get_activation(f'layer_{i}_attn_norm'))
    block.attention.q_proj.register_forward_hook(get_activation(f'layer_{i}_q_proj'))
    block.attention.k_proj.register_forward_hook(get_activation(f'layer_{i}_k_proj'))
    block.attention.v_proj.register_forward_hook(get_activation(f'layer_{i}_v_proj'))
    

In [17]:

def multinomial_sample_one_no_sync(probs_sort): # Does multinomial sampling without a cuda synchronization
    q = torch.empty_like(probs_sort).exponential_(1)
    return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)

def logits_to_probs(logits, temperature: float = 1.0, top_k: Optional[int] = None):
    logits = logits / max(temperature, 1e-5)

    if top_k is not None:
        v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
        pivot = v.select(-1, -1).unsqueeze(-1)
        logits = torch.where(logits < pivot, -float("Inf"), logits)
    probs = torch.nn.functional.softmax(logits, dim=-1)
    return probs
def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):
    probs = logits_to_probs(logits[:, -1], temperature, top_k)
    idx_next = multinomial_sample_one_no_sync(probs)
    return idx_next, probs
def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
    # input_pos: [B, S]
    logits = model(x, input_pos)
    return sample(logits, **sampling_kwargs)[0]
print("shapes", prompt.shape, input_pos.shape)
next_token = prefill(model, prompt, input_pos, temperature=0, top_k=1).clone()

shapes torch.Size([1, 3]) torch.Size([3])


In [18]:
activations

{'embeddings': tensor([[[-0.0006,  0.0005, -0.0049,  ..., -0.0001, -0.0004, -0.0010],
          [-0.0060, -0.0042,  0.0084,  ...,  0.0510,  0.0008, -0.0086],
          [-0.0225, -0.0293,  0.0107,  ..., -0.0100,  0.0104, -0.0571]]],
        device='cuda:0', dtype=torch.bfloat16),
 'layer_0_attn_norm': tensor([[[-0.0330,  0.0282, -0.2754,  ..., -0.0080, -0.0244, -0.0520],
          [-0.0435, -0.0325,  0.0591,  ...,  0.3574,  0.0059, -0.0566],
          [-0.1572, -0.2197,  0.0737,  ..., -0.0688,  0.0723, -0.3652]]],
        device='cuda:0', dtype=torch.bfloat16),
 'layer_0_q_proj': tensor([[[ 1.6641,  1.5234,  0.2852,  ...,  0.1426, -0.2324,  0.5820],
          [ 1.5234,  1.5938, -0.3984,  ..., -0.4883, -0.3789,  0.3828],
          [ 0.6250,  2.2031, -1.1250,  ..., -0.0559,  0.4629,  0.3320]]],
        device='cuda:0', dtype=torch.bfloat16),
 'layer_0_k_proj': tensor([[[-0.4316,  1.1406,  0.4277,  ...,  3.0156, -5.6250,  8.5000],
          [-0.9180,  1.0234,  0.4961,  ...,  2.3906, -5.218

In [33]:
activations["layer_0_wqkv"].split([3584, 512, 512], dim=-1)

(tensor([[[ 1.6641,  1.5234,  0.2852,  ...,  0.1426, -0.2324,  0.5820],
          [ 1.5234,  1.5938, -0.3984,  ..., -0.4883, -0.3789,  0.3828],
          [ 0.6250,  2.2031, -1.1250,  ..., -0.0559,  0.4629,  0.3320]]],
        device='cuda:0', dtype=torch.bfloat16),
 tensor([[[-0.4316,  1.1406,  0.4277,  ...,  3.0156, -5.6250,  8.5000],
          [-0.9180,  1.0234,  0.4961,  ...,  2.3906, -5.2188,  8.3750],
          [-0.4551,  1.2344,  0.3711,  ...,  3.7656, -7.5938,  9.0625]]],
        device='cuda:0', dtype=torch.bfloat16),
 tensor([[[-6.9824e-02, -7.3047e-01, -8.5156e-01,  ...,  1.1406e+00,
            1.0312e+00,  9.4238e-02],
          [-1.2305e-01, -1.2695e-01, -1.3086e-01,  ...,  1.2779e-04,
            4.0234e-01,  1.1279e-01],
          [ 1.4258e-01, -1.4465e-02,  1.5820e-01,  ..., -1.7188e-01,
           -6.3672e-01,  1.7090e-01]]], device='cuda:0', dtype=torch.bfloat16))

In [34]:
q, k, v = activations["layer_0_wqkv"].split([3584, 512, 512], dim=-1)
print("q", q.shape, "k", k.shape, "v", v.shape)

q torch.Size([1, 3, 3584]) k torch.Size([1, 3, 512]) v torch.Size([1, 3, 512])


In [29]:
embeddings_custom = activations['embeddings']

In [35]:
print("Embeddings dtype:", model.tok_embeddings.weight.dtype)

Embeddings dtype: torch.bfloat16


In [None]:
torch.allclose(embeddings_custom, embeddings_hf, rtol=1e-5, atol=1e-8)

RuntimeError: BFloat16 did not match Float

In [33]:
embeddings_custom.dtype, embeddings_hf.dtype

(torch.bfloat16, torch.float32)

In [20]:
next_token.shape

torch.Size([1, 1])

In [19]:
input_pos = torch.tensor([prompt_length], device=device, dtype=torch.int)

In [20]:
logits = model(next_token, input_pos)

In [21]:
next_token, next_prob = sample(logits, temperature=0, top_k=1)

In [22]:
next_token.shape
next_token

tensor([[220]], device='cuda:0', dtype=torch.int32)

In [23]:
print(tokenizer.decode(next_token[0].cpu().numpy()))

 
