In [43]:
import time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

# Load speculative (faster) model and verification (slower) model
speculative_model_name = "gpt2"  # Faster model
verification_model_name = "meta-llama/Llama-3.2-1B-Instruct"  # Accurate model

speculative_tokenizer = AutoTokenizer.from_pretrained(speculative_model_name)
speculative_model = AutoModelForCausalLM.from_pretrained(speculative_model_name)

verification_tokenizer = AutoTokenizer.from_pretrained(verification_model_name)
verification_model = AutoModelForCausalLM.from_pretrained(verification_model_name)

# Move models to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
speculative_model.to(device)
verification_model.to(device)





LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 2048)
    (layers): ModuleList(
      (0-15): 16 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (k_proj): Linear(in_features=2048, out_features=512, bias=False)
          (v_proj): Linear(in_features=2048, out_features=512, bias=False)
          (o_proj): Linear(in_features=2048, out_features=2048, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (up_proj): Linear(in_features=2048, out_features=8192, bias=False)
          (down_proj): Linear(in_features=8192, out_features=2048, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((2048,), eps=1e-05)
      )
    )
    (norm):

In [38]:
def speculative_selection(speculative_probs, verification_probs):
    # Find first index where verification probability is higher than speculative probability
    first_ind = 0
    break_flag = True
    while(break_flag):
        break_flag = False
        for i in range(first_ind, len(speculative_probs)):
            if verification_probs[i] < speculative_probs[i]:
                first_ind = i
                break_flag = True
                break
            
        # Reject with probability 1 - verification_probs[first_ind]/speculative_probs[first_ind]
        x = torch.rand(1)
        if x > verification_probs[first_ind]/speculative_probs[first_ind]:
            return first_ind

    return len(verification_probs) - 1

def return_speculative_tokens(speculative_probs, speculative_tokens, verification_probs, verification_tokens,accepted_token_count):
    first_ind = speculative_selection(speculative_probs, verification_probs)
    accepted_token_count += first_ind
    print(accepted_token_count)
    # print(verification_tokens.shape)
    if(first_ind == len(verification_probs) - 1):
        speculative_tokens = torch.cat([speculative_tokens, verification_tokens[first_ind]])
    
    else:
        speculative_tokens = speculative_tokens[:first_ind + 1]
        speculative_tokens[first_ind] = verification_tokens[first_ind]
    
    return speculative_tokens,accepted_token_count

def speculative_decoding(prompt, max_length=50, speculative_steps=3):
    # Tokenize input
    speculative_inputs = speculative_tokenizer(prompt, return_tensors="pt").to(device)
    verification_inputs = verification_tokenizer(prompt, return_tensors="pt").to(device)
    
    # Initialize generated sequence with prompt tokens
    generated_tokens = speculative_inputs['input_ids']
    # print(generated_tokens)
    # print(speculative_tokenizer.decode(generated_tokens[0]))
    # Initialize counters for acceptance ratio calculation
    speculative_token_count = 0
    accepted_token_count = 0

    # Start the time for speculative decoding
    start_time = time.time()

    for _ in range(max_length):
        # Speculative model generates multiple tokens (speculative_steps)
        spec_tokens = []
        spec_probs = []
        for _ in range(speculative_steps):
            with torch.no_grad():
                speculative_outputs = speculative_model(generated_tokens, return_dict=True)
                speculative_logits = speculative_outputs.logits
                # Use top-k sampling or any heuristic to select speculative tokens
                speculative_probs = torch.softmax(speculative_logits[:, -1, :], dim=-1)
                # next_speculative_tokens = torch.topk(speculative_probs, k=5, dim=-1).indices
            
                # next_speculative_prob = torch.topk(speculative_probs, k=5, dim=-1).values
                # torch.multinomial
                next_speculative_tokens = torch.multinomial(speculative_probs, num_samples=1)
                next_speculative_prob = speculative_probs[0][next_speculative_tokens[0]]
                spec_tokens.append(next_speculative_tokens[0])
                spec_probs.append(next_speculative_prob[0])
                generated_tokens = torch.cat([generated_tokens, next_speculative_tokens], dim=-1)
                # print(generated_tokens)
                # next_speculative_prob = torch.max(speculative_probs, dim=-1).values
                # next_speculative_tokens = torch.argmax(speculative_probs)
            # print(next_speculative_tokens, next_speculative_prob)
                # print(speculative_tokenizer.decode(next_speculative_tokens[0], skip_special_tokens=True))
            # Track speculative token count
        # print(spec_tokens)
        # print(spec_probs)
        spec_tokens = torch.stack(spec_tokens).reshape(1,-1)
        spec_probs = torch.stack(spec_probs).reshape(1,-1)
        speculative_token_count += 1
        # print(spec_tokens)
        # print(verification_inputs['input_ids'])
        new_verification_input = torch.cat([verification_inputs['input_ids'], spec_tokens[0].reshape(1,-1)], dim=-1)
        # print(new_verification_input)
        # print(new_verification_input.shape)
        llm_output = verification_model(new_verification_input, return_dict=True)
        verification_logits = llm_output.logits
        # print(verification_logits.shape)
        verification_probs = torch.softmax(verification_logits[:,-speculative_steps-1:,:], dim=-1)
        # print(verification_probs.shape)
        next_verification_tokens = torch.tensor([], dtype=torch.long)  
        next_verification_prob = torch.tensor([], dtype=torch.float)   
        for i in range(speculative_steps+1):
            # Sample the next token from the multinomial distribution of the large model's verification probabilities
            sampled_token = torch.multinomial(verification_probs[0, i], num_samples=1)

            # Concatenate the newly sampled token with the previous tokens
            next_verification_tokens = torch.cat((next_verification_tokens, sampled_token), dim=0) 
            
            # Get the probability of the sampled token from the verification model's output
            token_prob = verification_probs[0, i, sampled_token]
            
            # Concatenate the probability of the sampled token
            next_verification_prob = torch.cat((next_verification_prob, token_prob), dim=0)

        # Print shapes of the final tensors
        # print(next_verification_tokens.shape)
        # print(next_verification_prob.shape)
        spec_tokens,accepted_token_count = return_speculative_tokens(spec_probs[0], spec_tokens[0], next_verification_prob, next_verification_tokens,accepted_token_count)

    # End the time measurement for speculative decoding
    # end_time = time.time()

    # Calculate total time taken
    # total_time_taken = end_time - start_time

    # Calculate acceptance ratio
    acceptance_ratio = accepted_token_count / speculative_token_count if speculative_token_count > 0 else 0

    # Decode the generated tokens back to text
    final_output = speculative_tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
    
    return final_output, acceptance_ratio


# Run speculative decoding
prompt = "I am Neil Armstrong and I am going to the"
generated_text, acceptance_ratio = speculative_decoding(prompt,speculative_steps=7)

# Output the results
print(f"Generated text: {generated_text}")
# print(f"Time taken for speculative decoding: {time_taken:.4f} seconds")
print(f"Acceptance ratio: {acceptance_ratio:.4f}")


2
3
4
4
5
5
5
7
10
10
10
10
13
13
13
13
13
16
18
21
21
21
21
23
23
24
24
24
24
25
26
27
27
28
29
29
30
31
32
34
35
35
36
36
36
36
38
39
41
44
Generated text: I am Neil Armstrong and I am going to the Moon in October 2015.", Armstrong wrote in an email to the audience.

Barry Matea

Astronomer but has other wishes

To Tim Omars, Jeffrey Bosch wrote a letter to Armstrong

Rick Perry informs him of news[...]

Welcome Ranger Magnum, Eddie. My name is Eddie Rone & myself as Jim Shannon, Dave Humancol, Greg Rosenhaug brown and Mike Flintfield, you are my name is Rick Rampone Tim Rone, out with dad. I'm meetup now & dad will be meeting with you there........✂ Facepalm. Go to Hedda's. Are you In. Rupert Murdoch & the Australian Times?hl Have I said I don't want to see or have I been hit badscar homicide and China theft? Shirt Bangers ..." I will defend you again and I also need you to share with SDG during your absence we will meet at 8:00 p.m. I'm so worried about your scenario.Any trouble co