# DeepMind Paper Implementation

- delete models
- huggingface-cli delete-cache


- https://github.com/shreyansh26/Speculative-Sampling

In [73]:
import torch
from tqdm import tqdm
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import numpy as np
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
import time
from typing import Tuple
import pandas as pd
import matplotlib.pyplot as plt
import random
from transformers import set_seed
import os 
import dotenv

In [74]:
torch.manual_seed(42)
random.seed(42)
set_seed(42)
np.random.seed(42)

In [75]:
import os
import dotenv
dotenv.load_dotenv('/Users/vashisth/Desktop/research-new/speculative-decoding/.env')
hf_api = os.getenv('hf_api') 
print(os.getenv('hf_api') is not None)

True


In [52]:
# def sampling(p):
#     output = torch.multinomial(p, 1)
#     return output.reshape(1,-1)

# ' x is the tokenized prompt, max_new_tokens = N'
# def autoregressive_generation(x, model, max_new_tokens):
#     n = len(x)
#     N = max_new_tokens
#     T = len(x) + N
#     while n<T:
#         logits = model_fn(model,x)
#         output = sampling(logits)
#         x = torch.cat((x, output), dim=-1)
#         n += 1
#     return x



# (f)_+ function in the paper
def max_fn(x): 
    x_max = torch.where(x > 0, x, 0)
    x_max = x_max.float()
    sum_ = torch.sum(x_max, dim = -1, keepdim=True) + 1e-8
    x_max.div_( sum_)
    return x_max

In [12]:
def get_distribution(logits, temperature, epsilon = 1e-8):
    logits /= (temperature + epsilon)
    probability = F.softmax(logits, dim = -1)
    return probability

In [13]:
def sample(logits, temperature):
    output = get_distribution(logits, temperature)
    output = torch.multinomial(output, num_samples=1)
    return output[0]

In [14]:

def sample_from_draft_model(model, initial_prompt_seq, new_tokens, temperature=1.0):
    fin_prompt_seq = initial_prompt_seq.detach().clone()
    out_logits = []

    for _ in range(new_tokens):
        sample_token_logits = model(fin_prompt_seq).logits[:, -1, :]
        sample_token = sample(sample_token_logits, temperature=temperature)
        fin_prompt_seq = torch.concat([fin_prompt_seq, sample_token.unsqueeze(0)], dim=-1)
        out_logits.append(sample_token_logits)

    out_logits = torch.stack(out_logits, dim=1)
    return fin_prompt_seq, out_logits
    

In [275]:
def autoregressive_sampling(model, initial_prompt_seq, max_new_tokens, temperature=1.0):
    n = initial_prompt_seq.shape[-1]
    target_len = n + max_new_tokens
    fin_prompt_seq = initial_prompt_seq.detach().clone()

    while n < target_len:
        sample_token_logits = model(fin_prompt_seq).logits[:, -1, :]
        sample_token = sample(sample_token_logits, temperature=temperature)
        # fin_prompt_seq = torch.concat([fin_prompt_seq, sample_token[None,...]], dim=-1)
        sample_token_unsqueezed = sample_token.unsqueeze(0) # to add batch dim 
        fin_prompt_seq = torch.concat([fin_prompt_seq, sample_token_unsqueezed], dim=-1)

        n += 1
    return fin_prompt_seq

