# llama3.2 from scratch

## 准备工作
首先从huggingface下载Llama-3.2-3B-Instruct，具体步骤是：
1. Pip安装huggingface-cli
`pip install -U huggingface_hub`
2. 执行`huggingface-cli login`，输入token(token来自于自己的注册)
3. 然后用transformer接口下载模型，默认路径：~/.cache
[参考](https://github.com/huggingface/transformers/blob/main/src/transformers/models/llama/modeling_llama.py)

In [1]:
from transformers import AutoTokenizer
import json
import os
import torch
from torch import nn

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# 读取并查看模型参数
with open("/HOME/scz0101/run/model_acceleration/models/config.json", "r") as f:
    config = json.load(f)
# print(config)
print(config['rope_scaling']['factor'])

32.0


In [41]:
# 读取模型权重
from safetensors import safe_open
weights_root = "/HOME/scz0101/.cache/huggingface/hub/models--meta-llama--Llama-3.2-3B-Instruct/snapshots/0cb88a4f764b7a12671c53f0838cd831a0843b95"
# 文件路径
file1 = os.path.join(weights_root, "model-00001-of-00002.safetensors")
file2 = os.path.join(weights_root, "model-00002-of-00002.safetensors")

# # 加载第一个文件
with safe_open(file1, framework="pt", device="cpu") as f:
    state_dict1 = {key: f.get_tensor(key) for key in f.keys()}
# # 查看key和size
print(json.dumps(list(state_dict1.keys()), indent=4)) ## layer0~20
# print("model.embed_tokens.weight", state_dict1["model.embed_tokens.weight"].shape) # vob_size*hidden_size
# print("model.layers.0.input_layernorm.weight", state_dict1["model.layers.0.input_layernorm.weight"].shape) #hidden_size
# print("model.layers.0.mlp.down_proj.weight", state_dict1["model.layers.0.mlp.down_proj.weight"].shape) #hidden_size * intermediate_size
# print("model.layers.0.mlp.gate_proj.weight", state_dict1["model.layers.0.mlp.gate_proj.weight"].shape) # intermediate_size * hidden_size 
# print("model.layers.0.mlp.up_proj.weight", state_dict1["model.layers.0.mlp.up_proj.weight"].shape) #intermediate_size * hidden_size 
# print("model.layers.0.post_attention_layernorm.weight",state_dict1["model.layers.0.post_attention_layernorm.weight"].shape) #hidden_size
# print("model.layers.0.self_attn.k_proj.weight",state_dict1["model.layers.0.self_attn.k_proj.weight"].shape) #(head_dim*num_key_value_heads)*hidden_size
# print("model.layers.0.self_attn.o_proj.weight",state_dict1["model.layers.0.self_attn.o_proj.weight"].shape) #(head_dim*num_attention_heads)*hidden_size
# print("model.layers.0.self_attn.q_proj.weight",state_dict1["model.layers.0.self_attn.q_proj.weight"].shape) #(head_dim*num_attention_heads)*hidden_size
# print("model.layers.0.self_attn.v_proj.weight",state_dict1["model.layers.0.self_attn.v_proj.weight"].shape) #(head_dim*num_key_value_heads)*hidden_size

# # # 加载第二个文件(为了节省时间，可以先skip)
with safe_open(file2, framework="pt", device="cpu") as f:
    state_dict2 = {key: f.get_tensor(key) for key in f.keys()}

# # 合并两个状态字典
# state_dict = {**state_dict1, **state_dict2}

# # 查看模型的权重key
print(json.dumps(list(state_dict2.keys()), indent=4)) ## layer21~27

[
    "model.embed_tokens.weight",
    "model.layers.0.input_layernorm.weight",
    "model.layers.0.mlp.down_proj.weight",
    "model.layers.0.mlp.gate_proj.weight",
    "model.layers.0.mlp.up_proj.weight",
    "model.layers.0.post_attention_layernorm.weight",
    "model.layers.0.self_attn.k_proj.weight",
    "model.layers.0.self_attn.o_proj.weight",
    "model.layers.0.self_attn.q_proj.weight",
    "model.layers.0.self_attn.v_proj.weight",
    "model.layers.1.input_layernorm.weight",
    "model.layers.1.mlp.down_proj.weight",
    "model.layers.1.mlp.gate_proj.weight",
    "model.layers.1.mlp.up_proj.weight",
    "model.layers.1.post_attention_layernorm.weight",
    "model.layers.1.self_attn.k_proj.weight",
    "model.layers.1.self_attn.o_proj.weight",
    "model.layers.1.self_attn.q_proj.weight",
    "model.layers.1.self_attn.v_proj.weight",
    "model.layers.10.input_layernorm.weight",
    "model.layers.10.mlp.down_proj.weight",
    "model.layers.10.mlp.gate_proj.weight",
    "model.

In [4]:
#先跳过实现tokenizer，可以直接使用autotokenizer，后续再来实现
model_name_or_path = "meta-llama/Llama-3.2-3B-Instruct"
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

In [5]:
# 创建输入
input_text = "how are you?"
inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True)
print(inputs["input_ids"])
print(tokenizer.decode(inputs["input_ids"][0].tolist(), skip_special_tokens=True))
seq_len = len(inputs["input_ids"][0].tolist()) ## token num
print(seq_len)

tensor([[128000,   5269,    527,    499,     30]])
how are you?
5


In [6]:
#开始创建embeding_layer
embedding_layer = torch.nn.Embedding(config['vocab_size'], config["hidden_size"])
embedding_layer.weight.data.copy_(state_dict1["model.embed_tokens.weight"])

tensor([[ 1.1292e-02,  9.9487e-03,  1.4160e-02,  ..., -3.5706e-03,
         -1.9775e-02,  5.3711e-03],
        [ 1.3245e-02, -3.8385e-05,  2.2461e-02,  ..., -2.6550e-03,
          3.1738e-02, -1.0681e-03],
        [ 1.9775e-02,  2.0020e-02,  2.8687e-02,  ..., -3.5248e-03,
          3.1433e-03, -7.6294e-03],
        ...,
        [-3.0975e-03,  2.1057e-03,  4.8828e-03,  ..., -2.0905e-03,
         -1.2207e-03, -2.8992e-03],
        [-3.0975e-03,  2.1057e-03,  4.8828e-03,  ..., -2.0905e-03,
         -1.2207e-03, -2.8992e-03],
        [-3.0975e-03,  2.1057e-03,  4.8828e-03,  ..., -2.0905e-03,
         -1.2207e-03, -2.8992e-03]])

In [3]:
# 创建rms_norm层，tensor2维，对每一行做归一化
class LlamaRMSNorm(nn.Module):
    def __init__(self, hidden_size, eps=1e-6):
        """
        LlamaRMSNorm is equivalent to T5LayerNorm
        """
        super().__init__()
        self.weight = nn.Parameter(torch.ones(hidden_size))
        self.variance_epsilon = eps

    def forward(self, hidden_states):
        input_dtype = hidden_states.dtype
        hidden_states = hidden_states.to(torch.float32)
        variance = hidden_states.pow(2).mean(-1, keepdim=True)
        hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
        return self.weight * hidden_states.to(input_dtype)

In [27]:
# 准备位置编码（RoPE）
import math

def get_inv_freq_llama3(device):
    dim = config["head_dim"]
    rope_theta = config["rope_theta"]
    inv_freq = 1.0 / (rope_theta ** (torch.arange(0, dim, 2, dtype=torch.int64).float().to(device) / dim))
    # print(inv_freq)
    attention_factor = 1.0
    # 对于llama3的特殊处理
    factor = config["rope_scaling"]["factor"]  # `8` in the original implementation
    low_freq_factor = config["rope_scaling"]["low_freq_factor"]  # `1` in the original implementation
    high_freq_factor = config["rope_scaling"]["high_freq_factor"]  # `4` in the original implementation
    old_context_len = config["rope_scaling"]["original_max_position_embeddings"]  # `8192` in the original implementation

    low_freq_wavelen = old_context_len / low_freq_factor
    high_freq_wavelen = old_context_len / high_freq_factor



    wavelen = 2 * math.pi / inv_freq

    print(low_freq_wavelen, high_freq_wavelen, wavelen)
    # wavelen < high_freq_wavelen: do nothing
    # wavelen > low_freq_wavelen: divide by factor
    inv_freq_llama = torch.where(wavelen > low_freq_wavelen, inv_freq / factor, inv_freq)
    # otherwise: interpolate between the two, using a smooth factor
    smooth_factor = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
    smoothed_inv_freq = (1 - smooth_factor) * inv_freq_llama / factor + smooth_factor * inv_freq_llama
    is_medium_freq = ~(wavelen < high_freq_wavelen) * ~(wavelen > low_freq_wavelen)
    inv_freq_llama = torch.where(is_medium_freq, smoothed_inv_freq, inv_freq_llama)
    return inv_freq_llama, attention_factor


def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.
    参考：https://blog.csdn.net/UCB001/article/details/139511775
    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """
    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed


In [5]:

class LlamaMLP(nn.Module):
    def __init__(self):
        super().__init__()
        self.hidden_size = config["hidden_size"]
        self.intermediate_size = config["intermediate_size"]
        self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config["mlp_bias"])
        self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config["mlp_bias"])
        self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config["mlp_bias"])
        self.act_fn = nn.SiLU

    def forward(self, x):
        down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
        return down_proj

In [8]:
## 番外，对于普通sdpa和flash atten
import torch
import torch.nn.functional as F
import time

# 设置随机种子以确保可重复性
torch.manual_seed(42)

# 定义输入参数
batch_size = 8
num_heads = 8
seq_len = 1024
d_model = 64

# 随机生成 Q, K, V
Q = torch.randn(batch_size, num_heads, seq_len, d_model)
K = torch.randn(batch_size, num_heads, seq_len, d_model)
V = torch.randn(batch_size, num_heads, seq_len, d_model)

# 标准 SDPA 实现
def standard_sdpa(Q, K, V):
    scale = d_model ** 0.5
    scores = torch.matmul(Q, K.transpose(-2, -1)) / scale  # QK^T / sqrt(d_k)
    attn_weights = F.softmax(scores, dim=-1)  # Softmax
    output = torch.matmul(attn_weights, V)  # 加权求和
    return output

# 使用 PyTorch 的优化实现（基于 Flash Attention）
def flash_attention(Q, K, V):
    return F.scaled_dot_product_attention(Q, K, V)

# 测试标准 SDPA
start_time = time.time()
output_standard = standard_sdpa(Q, K, V)
standard_time = time.time() - start_time

# 测试 Flash Attention
start_time = time.time()
output_flash = flash_attention(Q, K, V)
flash_time = time.time() - start_time

# 检查输出是否一致（允许微小误差）
print("输出是否接近:", torch.allclose(output_standard, output_flash, atol=1e-5))
print(f"标准 SDPA 时间: {standard_time:.6f} 秒")
print(f"Flash Attention 时间: {flash_time:.6f} 秒")

输出是否接近: True
标准 SDPA 时间: 3.444230 秒
Flash Attention 时间: 0.586423 秒


In [23]:
from typing import Callable, List, Optional, Tuple, Union
import torch
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


def sdpa_attention_forward(
    module: torch.nn.Module,
    query: torch.Tensor,
    key: torch.Tensor,
    value: torch.Tensor,
    attention_mask: Optional[torch.Tensor],
    dropout: float = 0.0,
    scaling: Optional[float] = None,
    is_causal: Optional[bool] = None
) -> Tuple[torch.Tensor, None]:
    if hasattr(module, "num_key_value_groups"):
        key = repeat_kv(key, module.num_key_value_groups)
        value = repeat_kv(value, module.num_key_value_groups)

    causal_mask = attention_mask
    if attention_mask is not None:
        causal_mask = causal_mask[:, :, :, : key.shape[-2]]

    # SDPA with memory-efficient backend is bugged with non-contiguous inputs and custom attn_mask for some torch versions
    # Reference: https://github.com/pytorch/pytorch/issues/112577.
    query = query.contiguous()
    key = key.contiguous()
    value = value.contiguous()

    # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
    # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
    if is_causal is None:
        is_causal = causal_mask is None and query.shape[2] > 1

    # Shapes (e.g. query.shape[2]) are tensors during jit tracing, resulting in `is_causal` being a tensor.
    # We convert it to a bool for the SDPA kernel that only accepts bools.
    if torch.jit.is_tracing() and isinstance(is_causal, torch.Tensor):
        is_causal = is_causal.item()

    attn_output = torch.nn.functional.scaled_dot_product_attention(
        query,
        key,
        value,
        attn_mask=causal_mask,
        dropout_p=dropout,
        scale=scaling,
        is_causal=is_causal,
    )
    attn_output = attn_output.transpose(1, 2).contiguous()

    return attn_output, None

class LlamaAttention(nn.Module):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(self, layer_idx: int):
        super().__init__()
        self.config = config
        self.layer_idx = layer_idx
        self.head_dim = config["head_dim"]
        self.num_key_value_groups = config["num_attention_heads"] // config["num_key_value_heads"]
        self.scaling = self.head_dim**-0.5
        self.attention_dropout = config['attention_dropout']
        self.is_causal = True

        self.q_proj = nn.Linear(
            config['hidden_size'], config['num_attention_heads'] * self.head_dim, bias=config['attention_bias']
        )
        self.k_proj = nn.Linear(
            config['hidden_size'], config['num_key_value_heads'] * self.head_dim, bias=config['attention_bias']
        )
        self.v_proj = nn.Linear(
            config['hidden_size'], config['num_key_value_heads'] * self.head_dim, bias=config['attention_bias']
        )
        self.o_proj = nn.Linear(
            config['num_attention_heads'] * self.head_dim, config['hidden_size'], bias=config['attention_bias']
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor]
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        attn_output, attn_weights = sdpa_attention_forward(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask,
            dropout=0.0 if not self.training else self.attention_dropout,
            scaling=self.scaling
        )

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights

In [25]:
class LlamaDecoderLayer(nn.Module):
    def __init__(self, layer_idx: int):
        super().__init__()
        self.hidden_size = config["hidden_size"]

        self.self_attn = LlamaAttention(layer_idx=layer_idx)

        self.mlp = LlamaMLP()
        self.input_layernorm = LlamaRMSNorm(config['hidden_size'], eps=config['rms_norm_eps'])
        self.post_attention_layernorm = LlamaRMSNorm(config['hidden_size'], eps=config['rms_norm_eps'])

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None # necessary, but kept here for BC
    ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states, self_attn_weights = self.self_attn(
            hidden_states=hidden_states,
            attention_mask=attention_mask,
            position_embeddings=position_embeddings
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)
        # if output_attentions:
        #     outputs += (self_attn_weights,)

        return outputs

In [9]:
class LlamaRotaryEmbedding(nn.Module):
    def __init__(self, device=None):
        super().__init__()
        # BC: "rope_type" was originally "type"

        self.rope_type = config['rope_scaling']['rope_type']
        self.max_seq_len_cached = config['max_position_embeddings']
        self.original_max_seq_len = config['max_position_embeddings']

        self.config = config
        self.rope_init_fn = get_inv_freq_llama3

        inv_freq, self.attention_scaling = self.rope_init_fn(device)
        self.register_buffer("inv_freq", inv_freq, persistent=False)
        self.original_inv_freq = self.inv_freq

    @torch.no_grad()
    def forward(self, x, position_ids):
        # Core RoPE block
        inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
        position_ids_expanded = position_ids[:, None, :].float()
        # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
        device_type = x.device.type
        device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
        with torch.autocast(device_type=device_type, enabled=False): #enabled=False表示在此上下文中禁用混合精度，所有计算将使用默认的浮点精度（通常是FP32）
            freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
            emb = torch.cat((freqs, freqs), dim=-1)
            cos = emb.cos()
            sin = emb.sin()

        # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
        cos = cos * self.attention_scaling
        sin = sin * self.attention_scaling

        return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)


In [31]:
class LlamaModel(nn.Module):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]

    Args:
        config: LlamaConfig
    """

    def __init__(self):
        super().__init__()
        # self.padding_idx = config['pad_token_id']
        self.vocab_size = config['vocab_size']

        self.embed_tokens = nn.Embedding(self.vocab_size, config['hidden_size'])
        self.layers = nn.ModuleList(
            [LlamaDecoderLayer(layer_idx) for layer_idx in range(config['num_hidden_layers'])]
        )
        self.norm = LlamaRMSNorm(config['hidden_size'], eps=config['rms_norm_eps'])
        self.rotary_emb = LlamaRotaryEmbedding()

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None
    ):

        if (input_ids is None) ^ (inputs_embeds is not None):
            raise ValueError("You must specify exactly one of input_ids or inputs_embeds")

        inputs_embeds = self.embed_tokens(input_ids)

        hidden_states = inputs_embeds

        position_embeddings = self.rotary_emb(hidden_states, position_ids)

        for decoder_layer in self.layers[: self.config.num_hidden_layers]:
            layer_outputs = decoder_layer(
                hidden_states,
                attention_mask=attention_mask,
                position_embeddings=position_embeddings
            )

            hidden_states = layer_outputs[0]

        hidden_states = self.norm(hidden_states)

        return hidden_states

