### Test Toy for Task 2

In [1]:
TEST_WITH_REF = False # NOTE: toggle this flag to `True` to enable testing with running the cells with ref
# TEST_WITH_REF = True

In [2]:
device = "cpu" # NOTE: you had better use "cuda", otherwise it might be very slow
# device = "cuda"

In [3]:
model_dir = "./model/llama_3.2_1b_instruct/"
num_shards = 1

#### Step0. set up the environment

In [4]:
import os
import json

import torch
import torch.nn.functional as F
from torch.testing import assert_close

from transformers import LlamaForCausalLM, AutoTokenizer
from transformers import pipeline
from transformers import logging
logging.set_verbosity_error()

In [5]:
if TEST_WITH_REF:
    from ref.modeling import (
        BatchLayout as BatchLayoutRef,
        PaddingSide as PaddingSideRef,
        TruncateSide as TruncateSideRef,
        
        PromptType as PromptTypeRef,
        PromptTemplate as PromptTemplateRef,
    )
    from ref.modeling.models import (
        LlamaConfig as LlamaConfigRef,
        LlamaTokenizer as LlamaTokenizerRef,
        LlamaModel as LlamaModelRef,
    )
    from ref.inference import (
        DecodeStrategy as DecodeStrategyRef,
        InferenceConfig as InferenceConfigRef,
        InferenceAgent as InferenceAgentRef,
    )

In [6]:
from src.modeling import (
    BatchLayout,
    PaddingSide,
    TruncateSide,
    
    PromptType,
    PromptTemplate,
)
from src.modeling.models import (
    LlamaConfig,
    LlamaTokenizer,
    LlamaModel,
)
from src.inference import (
    DecodeStrategy,
    InferenceConfig,
    InferenceAgent,
)

#### Step1. load the pretrained model

In [7]:
config_file = os.path.join(model_dir, "config.json")
params_files = os.path.join(model_dir, "model.safetensors")

with open(config_file, "r") as f:
    config = json.load(f)

In [8]:
llama_config_ref = None
if TEST_WITH_REF:
    llama_config_ref: LlamaConfigRef = LlamaModelRef.load_config(
        config_file, 
        param_device=device,
    )
llama_config_ref

In [9]:
llama_config: LlamaConfig = LlamaModel.load_config(
    config_file, 
    param_device=device,
)
llama_config

********************   LlamaConfig   ********************
activation_type: MLPActivationType.SILU
apply_qk_norm: False
causal: True
eps: 1e-05
ffh_size: 8192
gate_init_mean: 0.0
gate_init_std: 1.0
group_size: None
head_dim: 64
hidden_size: 2048
init_base_seed: 42
lm_head_tied: True
lora_alpha: None
lora_dropout_rate: 0.0
lora_dropout_seed: 42
lora_init_base_seed: 42
lora_rank: 0
max_seq_len: 8192
moe_topk: 1
norm_init_range: (-1.0, 1.0)
num_experts: None
num_kv_head: 8
num_layers: 16
num_q_head: 32
online_attn_block_size: None
param_device: 'cpu'
param_dtype: torch.bfloat16
process_group: None
proj_init_mean: 0.0
proj_init_seed: 42
proj_init_std: 1.0
qk_norm_group_size: None
qkv_layout: AttnQKVLayout.BSHD
qkv_pack_format: AttnQKVPackFormat.Q_K_V
rank: 0
rope_base: 500000.0
rope_dynamic: False
rope_ratio: 1
softmax_cap: None
softmax_clip_range: (0.0, 1.0)
softmax_dropout_rate: 0.0
softmax_dropout_seed: 42
softmax_scale: None
softmax_temp: 1.0
vocab_init_mean: 0.0
vocab_init_std: 1.0
voc

In [10]:
llama_hf = LlamaForCausalLM.from_pretrained(model_dir, torch_dtype=llama_config.param_dtype).to(device)
llama_hf

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm):

In [11]:
llama_model_ref = None
if TEST_WITH_REF:
    llama_model_ref = LlamaModelRef(llama_config_ref)
    llama_model_ref.load_parameters(params_files)
llama_model_ref

In [12]:
llama_model = LlamaModel(llama_config)
llama_model.load_parameters(params_files)
llama_model

