### Test Toy for Task 1

In [1]:
TEST_WITH_REF = False # NOTE: toggle this flag to `True` to enable testing with running the cells with ref
# TEST_WITH_REF = True

In [None]:
device = "cpu" # NOTE: you had better use "cuda", otherwise it might be very slow
# device = "cuda"

In [3]:
model_dir = "./model/llama_3.2_1b_instruct/"
num_shards = 1

#### Step0. set up the environment

In [4]:
import os
import json

import torch
import torch.nn.functional as F
from torch.testing import assert_close

from transformers import LlamaForCausalLM, AutoTokenizer
from transformers import logging
logging.set_verbosity_error()

In [5]:
if TEST_WITH_REF:
    from ref.modeling.models import (
        LlamaConfig as LlamaConfigRef,
        LlamaModel as LlamaModelRef,
    )

In [6]:
from src.modeling.models import (
    LlamaConfig,
    LlamaModel,
)

#### Step1. test LlamaConfig loading

In [7]:
config_file = os.path.join(model_dir, "config.json")
params_files = os.path.join(model_dir, "model.safetensors")
config_file, params_files

('./model/llama_3.2_1b_instruct/config.json',
 './model/llama_3.2_1b_instruct/model.safetensors')

In [8]:
with open(config_file, "r") as f:
    config = json.load(f)
config

{'architectures': ['LlamaForCausalLM'],
 'attention_bias': False,
 'attention_dropout': 0.0,
 'bos_token_id': 128000,
 'eos_token_id': [128001, 128008, 128009],
 'head_dim': 64,
 'hidden_act': 'silu',
 'hidden_size': 2048,
 'initializer_range': 0.02,
 'intermediate_size': 8192,
 'max_position_embeddings': 131072,
 'mlp_bias': False,
 'model_type': 'llama',
 'num_attention_heads': 32,
 'num_hidden_layers': 16,
 'num_key_value_heads': 8,
 'pretraining_tp': 1,
 'rms_norm_eps': 1e-05,
 'rope_scaling': {'factor': 32.0,
  'high_freq_factor': 4.0,
  'low_freq_factor': 1.0,
  'original_max_position_embeddings': 8192,
  'rope_type': 'llama3'},
 'rope_theta': 500000.0,
 'tie_word_embeddings': True,
 'torch_dtype': 'bfloat16',
 'transformers_version': '4.45.0.dev0',
 'use_cache': True,
 'vocab_size': 128256}

In [9]:
llama_config_ref = None
if TEST_WITH_REF:
    llama_config_ref: LlamaConfigRef = LlamaModelRef.load_config(
        config_file, 
        param_device=device,
    )
llama_config_ref

In [10]:
llama_config: LlamaConfig = LlamaModel.load_config(
    config_file, 
    param_device=device,
)
llama_config

********************   LlamaConfig   ********************
activation_type: MLPActivationType.SILU
apply_qk_norm: False
causal: True
eps: 1e-05
ffh_size: 8192
gate_init_mean: 0.0
gate_init_std: 1.0
group_size: None
head_dim: 64
hidden_size: 2048
init_base_seed: 42
lm_head_tied: True
lora_alpha: None
lora_dropout_rate: 0.0
lora_dropout_seed: 42
lora_init_base_seed: 42
lora_rank: 0
max_seq_len: 8192
moe_topk: 1
norm_init_range: (-1.0, 1.0)
num_experts: None
num_kv_head: 8
num_layers: 16
num_q_head: 32
online_attn_block_size: None
param_device: 'cpu'
param_dtype: torch.bfloat16
process_group: None
proj_init_mean: 0.0
proj_init_seed: 42
proj_init_std: 1.0
qk_norm_group_size: None
qkv_layout: AttnQKVLayout.BSHD
qkv_pack_format: AttnQKVPackFormat.Q_K_V
rank: 0
rope_base: 500000.0
rope_dynamic: False
rope_ratio: 1
softmax_cap: None
softmax_clip_range: (0.0, 1.0)
softmax_dropout_rate: 0.0
softmax_dropout_seed: 42
softmax_scale: None
softmax_temp: 1.0
vocab_init_mean: 0.0
vocab_init_std: 1.0
voc

#### Step2. test LlamaModel loading

In [11]:
llama_tokenizer = AutoTokenizer.from_pretrained(model_dir)
llama_tokenizer.pad_token_id = llama_tokenizer.eos_token_id

llama_hf = LlamaForCausalLM.from_pretrained(model_dir, torch_dtype=llama_config.param_dtype).to(device)
print(llama_hf)

for name, param in llama_hf.named_parameters():
    print(name, param.shape, param.dtype)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm):

In [12]:
if TEST_WITH_REF:
    llama_model_ref = LlamaModelRef(llama_config_ref)
    llama_model_ref.load_parameters(params_files)
    print(llama_model_ref)

    for name, param in llama_model_ref.named_parameters():
        print(name, param.shape, param.dtype)

In [13]:
llama_model = LlamaModel(llama_config)
llama_model.load_parameters(params_files)
print(llama_model)

for name, param in llama_model.named_parameters():
    print(name, param.shape, param.dtype)