In [33]:
class LlamaForCausalLM(nn.Module):
    def __init__(self):
        super().__init__()
        self.model = LlamaModel()
        self.vocab_size = config['vocab_size']
        self.lm_head = nn.Linear(config['hidden_size'], config['vocab_size'], bias=False)

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        logits_to_keep: Union[int, torch.Tensor] = 0
    ):

        # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            position_ids=position_ids
        )

        hidden_states = outputs[0]
        # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
        slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
        logits = self.lm_head(hidden_states[:, slice_indices, :])

        return logits

In [37]:
# from torchsummary import summary
model = LlamaForCausalLM()

# for name, module in model.named_modules():
#     print(f"层名称：{name}")
#     print(f"参数数量：{sum(p.numel() for p in module.parameters())}")
from safetensors.torch import load_file
# 加载第一个分片
model_part1 = load_file(file1)

# 加载第二个分片
model_part2 = load_file(file2)

# 合并模型参数
model_state_dict = {**model_part1, **model_part2}

# 加载合并后的状态字典
model.load_state_dict(model_state_dict)


8192.0 2048.0 tensor([6.2832e+00, 7.7131e+00, 9.4683e+00, 1.1623e+01, 1.4268e+01, 1.7515e+01,
        2.1501e+01, 2.6394e+01, 3.2400e+01, 3.9774e+01, 4.8825e+01, 5.9936e+01,
        7.3576e+01, 9.0320e+01, 1.1087e+02, 1.3611e+02, 1.6708e+02, 2.0510e+02,
        2.5178e+02, 3.0907e+02, 3.7941e+02, 4.6575e+02, 5.7174e+02, 7.0185e+02,
        8.6158e+02, 1.0576e+03, 1.2983e+03, 1.5938e+03, 1.9565e+03, 2.4017e+03,
        2.9483e+03, 3.6192e+03, 4.4429e+03, 5.4540e+03, 6.6951e+03, 8.2187e+03,
        1.0089e+04, 1.2385e+04, 1.5203e+04, 1.8663e+04, 2.2911e+04, 2.8124e+04,
        3.4525e+04, 4.2381e+04, 5.2026e+04, 6.3866e+04, 7.8400e+04, 9.6241e+04,
        1.1814e+05, 1.4503e+05, 1.7803e+05, 2.1855e+05, 2.6828e+05, 3.2934e+05,
        4.0428e+05, 4.9629e+05, 6.0923e+05, 7.4787e+05, 9.1806e+05, 1.1270e+06,
        1.3835e+06, 1.6983e+06, 2.0848e+06, 2.5592e+06])