LlamaModel(
  (block): TransformerDecoderBlock(
    (vocab_emb): ParallelVocabEmbedding()
    (layers): ModuleList(
      (0-15): 16 x TransformerDecoderLayer(
        (attn_pre_norm): GroupRMSNorm()
        (rope): NTKAwareRoPE()
        (attn): OfflineSlidingWindowAttn(
          (softmax_dropout): Dropout(p=0.0, inplace=False)
        )
        (mlp_pre_norm): GroupRMSNorm()
        (mlp): DenseMLPWithLoRA()
      )
    )
    (kv_cache): TransformerDecoderKVCache()
    (final_norm): GroupRMSNorm()
    (lm_head): Linear(in_features=2048, out_features=128256, bias=False)
  )
)

#### Step2. load the pretrained tokenizer

In [13]:
tokenizer_file = os.path.join(model_dir, "tokenizer.json")
tokenizer_config_file = os.path.join(model_dir, "tokenizer_config.json")

In [14]:
query = "The key to life is"
response1 = " not to be afraid to take risks and try new"
response2 = " not to be happy, but to be content"
prompt1 = query + response1
prompt2 = query + response2
prompt1, prompt2

('The key to life is not to be afraid to take risks and try new',
 'The key to life is not to be happy, but to be content')

In [15]:
llama_hf_tokenizer = AutoTokenizer.from_pretrained(model_dir)
llama_hf_tokenizer.pad_token_id = llama_hf_tokenizer.eos_token_id # NOTE: it pads right by default
print(llama_hf_tokenizer.pad_token, llama_hf_tokenizer.pad_token_id, llama_hf_tokenizer.eos_token, llama_hf_tokenizer.eos_token_id, llama_hf_tokenizer.bos_token, llama_hf_tokenizer.bos_token_id)

encoded_ids_hf = llama_hf_tokenizer([prompt1, prompt2]).input_ids
print(encoded_ids_hf)

decoded_prompts = llama_hf_tokenizer.batch_decode(encoded_ids_hf, skip_special_tokens=True)
decoded_prompts

<|eot_id|> 128009 <|eot_id|> 128009 <|begin_of_text|> 128000
[[128000, 791, 1401, 311, 2324, 374, 539, 311, 387, 16984, 311, 1935, 15635, 323, 1456, 502], [128000, 791, 1401, 311, 2324, 374, 539, 311, 387, 6380, 11, 719, 311, 387, 2262]]


['The key to life is not to be afraid to take risks and try new',
 'The key to life is not to be happy, but to be content']

In [16]:
if TEST_WITH_REF:
    llama_tokenizer_ref = LlamaTokenizerRef(
        vocab_file=tokenizer_file,
        config_file=tokenizer_config_file,
    )
    print(llama_tokenizer_ref)

    encoded_ids_ref = llama_tokenizer_ref.encode([prompt1, prompt2])
    print(encoded_ids_ref)
    
    decoded_prompts_ref = llama_tokenizer_ref.decode(encoded_ids_ref)
    print(decoded_prompts_ref)

In [17]:
llama_tokenizer = LlamaTokenizer(
    vocab_file=tokenizer_file,
    config_file=tokenizer_config_file,
)
print(llama_tokenizer)

encoded_ids = llama_tokenizer.encode([prompt1, prompt2])
print(encoded_ids)

decoded_prompts = llama_tokenizer.decode(encoded_ids)
decoded_prompts

LlamaTokenizer()
[tensor([128000,    791,   1401,    311,   2324,    374,    539,    311,    387,
         16984,    311,   1935,  15635,    323,   1456,    502]), tensor([128000,    791,   1401,    311,   2324,    374,    539,    311,    387,
          6380,     11,    719,    311,    387,   2262])]


['The key to life is not to be afraid to take risks and try new',
 'The key to life is not to be happy, but to be content']

#### Step3. test PromptTemplate

In [18]:
if TEST_WITH_REF:
    prompt_template_ref = PromptTemplateRef(
        template_str="I am a {profession} named {name}, with the age of {age}.",
    )

    prompt_template_ref.set_default(name="John", age="24")
    print(prompt_template_ref.keys())
    prompt_str_ref = prompt_template_ref(profession="programmer")
    print(prompt_str_ref)

In [19]:
prompt_template = PromptTemplate(
    template_str="I am a {profession} named {name}, with the age of {age}.",
)

prompt_template.set_default(name="John", age="24")
print(prompt_template.keys())
prompt_str = prompt_template(profession="programmer")
prompt_str

{'profession': None, 'name': 'John', 'age': '24'}


'I am a programmer named John, with the age of 24.'

#### Step4. load the InferenceConfig

In [20]:
generation_config_file = os.path.join(model_dir, "generation_config.json")
max_new_tokens = 20
sampling_seed = 42

