## Import Models

In [None]:
import torch
import time
from transformers import AutoModelForCausalLM, AutoTokenizer, QuantoConfig
from sampling.base_decoding import autoregressive_generate
from sampling.speculative_decoding import speculative_generate
from ngram_assisted.ngram_assisted import ngram_assisted_speculative_generate
from ngram_assisted.ngram_storage import NGramStorage
from utils.logits_processor import (
    GreedyProcessor,
    MultinomialProcessor,
    NucleusProcessor,
    TopKNucleusProcessor,
    TopKProcessor,
    MCMCProcessor
)

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(device)

drafter_model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# drafter_model = "gpt2"
drafter_quantize = QuantoConfig(
    weights="int8"
)  # QuantoConfig(weights="int8") None

target_model = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# target_model = "gpt2-medium"
target_quantize = QuantoConfig(
    weights="int8"
)  # QuantoConfig(weights="int8")  None
tokenizer_name = target_model
tokenizer = AutoTokenizer.from_pretrained(
    tokenizer_name, trust_remote_code=True
)

drafter = AutoModelForCausalLM.from_pretrained(
    drafter_model,
    quantization_config=drafter_quantize,
    device_map=device,
    trust_remote_code=True,
)
drafter.eval()

target = AutoModelForCausalLM.from_pretrained(
    target_model,
    quantization_config=target_quantize,
    device_map=device,
    trust_remote_code=True,
)
target.eval()


cpu


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(32000, 2048)
    (layers): ModuleList(
      (0-21): 22 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): QLinear(in_features=2048, out_features=2048, bias=False)
          (k_proj): QLinear(in_features=2048, out_features=256, bias=False)
          (v_proj): QLinear(in_features=2048, out_features=256, bias=False)
          (o_proj): QLinear(in_features=2048, out_features=2048, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): QLinear(in_features=2048, out_features=5632, bias=False)
          (up_proj): QLinear(in_features=2048, out_features=5632, bias=False)
          (down_proj): QLinear(in_features=5632, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((2048,), eps=1e-05)
    (rotary

In [None]:

ngram = NGramStorage(n=3, vocab_size=target.config.vocab_size)

end_tokens = [
    tokenizer.eos_token_id,
    tokenizer.convert_tokens_to_ids("<|eot_id|>"),
]  # "<|eot_id|>" is the end of turn token for Llama model.




In [20]:
print(end_tokens)

[50256]


In [None]:
prefix = "what is capital of Canada?"
tokenized = tokenizer(prefix, return_tensors="pt").input_ids[0].tolist()

In [27]:
print(tokenized)

[7454, 2402, 257, 640, 220]


### Autoregression Method

In [28]:
start_time = time.time()
output_ids = autoregressive_generate(
    tokenized,
    target,
    # use_cache=self.cache,
    # max_gen_len=self.gen_len,
    eos_tokens_id=end_tokens,
    # logits_processor=self.processor,
    # debug=self.debug,
)
end_time = time.time()
output = tokenizer.decode(output_ids, skip_special_tokens=True)
print(output)
base_throughput = len(output) / (end_time - start_time)
print(f"Throughput: {base_throughput:.1f} tokens/s")


!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Throughput: 3.1 tokens/s


### Speculative Method

In [29]:
spec_start_time = time.time()
output_ids, accept_rate = speculative_generate(
    tokenized,
    drafter,
    target,
    tokenizer=tokenizer,
    # logits_processor=processor,
    # gamma=gamma,
    # max_gen_len=gen_len,
    # eos_tokens_id=end_tokens,
    # debug=debug,
    # use_cache=cache,
)
spec_end_time = time.time()
spec_output = tokenizer.decode(output_ids, skip_special_tokens=True)
print( spec_output)
print(f"Acceptance rate: {accept_rate:.3f}")
spec_throughput = len(spec_output) / (spec_end_time - spec_start_time)
print(f"Throughput: {spec_throughput:.1f} tokens/s")


!t was a man,!t was a woman,!t was a man,!t was a woman,!t was a man,!t was a woman,!t was!
Acceptance rate: 1.000
Throughput: 3.7 tokens/s


### Ngram specukative Method

In [30]:

ngram_start_time = time.time()
output_ids, accept_rate = ngram_assisted_speculative_generate(
    tokenized,
    ngram,
    target,
    tokenizer=tokenizer,
    # filler_top_k=self.top_k_filler,
    # logits_processor=self.processor,
    # max_gen_len=self.gen_len,
    # eos_tokens_id=self.end_tokens,
    # debug=self.debug,
    # use_cache=self.cache,
    # first_target=True,
    # stop_if_unknown=True,
)
ngram_end_time = time.time()
ngram_output = tokenizer.decode(output_ids, skip_special_tokens=True)
print(ngram_output)
print(f"Acceptance rate: {accept_rate:.3f}")
ngram_throughput = len(ngram_output) / (ngram_end_time - ngram_start_time)
print(f"Throughput: {ngram_throughput:.1f} tokens/s")

!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Acceptance rate: 0.861
Throughput: 8.6 tokens/s