RuntimeError: Error(s) in loading state_dict for LlamaForCausalLM:
	Missing key(s) in state_dict: "lm_head.weight". 

In [None]:
def sample(logits, temperature: float = 1.0):
    """
    Samples a token from the logits using temperature scaling.

    Args:
        logits (torch.Tensor): The logits tensor for token predictions.
        temperature (float, optional): Temperature for scaling logits. Defaults to 1.0.

    Returns:
        torch.Tensor: The sampled token.
    """
    logits = logits / max(temperature, 1e-5)
    probs = torch.softmax(logits, dim=-1)
    return probs.div_(torch.empty_like(probs).exponential_(1)).argmax(dim=-1)

@torch.inference_mode()
def generate(
    model: LlamaForCausalLM,
    prompt_tokens: List[List[int]],
    max_new_tokens: int,
    eos_id: int,
    temperature: float = 1.0
) -> List[List[int]]:
    """
    Generates new tokens based on the given prompt tokens using the specified model.

    Args:
        model (Transformer): The transformer model used for token generation.
        prompt_tokens (List[List[int]]): A list of lists containing the prompt tokens for each sequence.
        max_new_tokens (int): The maximum number of new tokens to generate.
        eos_id (int): The end-of-sequence token ID.
        temperature (float, optional): The temperature value for sampling. Defaults to 1.0.

    Returns:
        List[List[int]]: A list of lists containing the generated tokens for each sequence.
    """
    prompt_lens = [len(t) for t in prompt_tokens]
    assert max(prompt_lens) <= model.max_seq_len
    total_len = min(model.max_seq_len, max_new_tokens + max(prompt_lens))
    tokens = torch.full((len(prompt_tokens), total_len), -1, dtype=torch.long, device="cuda")
    for i, t in enumerate(prompt_tokens):
        tokens[i, :len(t)] = torch.tensor(t, dtype=torch.long, device="cuda")
    prev_pos = 0
    finished = torch.tensor([False] * len(prompt_tokens), device="cuda")
    prompt_mask = tokens != -1
    for cur_pos in range(min(prompt_lens), total_len):
        logits = model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
        if temperature > 0:
            next_token = sample(logits, temperature)
        else:
            next_token = logits.argmax(dim=-1)
        next_token = torch.where(prompt_mask[:, cur_pos], tokens[:, cur_pos], next_token)
        tokens[:, cur_pos] = next_token
        finished |= torch.logical_and(~prompt_mask[:, cur_pos], next_token == eos_id)
        prev_pos = cur_pos
        if finished.all():
            break
    completion_tokens = []
    for i, toks in enumerate(tokens.tolist()):
        toks = toks[prompt_lens[i]:prompt_lens[i]+max_new_tokens]
        if eos_id in toks:
            toks = toks[:toks.index(eos_id)]
        completion_tokens.append(toks)
    return completion_tokens