In [278]:
def speculative_sampling(target_model, draft_model, initial_prompt_seq, max_new_tokens, tokenizer, lookahead=3, temperature=1.0, debug=True):

    assert initial_prompt_seq.shape[0] == 1, 'Batch size should be 1'

    n = initial_prompt_seq.shape[-1]
    target_len = n + max_new_tokens
    fin_prompt_seq = initial_prompt_seq.detach().clone()

    while n < target_len:
        if debug:
            print('____________________')
            print('n: ', n)
        n_orig = n
        N = fin_prompt_seq.shape[-1]
        draft_outputs, draft_logits = sample_from_draft_model(draft_model, fin_prompt_seq, new_tokens=lookahead, temperature=temperature)
        
        if debug:
            print(f"Possible continuations: {tokenizer.decode(draft_outputs[0,n_orig:], skip_special_tokens=True)}")

        target_logits = target_model(draft_outputs).logits[:, -lookahead-1:, :]

        target_model_distribution = get_distribution(target_logits, temperature)
        draft_model_distribution = get_distribution(draft_logits, temperature)

        accepted_flag = True
        
        for t in range(lookahead):
            numerator = target_model_distribution[:, t, draft_outputs[0, N+t]]
            denominator = draft_model_distribution[:, t, draft_outputs[0, N+t]]
            ratio = (numerator / denominator)
            r = torch.rand_like(numerator) # Uniform[0]
            ones_tensor = torch.ones_like(numerator)

            # Rejection Sampling
            ## Acceptance
            if (r < torch.min(ones_tensor, ratio)).any():
                fin_prompt_seq = torch.concat([fin_prompt_seq, draft_outputs[:, N+t].unsqueeze(dim=-1)], dim=-1)
                n += 1
                
                if debug:
                    accepted_token = tokenizer.decode(draft_outputs[0, N+t])
                    print(f"Accepted token: ''{accepted_token}'' ")

            ## Rejection
            else:
                new_dist = (target_model_distribution[:, t, :] - draft_model_distribution[:, t, :])
                new_dist = max_fn(new_dist)
                token_id = torch.multinomial(new_dist, num_samples=1)[0]
                fin_prompt_seq = torch.concat([fin_prompt_seq, token_id.unsqueeze(0) ], dim=-1)
                
                accepted_flag = False
                
                if debug:
                    rejected_token = tokenizer.decode(draft_outputs[0, N+t])
                    new_token = tokenizer.decode(token_id)
                    print(f"Rejected token: ''{rejected_token}'', Replaced with: {new_token}")
                break
            
            # Print full sentence after every token update
        if debug:
            full_sentence = tokenizer.decode(fin_prompt_seq[0], skip_special_tokens=True)
            print(f"Full sentence: {full_sentence}")
            
        if accepted_flag:
            sample_token = sample(target_logits[:, -1, :], temperature=temperature)
            fin_prompt_seq = torch.concat([fin_prompt_seq, sample_token.unsqueeze(0) ], dim=-1)

        if debug:
            print(f"Accepted continuations: {tokenizer.decode(fin_prompt_seq[0,n_orig:], skip_special_tokens=True)}")

        n += 1

    return fin_prompt_seq

In [272]:
draft_tokenizer = AutoTokenizer.from_pretrained("gpt2")
draft_model = AutoModelForCausalLM.from_pretrained("gpt2")
draft_generator = pipeline('text-generation', model=draft_model, tokenizer=draft_tokenizer)

In [273]:
target_tokenizer = AutoTokenizer.from_pretrained("gpt2-large")
target_model = AutoModelForCausalLM.from_pretrained("gpt2-large")
target_generator = pipeline('text-generation', model=target_model, tokenizer=target_tokenizer)
tokenizer = target_tokenizer

In [256]:
question = 'the quick brown fox'

In [257]:
inputs = tokenizer(question, return_tensors="pt").to(device)

In [279]:
tokens = speculative_sampling(target_model, draft_model, initial_prompt_seq=inputs.input_ids, max_new_tokens= 11, lookahead=3, tokenizer=tokenizer, temperature=0., debug=True)

____________________
n:  4
Possible continuations: es, and
Rejected token: ''es'', Replaced with:  jumps
Full sentence: the quick brown fox jumps
Accepted continuations:  jumps
____________________
n:  5
Possible continuations:  up and down
Rejected token: '' up'', Replaced with:  over
Full sentence: the quick brown fox jumps over
Accepted continuations:  over
____________________
n:  6
Possible continuations:  the fence and
Accepted token: '' the'' 
Rejected token: '' fence'', Replaced with:  lazy
Full sentence: the quick brown fox jumps over the lazy
Accepted continuations:  the lazy
____________________
n:  8
Possible continuations: , lazy fox
Rejected token: '','', Replaced with:  dog
Full sentence: the quick brown fox jumps over the lazy dog
Accepted continuations:  dog
____________________
n:  9
Possible continuations:  and runs off
Rejected token: '' and'', Replaced with: "
Full sentence: the quick brown fox jumps over the lazy dog"
Accepted continuations: "
____________________

In [280]:
new_tokens = len(tokens[0]) - len(inputs.input_ids[0])
print(new_tokens)
print(tokenizer.decode(tokens[0]))

11
the quick brown fox jumps over the lazy dog" is a common saying.


In [276]:
tokens = autoregressive_sampling(target_model, initial_prompt_seq=inputs.input_ids, max_new_tokens=11, temperature=0.)

new_tokens = len(tokens[0]) - len(inputs.input_ids[0])
print(new_tokens)
print(tokenizer.decode(tokens[0]))

11
the quick brown fox jumps over the lazy dog" is a common saying.