In [21]:
if TEST_WITH_REF:
    inf_config_loadded_ref = InferenceAgentRef.load_generation_config(
        generation_config_file, 
        max_new_tokens=max_new_tokens, 
        sampling_seed=sampling_seed,
        device=device,
    )
    print(inf_config_loadded_ref)

In [22]:
inf_config_loadded = InferenceAgent.load_generation_config(
    generation_config_file, 
    max_new_tokens=max_new_tokens, 
    sampling_seed=sampling_seed,
    device=device,
)
inf_config_loadded

********************   InferenceConfig   ********************
batch_layout: BatchLayout.STACK
decode_strategy: DecodeStrategy.SAMPLING
device: 'cpu'
max_new_tokens: 20
pad_to_multiple_of: 1
padding_side: PaddingSide.LEFT
sampling_seed: 42
streaming: False
temperature: 0.6
top_k: 50
top_p: 0.9
truncate_length: None
truncate_side: TruncateSide.RIGHT
*************************************************************

In [23]:
if TEST_WITH_REF:
    inf_config_ref = InferenceConfigRef(
        decode_strategy=DecodeStrategyRef.GREEDY,
        max_new_tokens=max_new_tokens,
        temperature=1.0,
        top_p=0.9,
        top_k=50,
        sampling_seed=sampling_seed,
        padding_side=PaddingSideRef.LEFT,
        pad_to_multiple_of=1,
        truncate_length=None,
        truncate_side=TruncateSideRef.RIGHT,
        device=device,
    )
    print(inf_config_ref)

In [24]:
inf_config = InferenceConfig(
    decode_strategy=DecodeStrategy.GREEDY,
    max_new_tokens=max_new_tokens,
    temperature=1.0,
    top_p=0.9,
    top_k=50,
    sampling_seed=sampling_seed,
    padding_side=PaddingSide.LEFT,
    pad_to_multiple_of=1,
    truncate_length=None,
    truncate_side=TruncateSide.RIGHT,
    device=device,
)
inf_config

********************   InferenceConfig   ********************
batch_layout: BatchLayout.STACK
decode_strategy: DecodeStrategy.GREEDY
device: 'cpu'
max_new_tokens: 20
pad_to_multiple_of: 1
padding_side: PaddingSide.LEFT
sampling_seed: 42
streaming: False
temperature: 1.0
top_k: 50
top_p: 0.9
truncate_length: None
truncate_side: TruncateSide.RIGHT
*************************************************************

#### Step5. load the InferenceAgent

In [25]:
pipe_hf = pipeline(
    "text-generation",
    model=llama_hf,
    tokenizer=llama_hf_tokenizer,
    device=device,
    do_sample=inf_config.decode_strategy == DecodeStrategy.SAMPLING,
    max_new_tokens=inf_config.max_new_tokens,
    temperature=inf_config.temperature,
    top_p=inf_config.top_p,
    top_k=inf_config.top_k,
)
pipe_hf

<transformers.pipelines.text_generation.TextGenerationPipeline at 0x7f31f8cf54f0>

In [26]:
if TEST_WITH_REF:
    inf_agent_ref = InferenceAgentRef(
        config=inf_config_ref,
        # config=inf_config_loadded_ref,
        model=llama_model_ref,
        tokenizer=llama_tokenizer_ref,
    )
    print(inf_agent_ref)

In [27]:
inf_agent = InferenceAgent(
    config=inf_config,
    # config=inf_config_loadded,
    model=llama_model,
    tokenizer=llama_tokenizer,
)
inf_agent

InferenceAgent(
  (model): LlamaModel(
    (block): TransformerDecoderBlock(
      (vocab_emb): ParallelVocabEmbedding()
      (layers): ModuleList(
        (0-15): 16 x TransformerDecoderLayer(
          (attn_pre_norm): GroupRMSNorm()
          (rope): NTKAwareRoPE()
          (attn): OfflineSlidingWindowAttn(
            (softmax_dropout): Dropout(p=0.0, inplace=False)
          )
          (mlp_pre_norm): GroupRMSNorm()
          (mlp): DenseMLPWithLoRA()
        )
      )
      (kv_cache): TransformerDecoderKVCache()
      (final_norm): GroupRMSNorm()
      (lm_head): Linear(in_features=2048, out_features=128256, bias=False)
    )
  )
  (tokenizer): LlamaTokenizer()
  (_system_prompt_template): PromptTemplate()
  (_context_prompt_template): PromptTemplate()
)

