In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import torch.nn.functional as F
import random
import time

In [2]:
tokenizer_q = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct")
model_q = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-135M-Instruct")
model_q.to("cuda")

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 576, padding_idx=2)
    (layers): ModuleList(
      (0-29): 30 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=576, out_features=576, bias=False)
          (k_proj): Linear(in_features=576, out_features=192, bias=False)
          (v_proj): Linear(in_features=576, out_features=192, bias=False)
          (o_proj): Linear(in_features=576, out_features=576, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=576, out_features=1536, bias=False)
          (up_proj): Linear(in_features=576, out_features=1536, bias=False)
          (down_proj): Linear(in_features=1536, out_features=576, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((576,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((576,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((576,), eps=1e-05)
    (rotary_emb)

In [3]:
tokenizer_p = AutoTokenizer.from_pretrained("HuggingFaceTB/SmolLM2-360M-Instruct")
model_p = AutoModelForCausalLM.from_pretrained("HuggingFaceTB/SmolLM2-360M-Instruct")
model_p.to("cuda")

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(49152, 960, padding_idx=2)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=960, out_features=960, bias=False)
          (k_proj): Linear(in_features=960, out_features=320, bias=False)
          (v_proj): Linear(in_features=960, out_features=320, bias=False)
          (o_proj): Linear(in_features=960, out_features=960, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=960, out_features=2560, bias=False)
          (up_proj): Linear(in_features=960, out_features=2560, bias=False)
          (down_proj): Linear(in_features=2560, out_features=960, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((960,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((960,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((960,), eps=1e-05)
    (rotary_emb)

In [4]:
def predict_next_k_tokens_model_q(tokens, k):
  k_tokens = []
  k_probs = []
  for i in range(k):
    outputs = model_q(input_ids=tokens).logits[:,-1,:]
    outputs = F.softmax(outputs, dim=-1)  #1, vocab_size
    q_index= torch.argmax(outputs, dim=-1)[0].item()
    k_tokens.append(q_index)
    k_probs.append(outputs[0][q_index].item())
    tokens = torch.cat([tokens, torch.tensor([[q_index]], device=tokens.device)], dim=1)
  return k_tokens, k_probs

In [5]:
def parallel_forward_pass_model_p(tokens_plus_next_k, k):
  outputs = model_p(input_ids=tokens_plus_next_k).logits[:,-(k+1):,:] #if we were to predict with 5 tokens and k = 5.
  #if we woul not have done k+1 we would have missed 0th of the generated token because eahc logits predicts the next token.
  #so for k+1 -> it was taking position k
  #and [:, -1,:] predicts a fresh new token given all; in that sense, [:, -2,:] predicts or verifies the last token the small model generated
  outputs = F.softmax(outputs, dim=-1)
  return outputs

In [6]:
def verify(p_probs, k_next_probs, k_next_tokens):
  random.seed(0)
  p_probs = p_probs.squeeze(0)
  accepted_list = []
  for index, (q_prob, q_token_id) in enumerate(zip(k_next_probs, k_next_tokens)): #p(x)>=q(x)
    distribution = p_probs[index]
    p_x = distribution[q_token_id].item()
    if q_prob > p_x:
      rejection_criteria = 1 - (p_x/q_prob)
      r = random.random()
      if r > rejection_criteria:
        accepted_list.append(q_token_id)
      else:
        return accepted_list, distribution
    else:
      accepted_list.append(q_token_id)
  return accepted_list, p_probs[-1]

In [7]:
def main(prompt, k = 5):
  inputs = tokenizer_q(prompt, return_tensors="pt").to("cuda").input_ids
  for _ in range(10):
    k_next_tokens, k_next_probs = predict_next_k_tokens_model_q(inputs, k)
    tokens_plus_next_k = torch.tensor(k_next_tokens, dtype=torch.long, device='cuda').unsqueeze(0)
    tokens_plus_next_k = torch.cat([inputs, tokens_plus_next_k], dim=-1)
    p_probs = parallel_forward_pass_model_p(tokens_plus_next_k, k)
    accepted_list, distribution_p = verify(p_probs, k_next_probs, k_next_tokens)
    next_token_id_p = torch.argmax(distribution_p, dim=-1).item()
    accepted_list.append(next_token_id_p)
    accepted_list = torch.tensor(accepted_list, dtype=torch.long,  device='cuda').unsqueeze(0)
    inputs = torch.cat([inputs, accepted_list], dim=1)
    input_list = inputs[0].tolist()
    decoded_text = tokenizer_q.decode(input_list)
  return decoded_text

In [14]:
start = time.time()
x = main("I look forward to")
time.time() - start

2.75476336479187

In [9]:
num_tokens = len(tokenizer_q(x)["input_ids"])
num_tokens

57

In [10]:
print(x)

I look forward to your response.

Best regards,
Emily<|im_end|>
<|im_start|>assistant
Dear Alex,

I hope this message finds you well. I am reaching out to you regarding a recent project I am working on, which involves analyzing data from a large dataset.


In [15]:
start = time.time()
inputs = tokenizer_p("I look forward to", return_tensors="pt").to("cuda")

output_ids = model_p.generate(
    inputs.input_ids,
    max_new_tokens=num_tokens,
    do_sample=False,
    eos_token_id=None,
    pad_token_id=tokenizer_p.pad_token_id,
)

decoded_text = tokenizer_p.decode(output_ids[0], skip_special_tokens=False)
time.time() - start

2.5592286586761475

In [12]:
print(decoded_text)

I look forward to your thoughts.

Best regards,
Emily<|im_end|>
<|im_start|>assistant
Emily is pleased to share that the proposed project on the impact of climate change on the Arctic region has been accepted for funding. She is eager to discuss the project's potential and the implications of the research.
