In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM



In [2]:
# We will use the Google Llama-3.2 3B Instruct as the model we want to accelerate (3B parameters)
target_model_name = "meta-llama/Llama-3.2-3B-Instruct"
target = AutoModelForCausalLM.from_pretrained(target_model_name)

# We will use the Google Llama-3.2 1B Instruct as the drafter model (1B parameters)
drafter_model_name = "meta-llama/Llama-3.2-1B-Instruct"
drafter = AutoModelForCausalLM.from_pretrained(drafter_model_name)

# Don't forget to load the tokenizer
tokenizer = AutoTokenizer.from_pretrained(target_model_name)

2024-11-14 12:19:43.389204: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-11-14 12:19:43.421964: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-11-14 12:19:43.421999: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-11-14 12:19:43.422809: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-11-14 12:19:43.428254: I tensorflow/core/platform/cpu_feature_guar

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [3]:
prefix = "Translate to English: Je m'appelle Romain. N'hésitez pas à contribuer à mon projet !"

chat_templated = f"<bos><start_of_turn>user\n{prefix}<end_of_turn>\n<start_of_turn>model\n" # Gemma chat template
input_ids = tokenizer(chat_templated, return_tensors="pt").input_ids
input_ids = input_ids[0].tolist() # Generation methods require a list of ids

In [4]:
from base_decoding import  autoregressive_generate
from speculative_decoding import speculative_generate
from utils.logits_processors import NucleusProcessor

# Parameters
gen_len = 100       # Maximum number of tokens generated (could over pass when using speculative decoding)
gamma = 4           # Number of drafts generated by the drafter model at each step
logits_processor = NucleusProcessor(temperature=.6, top_p=.9) # Nucleus sampling with p=0.9 and T=0.6

# Generate text using the classic auto-regressive decoding (slow)
output_ids_ar = autoregressive_generate( # or autoregressive_generate_encoder_decoder for encoder-decoder models
                input_ids,
                target,
                logits_processor=logits_processor,
                max_gen_len=gen_len,
                eos_tokens_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
            )
output_ar = tokenizer.decode(output_ids_ar, skip_special_tokens=True)

# Generate text using the speculative decoding (faster)
output_ids_sd, alpha = speculative_generate( # or speculative_generate_encoder_decoder for encoder-decoder models
                input_ids,
                drafter,
                target,
                logits_processor=logits_processor,
                gamma=gamma,
                max_gen_len=gen_len,
                eos_tokens_id=tokenizer.eos_token_id,
                pad_token_id=tokenizer.pad_token_id,
            )
output_sd = tokenizer.decode(output_ids_sd, skip_special_tokens=True)

print("Auto-regressive decoding:", output_ar)
print("Speculative decoding:", output_sd)
print("Acceptance rate:", alpha) # Number of drafts accepted by the target model divided by the number of drafts generated

ModuleNotFoundError: No module named 'utils.logits_processor'

Data loaded
torch.Size([90000000]) torch.Size([5000000])


Datasets created
tensor([116, 101, 114,  93,  93,  32, 110,  97, 109, 101, 100,  32,  91,  91,
         65, 114, 105, 117, 115,  93,  93,  32,  98, 101, 103,  97, 110,  32,
        116, 101,  97,  99, 104, 105, 110, 103,  32, 116, 104,  97, 116,  32,
        116, 104, 101, 114, 101,  32, 119,  97, 115,  32,  97,  32, 116, 105,
        109, 101,  32,  98, 101, 102, 111, 114, 101,  32,  91,  91,  71, 111,
        100,  93,  93,  32, 116, 104, 101,  32,  70,  97, 116, 104, 101, 114,
         32,  98, 101, 103,  97, 116,  32,  91,  91,  74, 101, 115, 117, 115,
         93,  93,  32, 119, 104, 101, 110,  32, 116, 104, 101,  32, 108,  97,
        116, 116, 101, 114,  32, 100, 105, 100,  32, 110, 111, 116,  32, 101,
        120, 105, 115, 116,  46,  32,  65, 116, 104,  97, 110,  97, 115, 105,
        117, 115,  32,  97,  99,  99, 111, 109, 112,  97, 110, 105, 101, 100,
         32,  91,  91,  65, 108, 101, 120,  97, 110, 100, 101, 114,  32, 111,
        102,  32,  65, 108, 101, 120,  97, 110,