# Speculative Decoding using 
- GPTQ: https://arxiv.org/abs/2210.17323
- https://huggingface.co/docs/transformers/v4.34.0/main_classes/quantization

In [4]:
import torch
from tqdm import tqdm
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import numpy as np
import time
from typing import Tuple
import pandas as pd
import matplotlib.pyplot as plt
import random
from transformers import set_seed
import os 
import dotenv
from speculative_sampling_helper import *

In [7]:
from transformers import AutoConfig, AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig, TextStreamer, GPTQConfig

In [6]:
model_names = 'Mistral-16vs4'

In [None]:
model_id = "mistralai/Mistral-7B-Instruct-v0.1"
tokenizer = AutoTokenizer.from_pretrained(model_id)

In [None]:
# gptq_config_2 = GPTQConfig(bits=2, tokenizer=tokenizer)
# quantization_4 = GPTQConfig(bits=4, tokenizer=tokenizer)

In [None]:
gptq_path = 'TheBloke/Mistral-7B-v0.1-GPTQ'

gptq_config_4 = GPTQConfig(bits=4, tokenizer=tokenizer,dataset="c4")
draft_model = AutoModelForCausalLM.from_pretrained(gptq_path, quantization_config=gptq_config_4, device_map="auto")
# .to(device)
draft_generator = pipeline('text-generation', model=draft_model, tokenizer=tokenizer)

In [None]:
target_model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto")
# .to(device)
target_generator = pipeline('text-generation', model=target_model, tokenizer=tokenizer)

In [None]:
question = 'The quick brown fox'
inputs = tokenizer(question, return_tensors="pt").to(device)

In [None]:
tokens = speculative_sampling(target_model, draft_model, initial_prompt_seq=inputs.input_ids, max_new_tokens= 20, lookahead=4, tokenizer=tokenizer, temperature=0., debug=True)

In [None]:
new_tokens = len(tokens[0]) - len(inputs.input_ids[0])
print(new_tokens)
print(tokenizer.decode(tokens[0]))

In [None]:
tokens = autoregressive_sampling(target_model, initial_prompt_seq=inputs.input_ids, max_new_tokens=11, temperature=0.)
new_tokens = len(tokens[0]) - len(inputs.input_ids[0])
print(new_tokens)
print(tokenizer.decode(tokens[0]))

---

In [None]:
prompt = 'United States in the year 2025'
temperature = 0
max_lengths = [32, 64, 128, 256]  # example max_lengths
lookahead_ks = [1, 2, 3, 4, 5, 7, 8]  # exaample lookahead_ks

# Store the results
results = []

# Loop over different max_length and k values
for max_length in max_lengths:
    # Run autoregressive sampling
    output_ar, time_ar, new_tokens_ar = sampling_test(prompt, tokenizer, 'autoregressive', target_model, draft_model, max_new_tokens=max_length)
    
    results.append(['autoregressive', max_length, 0, time_ar, output_ar, new_tokens_ar])
    
    for k in lookahead_ks:
        # Run speculative sampling
        output_sp, time_sp, new_tokens_sp = sampling_test(prompt, tokenizer, 'speculative', target_model, draft_model, max_new_tokens=max_length, lookahead_k=k)

        results.append(['speculative', max_length, k, time_sp, output_sp, new_tokens_sp])
# Create a DataFrame
df = pd.DataFrame(results, columns=['Sampling Method', 'Max Length', 'K Values', 'Time Taken', 'Text Generated', 'New Tokens'])

# Save the results to a CSV file
csv_file_path = f'{model_names}-sampling-times-temp{temperature}.csv'
df.to_csv(csv_file_path, index=False)

In [None]:
# Define the unique N values (Max Length) and K values for lookahead
max_lengths_sorted = sorted(df['Max Length'].unique())
lookahead_ks = sorted(df[df['Sampling Method'] == 'speculative']['K Values'].unique())
ind = np.arange(len(max_lengths_sorted))
bar_width = 0.1


# Adjust the plot aesthetics as per the user's request
fig, ax = plt.subplots(figsize=(10, 6))

# Set the color for the bars
color_speculative = 'lightgreen'  # A more pleasant color for speculative bars
color_autoregressive = 'lightblue'  # A more pleasant color for autoregressive bars

# Iterate over max_lengths and k_values to create bar positions
for i, max_length in enumerate(max_lengths_sorted):
    # Select subset of DataFrame for the current N (Max Length)
    subset = df[df['Max Length'] == max_length]
    
    # Plot bars for autoregressive method
    autoregressive_time = subset[subset['Sampling Method'] == 'autoregressive']['Time Taken'].values
    if autoregressive_time.size > 0:
        ax.bar(ind[i] - bar_width/2, autoregressive_time, bar_width, label='Autoregressive' if i == 0 else "", 
               color=color_autoregressive, edgecolor='black')
    
    # Plot bars for speculative method with different K values
    speculative_subset = subset[subset['Sampling Method'] == 'speculative']
    for j, k_value in enumerate(speculative_subset['K Values']):
        time = speculative_subset[speculative_subset['K Values'] == k_value]['Time Taken'].values
        if time.size > 0:
            bar = ax.bar(ind[i] + (j+0.5)*bar_width, time, bar_width, label=f'Speculative K' if i == 0 else "", 
                         color=color_speculative, edgecolor='black')
            # Annotate K value on the bar
            ax.annotate(f'K={k_value}',
                        xy=(bar[0].get_x() + bar[0].get_width() / 2, bar[0].get_height()),
                        xytext=(0, 3),  # 3 points vertical offset
                        textcoords="offset points",
                        ha='center', va='bottom')

# Add grid to the plot with reduced alpha for less bold lines
ax.yaxis.grid(True, linestyle='--', which='major', color='grey', alpha=0.5)

# Set the labels and legend
ax.set_xticks(ind)
ax.set_xticklabels(max_lengths_sorted)
ax.set_xlabel('Max Length (N)')
ax.set_ylabel('Time Taken (seconds)')
ax.set_title(f'Time Taken for Different Sampling Methods and K Values (temp = {temperature})')
ax.legend()

# Show the plot
plt.tight_layout()

# Save the plot to a file
plot_file_path = f'{model_names}-bar-plot-sampling-times-temp{temperature}.png'
plt.savefig(plot_file_path)
plt.show()

In [None]:
max_lengths_sorted = sorted(df['Max Length'].unique())

for max_length in max_lengths_sorted:
    k = df[df['Max Length'] == max_length]['K Values']
    time_list = df[df['Max Length'] == max_length]['Time Taken'] * 1000
    plt.figure(figsize=(10, 6))
    plt.plot(k, time_list, label=f"Max Length = {max_length}", linestyle='--', marker='x')
    plt.plot(k, time_list, 'x', color='red')

    plt.grid(True, axis='y', linestyle='--', which='major', color='grey', alpha=0.5)
    plt.xlabel('Number of draft tokens (K)')
    plt.ylabel('Time Taken (m seconds)')
    plt.title(f'Time Taken for K Values for max-length = {max_length} (temp = {temperature})')
    plt.legend()
    plt.savefig(f'{model_names}-time-taken-max-length-{max_length}.png')
    plt.show()
    plt.clf()