LlamaModel(
  (block): TransformerDecoderBlock(
    (vocab_emb): ParallelVocabEmbedding()
    (layers): ModuleList(
      (0-15): 16 x TransformerDecoderLayer(
        (attn_pre_norm): GroupRMSNorm()
        (rope): NTKAwareRoPE()
        (attn): OfflineSlidingWindowAttn(
          (softmax_dropout): Dropout(p=0.0, inplace=False)
        )
        (mlp_pre_norm): GroupRMSNorm()
        (mlp): DenseMLPWithLoRA()
      )
    )
    (kv_cache): TransformerDecoderKVCache()
    (final_norm): GroupRMSNorm()
    (lm_head): Linear(in_features=2048, out_features=128256, bias=False)
  )
)
block.vocab_emb.weight torch.Size([128256, 2048]) torch.bfloat16
block.layers.0.qkv_proj torch.Size([2048, 3072]) torch.bfloat16
block.layers.0.out_proj torch.Size([2048, 2048]) torch.bfloat16
block.layers.0.attn_pre_norm.weight torch.Size([1, 1, 1, 2048]) torch.bfloat16
block.layers.0.mlp_pre_norm.weight torch.Size([1, 1, 1, 2048]) torch.bfloat16
block.layers.0.mlp.up_proj torch.Size([2048, 8192]) torch.bfloat1

#### Step3. test LlamaModel statistics APIs

In [14]:
punit, munit = "B", "GB"
pmap = {
    "B": 1000**3,
    "M": 1000**2,
    "K": 1000,
    "1": 1,
}
mmap = {
    "GB": 1024**3,
    "MB": 1024**2,
    "KB": 1024,
    "1": 1,
}

In [15]:
print(f"Total parameters: {sum(p.numel() for p in llama_hf.parameters()) / pmap[punit]:.2f} {punit}")
print(f"Memory footprint: {llama_hf.get_memory_footprint() / mmap[munit]:.2f} {munit}")

Total parameters: 1.24 B
Memory footprint: 2.30 GB


In [16]:
if TEST_WITH_REF:
    total_params_b, memory_gb = llama_model_ref.num_parameters(unit=punit), llama_model_ref.num_memory_footprint(unit=munit)
    print(f"Total parameters: {total_params_b:.2f} {punit}")
    print(f"Memory footprint: {memory_gb:.2f} {munit}")

In [17]:
total_params_b, memory_gb = llama_model.num_parameters(unit=punit), llama_model.num_memory_footprint(unit=munit)
print(f"Total parameters: {total_params_b:.2f} {punit}")
print(f"Memory footprint: {memory_gb:.2f} {munit}")

Total parameters: 1.24 B
Memory footprint: 2.30 GB


#### Step4. test LlamaModel forward in evaluation mode

In [18]:
query = "The key to life is"
input_ids = llama_tokenizer(query, return_tensors="pt").input_ids.to(device)
input_ids.shape, input_ids

(torch.Size([1, 6]),
 tensor([[128000,    791,   1401,    311,   2324,    374]]))

In [19]:
llama_hf.eval()
with torch.no_grad():
    outpu_hf = llama_hf.model(input_ids, return_dict=False)[0]
    logits_hf = llama_hf.lm_head(outpu_hf)
logits_hf = logits_hf[:, -1, :]
probs_hf = F.softmax(logits_hf, dim=-1)

probs_hf.shape, probs_hf.dtype, probs_hf

(torch.Size([1, 128256]),
 torch.bfloat16,
 tensor([[1.0133e-06, 6.7428e-07, 1.4901e-08,  ..., 1.3733e-10, 1.3733e-10,
          1.3733e-10]], dtype=torch.bfloat16))

In [20]:
if TEST_WITH_REF:
    llama_model_ref.eval()
    llama_model_ref.reset_kv_cache()

    with torch.no_grad():
        probs_ref = llama_model_ref(input_ids)

    print(probs_ref.shape, probs_ref.dtype, probs_ref)

In [21]:
llama_model.eval()
llama_model.reset_kv_cache()

with torch.no_grad():
    probs = llama_model(input_ids)

probs.shape, probs.dtype, probs

(torch.Size([1, 128256]),
 torch.bfloat16,
 tensor([[9.5367e-07, 6.3330e-07, 1.3795e-08,  ..., 1.3097e-10, 1.3097e-10,
          1.3097e-10]], dtype=torch.bfloat16))

In [22]:
try: assert_close(probs, probs_hf)
except Exception as e: print(e)

Tensor-likes are not close!

Mismatched elements: 126 / 128256 (0.1%)
Greatest absolute difference: 0.02734375 at index (0, 539) (up to 1e-05 allowed)
Greatest relative difference: 0.12255859375 at index (0, 1202) (up to 0.016 allowed)


In [23]:
if TEST_WITH_REF:
    try: assert_close(probs, probs_ref)
    except Exception as e: print(e)

#### Step5. test LlamaModel forward in training mode

In [24]:
query = "The key to life is to be happy"
input_ids = llama_tokenizer(query, return_tensors="pt").input_ids.to(device)
labels = input_ids.clone()
input_ids.shape, input_ids, labels.shape, labels

(torch.Size([1, 9]),
 tensor([[128000,    791,   1401,    311,   2324,    374,    311,    387,   6380]]),
 torch.Size([1, 9]),
 tensor([[128000,    791,   1401,    311,   2324,    374,    311,    387,   6380]]))

In [25]:
llama_hf.train()
loss_hf = llama_hf(input_ids, labels=labels).loss
loss_hf

tensor(3.3770, grad_fn=<NllLossBackward0>)

In [26]:
loss_ref = None
if TEST_WITH_REF:
    llama_model_ref.train()

    with torch.enable_grad():
        loss_ref = llama_model_ref(input_ids, labels=labels)

loss_ref

In [27]:
llama_model.train()

with torch.enable_grad():
    loss = llama_model(input_ids, labels=labels)

loss

tensor(3.3784, grad_fn=<NllLossBackward0>)