In [1]:
from transformers import LlamaForCausalLM, LlamaConfig, LlamaTokenizer, GenerationConfig
import torch
from copy import deepcopy
from time import time
from tqdm import tqdm
import random 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

prompt = "typing import List\ndef bucket_sort(A: List):"

checkpoint = "facebook/layerskip-llama2-7B"
device = "cuda" if torch.cuda.is_available() else "cpu"

max_new_tokens = 512
do_sample = True
top_p = 0.9
temperature = 0.6

warmup = 2
repeat = 10

# Range for dynamic early exit
min_exit = 2  # <-- minimum layer to exit
max_exit = 12  # <-- maximum layer to exit

config = LlamaConfig.from_pretrained(checkpoint)
model = LlamaForCausalLM.from_pretrained(checkpoint, config=config, torch_dtype=torch.float16)

model.to(device)

tokenizer = LlamaTokenizer.from_pretrained(checkpoint, use_fast=False)
inputs = tokenizer(prompt, return_tensors="pt").to(device)

generation_config = {
    "max_new_tokens": max_new_tokens,
    "do_sample": do_sample,
    "top_p": top_p, 
    "temperature": temperature,
    "pad_token_id": tokenizer.eos_token_id,
}

Downloading shards: 100%|██████████| 3/3 [01:06<00:00, 22.22s/it]
Loading checkpoint shards: 100%|██████████| 3/3 [00:00<00:00,  4.69it/s]


In [3]:
def create_dynamic_assistant_model(base_model, exit_layer):
    weights_memo = {id(w): w for w in base_model.parameters()}
    assistant_model = deepcopy(base_model, memo=weights_memo)
    assistant_model.model.layers = assistant_model.model.layers[:exit_layer]
    del assistant_model.model.layers[exit_layer:]
    return assistant_model

In [4]:
# Warmup
print("Warmup")
for i in tqdm(range(warmup)):
    _ = model.generate(**inputs, **generation_config)
    early_exit = random.randint(min_exit, max_exit)  # Random early exit
    assistant_model = create_dynamic_assistant_model(model, early_exit).to(device)
    _ = model.generate(**inputs, **generation_config, assistant_model=assistant_model)

print("Autoregressive Decoding (no early exit)")
total_time = 0
total_tokens = 0
for i in tqdm(range(repeat)):
    start = time()
    outputs = model.generate(**inputs, **generation_config)
    total_time += time() - start
    total_tokens += outputs.numel()
print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
print("\n\t=========================")
print(f"\tAverage Generation Time: {total_time / repeat:.2f} s")
print(f"\tAverage Tokens per Second: {total_tokens / total_time:.2f} tokens per sec\n\n")

print("Self-Speculative Decoding (dynamic early exit)")
total_time = 0
total_tokens = 0
for i in tqdm(range(repeat)):
    early_exit = random.randint(min_exit, max_exit)  # <-- Dynamic early exit for each generation
    assistant_model = create_dynamic_assistant_model(model, early_exit).to(device)
    start = time()
    outputs = model.generate(**inputs, **generation_config, assistant_model=assistant_model)
    total_time += time() - start
    total_tokens += outputs.numel()
print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])
print("\n\t=========================")
print(f"\tAverage Generation Time: {total_time / repeat:.2f} s")
print(f"\tAverage Tokens per Second: {total_tokens / total_time:.2f} tokens per sec")

Warmup


  0%|          | 0/2 [00:00<?, ?it/s]Starting from v4.46, the `logits` model output will have the same type as the model (except at train time, where it will always be FP32)
From v4.47 onwards, when a model cache is to be returned, `generate` will return a `Cache` instance instead by default (as opposed to the legacy tuple of tuples format). If you want to keep returning the legacy format, please set `return_legacy_cache=True`.
100%|██████████| 2/2 [01:01<00:00, 30.77s/it]


Autoregressive Decoding (no early exit)


100%|██████████| 10/10 [01:50<00:00, 11.03s/it]


typing import List
def bucket_sort(A: List):
    """
    Sorts a list using bucket sort.
    """
    # initialize bucket array
    buckets = [0] * len(A)
    # initialize index
    index = 0
    # initialize max
    max = 0
    # loop through array
    for i in range(len(A)):
        # if value is larger than max
        if A[i] > max:
            # set max to value
            max = A[i]
            # increment index
            index += 1
        # set bucket value
        buckets[index] = A[i]
    # sort buckets
    for i in range(len(buckets)):
        # if bucket value is greater than max
        if buckets[i] > max:
            # swap max and bucket
            max = buckets[i]
            buckets[i] = max
    # return sorted array
    return buckets


# test code


# test case 1
A = [4, 3, 1, 2, 5]
print(bucket_sort(A))


# test case 2
A = [1, 2, 3, 4, 5, 6]
print(bucket_sort(A))


	Average Generation Time: 11.03 s
	Average Tokens per Second: 31.12 tokens per sec


Self-Speculat

100%|██████████| 10/10 [01:46<00:00, 10.66s/it]

typing import List
def bucket_sort(A: List):
    """
    Bucket sort is a sorting algorithm that divides the input list into a number of buckets, sorts each bucket, and then merges the sorted buckets.
    The input list is split into a number of buckets, each bucket is sorted, and the sorted buckets are merged together to form the output.
    :param A: List
    :return: List
    """
    # bucket_size = int(len(A)/2)
    bucket_size = 1
    buckets = []
    for i in range(len(A)):
        if i%bucket_size == 0:
            buckets.append([])
        buckets[-1].append(A[i])
    for i in range(len(buckets)-1):
        buckets[i+1] = sorted(buckets[i])
    buckets[-1] = sorted(buckets[-1])
    return buckets


if __name__ == "__main__":
    A = [3, 1, 2, 5, 4, 7, 6, 9, 8]
    B = bucket_sort(A)
    print(B)


	Average Generation Time: 10.62 s
	Average Tokens per Second: 36.27 tokens per sec





: 