#### Step6. test InferenceAgent on text generation

In [28]:
system_prompt_template = PromptTemplate(
    template_str="You're a helpful assitant on {subject}.\n",
)
context_prompt_template = PromptTemplate(
    template_str="Fill the sentence below for you to make it {adjective}.\n",
)

subject = "life"
adjective = "reasonable"

system_prompt_str = system_prompt_template(subject=subject)
context_prompt_str = context_prompt_template(adjective=adjective)
system_prompt_str, context_prompt_str

("You're a helpful assitant on life.\n",
 'Fill the sentence below for you to make it reasonable.\n')

In [29]:
querys = [
    "The key to life is",
    "The only thing we have to fear is",
    "The cat jumped on the keyboard and accidentally",
]

prompts = [
    system_prompt_str + context_prompt_str + q
    for q in querys
]
prompts

["You're a helpful assitant on life.\nFill the sentence below for you to make it reasonable.\nThe key to life is",
 "You're a helpful assitant on life.\nFill the sentence below for you to make it reasonable.\nThe only thing we have to fear is",
 "You're a helpful assitant on life.\nFill the sentence below for you to make it reasonable.\nThe cat jumped on the keyboard and accidentally"]

In [30]:
prompt_dicts = pipe_hf(prompts)
for i, prompt_dict in enumerate(prompt_dicts):
    print(f"\n{'='*25} The {i}-th sample in the batch {'='*25}")
    generated_text = prompt_dict[0]["generated_text"]
    print(f"[generated_text]: {generated_text}")




[generated_text]: You're a helpful assitant on life.
Fill the sentence below for you to make it reasonable.
The key to life is to be happy and content with what you have, and to never let anyone else's opinions dictate your

[generated_text]: You're a helpful assitant on life.
Fill the sentence below for you to make it reasonable.
The only thing we have to fear is fear itself. (Franklin D. Roosevelt)
This is a famous quote from the 32nd

[generated_text]: You're a helpful assitant on life.
Fill the sentence below for you to make it reasonable.
The cat jumped on the keyboard and accidentally typed out a message that said "I'm purr-fectly happy to be stuck in this


In [31]:
if TEST_WITH_REF:
    inf_agent_ref.set_prompt(
        prompt_template=system_prompt_template,
        prompt_type=PromptTypeRef.SYSTEM,
    )
    inf_agent_ref.set_prompt(
        prompt_template=context_prompt_template,
        prompt_type=PromptTypeRef.CONTEXT,
    )
    
    prompt_dicts = inf_agent_ref(querys, subject=subject, adjective=adjective)
    for i, prompt_dict in enumerate(prompt_dicts):
        print(f"\n{'='*25} The {i}-th sample in the batch {'='*25}")
        for prompt_type, promp in prompt_dict.items():
            print(f"\n[{prompt_type}]: {promp}")

In [32]:
inf_agent.set_prompt(
    prompt_template=system_prompt_template,
    prompt_type=PromptType.SYSTEM,
)
inf_agent.set_prompt(
    prompt_template=context_prompt_template,
    prompt_type=PromptType.CONTEXT,
)
prompt_dicts = inf_agent(querys, subject=subject, adjective=adjective)
for i, prompt_dict in enumerate(prompt_dicts):
    print(f"\n{'='*25} The {i}-th sample in the batch {'='*25}")
    for prompt_type, promp in prompt_dict.items():
        print(f"\n[{prompt_type}]: {promp}")



[PromptType.SYSTEM]: You're a helpful assitant on life.


[PromptType.CONTEXT]: Fill the sentence below for you to make it reasonable.


[PromptType.QUERY]: The key to life is

[PromptType.RESPONSE]:  to be a good friend to others. 
The best way to do this is to be a good

[PromptType.PROMPT]: You're a helpful assitant on life.
Fill the sentence below for you to make it reasonable.
The key to life is

[PromptType.ALL]: You're a helpful assitant on life.
Fill the sentence below for you to make it reasonable.
The key to life is to be a good friend to others. 
The best way to do this is to be a good


[PromptType.SYSTEM]: You're a helpful assitant on life.


[PromptType.CONTEXT]: Fill the sentence below for you to make it reasonable.


[PromptType.QUERY]: The only thing we have to fear is

[PromptType.RESPONSE]:  fear itself. (Franklin D. Roosevelt)
This is a famous quote from the 32nd

[PromptType.PROMPT]: You're a helpful assitant on life.
Fill the sentence below for you to make it re