In [3]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
MODEL = "meta-llama/Llama-3.2-1B"


In [42]:
from utils import get_best_device


tokenizer = AutoTokenizer.from_pretrained(MODEL)
model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.float16)
device = get_best_device()
model.to(device)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb):

In [43]:
# print(tokenizer)
print(model)

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary_emb):

In [108]:
current = tokenizer("1+2=", return_tensors="pt").to(device)
print(current)

{'input_ids': tensor([[128000,     16,     10,     17,     28]], device='mps:0'), 'attention_mask': tensor([[1, 1, 1, 1, 1]], device='mps:0')}


In [119]:
past_key_values = DynamicCache()
cache_position = torch.arange(current.input_ids.shape[1], dtype=torch.int64, device=model.device)
outputs = model(
            input_ids=current.input_ids,
            past_key_values=past_key_values,
            use_cache=True,
            cache_position=cache_position
        )
probs = torch.softmax(outputs.logits, dim=-1)

top_probs, _ = torch.max(probs, dim=-1)
print("top_probs:", top_probs,)

print(outputs.logits.shape)
next_token_logits = outputs.logits[:, -1, :]
print(outputs)
next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
print(next_token_id)
pred = tokenizer.decode(next_token_id[0], skip_special_tokens=True)
print(pred)
past_key_values = outputs.past_key_values

top_probs: tensor([[0.3066, 0.3652, 0.4668, 0.2539, 0.4766]], device='mps:0',
       dtype=torch.bfloat16, grad_fn=<MaxBackward0>)
torch.Size([1, 5, 128256])
CausalLMOutputWithPast(loss=None, logits=tensor([[[ 7.0938,  9.0625, 13.3750,  ..., -3.7656, -3.7656, -3.7656],
         [ 8.7500, 11.0625,  9.6875,  ..., -2.4219, -2.4219, -2.4219],
         [10.3750,  7.0000,  8.7500,  ..., -1.2578, -1.2578, -1.2656],
         [12.8750, 12.3125, 11.8750,  ..., -0.6797, -0.6797, -0.6797],
         [10.9375,  7.8750, 10.3750,  ..., -0.5742, -0.5742, -0.5742]]],
       device='mps:0', dtype=torch.bfloat16, grad_fn=<LinearBackward0>), past_key_values=DynamicCache(), hidden_states=None, attentions=None)
tensor([[18]], device='mps:0')
3


In [76]:
next_token = tokenizer.encode("son", return_tensors="pt", add_special_tokens=False).to(device)
print(next_token)
# append the new token to the input
current = torch.cat([current, next_token], dim=-1)
current

tensor([[942]], device='mps:0')


tensor([[128000,     16,     10,     17,     28,    942]], device='mps:0')

In [5]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache

model_id = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)

past_key_values = DynamicCache()
messages = [{"role": "user", "content": "Hello, what's your name."}]
inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt", return_dict=True).to(model.device)

generated_ids = inputs.input_ids
cache_position = torch.arange(inputs.input_ids.shape[1], dtype=torch.int64, device=model.device)
max_new_tokens = 10
print(tokenizer.decode(generated_ids[0], skip_special_tokens=True), flush=True, end="")


for _ in range(max_new_tokens):
    outputs = model(**inputs, cache_position=cache_position, past_key_values=past_key_values, use_cache=True)
    # Greedily sample one next token
    next_token_ids = outputs.logits[:, -1:].argmax(-1)
    generated_ids = torch.cat([generated_ids, next_token_ids], dim=-1)

    # Prepare inputs for the next generation step by leaaving unprocessed tokens, in our case we have only one new token
    # and expanding attn mask for the new token, as explained above
    attention_mask = inputs["attention_mask"]
    attention_mask = torch.cat([attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
    inputs = {"input_ids": next_token_ids, "attention_mask": attention_mask}
    cache_position = cache_position[-1:] + 1 # add one more position for the next token
    print(tokenizer.decode(next_token_ids[0], skip_special_tokens=True), flush=True, end="")

# print(tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0])

<|user|>
Hello, what's your name. 
<|assistant|>
MynameisSarah.
<|

In [18]:
EXAMPLE = [
    {"role": "user", "content": "Hi and welcome"},
    {"role": "assistant", "content": "waiting"},
    {"role": "user", "content": " to tech support."},
    {"role": "assistant", "content": "waiting"},
    {"role": "user", "content": " For sales,"},
    {"role": "assistant", "content": "waiting"},
    {"role": "user", "content": " please press 1."},
    {"role": "assistant", "content": "Hello world"},
    {"role": "user", "content": " Press 2 for"},
    {"role": "assistant", "content": "Hello world"},
    {"role": "user", "content": " for support."},
    {"role": "assistant", "content": "waiting"},
    {"role": "user", "content": " If you require"},
    {"role": "assistant", "content": "waiting"},
    {"role": "user", "content": " support with your billing."},
    {"role": "assistant", "content": "waiting"},
    {"role": "user", "content": " Please press 3."},
    {"role": "assistant", "content": "Hello world"},
]
formatted_chat = tokenizer.apply_chat_template(EXAMPLE, tokenize=True, return_dict=True, add_generation_prompt=False)
chat_text = tokenizer.decode(formatted_chat.input_ids, skip_special_tokens=True)
print(chat_text)

<|user|>
Hi and welcome 
<|assistant|>
waiting 
<|user|>
 to tech support. 
<|assistant|>
waiting 
<|user|>
 For sales, 
<|assistant|>
waiting 
<|user|>
 please press 1. 
<|assistant|>
Hello world 
<|user|>
 Press 2 for 
<|assistant|>
Hello world 
<|user|>
 for support. 
<|assistant|>
waiting 
<|user|>
 If you require 
<|assistant|>
waiting 
<|user|>
 support with your billing. 
<|assistant|>
waiting 
<|user|>
 Please press 3. 
<|assistant|>
Hello world 

