In [98]:
import torch
#from transformers import GPT2LMHeadModel, GPT2Tokenizer
from transformers import AutoTokenizer, AutoModelForCausalLM
import time
import matplotlib.pyplot as plt
from tqdm import tqdm


# Отключаем градиенты глобально
torch.set_grad_enabled(False)

# Настройки моделей и параметров
# MAIN_MODEL = 'EleutherAI/gpt-neo-1.3B' # gpt2-xl'
# DRAFT_MODEL = 'EleutherAI/gpt-neo-125M' # 'gpt2'
MAIN_MODEL = 'facebook/opt-1.3b'
DRAFT_MODEL = 'facebook/opt-350m'
# MAIN_MODEL = 'tiiuae/falcon-rw-1b'
# DRAFT_MODEL = 'tiiuae/falcon-rw-0.3b'
MAX_NEW_TOKENS = 12
DRAFT_TOKENS_LEN = 3
PROMPT = "Once upon a time"

In [99]:
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'

In [100]:
# Загрузка токенизатора и моделей
tokenizer = AutoTokenizer.from_pretrained(MAIN_MODEL)
main_model = AutoModelForCausalLM.from_pretrained(MAIN_MODEL).half().to(DEVICE).eval()
draft_model = AutoModelForCausalLM.from_pretrained(DRAFT_MODEL).half().to(DEVICE).eval()

tokenizer_config.json:   0%|          | 0.00/685 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/653 [00:00<?, ?B/s]

vocab.json: 0.00B [00:00, ?B/s]

merges.txt: 0.00B [00:00, ?B/s]

special_tokens_map.json:   0%|          | 0.00/441 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/2.63G [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/2.63G [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/644 [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/663M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/662M [00:00<?, ?B/s]

generation_config.json:   0%|          | 0.00/137 [00:00<?, ?B/s]

In [101]:
# Подготовка входа
input_ids = tokenizer(PROMPT, return_tensors='pt').input_ids.to(DEVICE)
batch_size = input_ids.size(0)

In [132]:
def truncate_past_key_values(past_key_values, n_tokens_to_remove):
    if n_tokens_to_remove <= 0:
        return past_key_values
    truncated = []
    for layer_past in past_key_values:
        key, value = layer_past
        truncated_key = key[:, :, :-n_tokens_to_remove, :]
        truncated_value = value[:, :, :-n_tokens_to_remove, :]
        truncated.append((truncated_key, truncated_value))
    return tuple(truncated)

In [133]:
def autoregressive_decode(input_ids, max_tokens, model=main_model):

    # prefill
    outputs = model(input_ids, use_cache=True)
    past = outputs.past_key_values
    next_token = torch.argmax(outputs.logits[:,-1], dim=-1,keepdim=True)
    # decoding
    tokens = [next_token]
    for i in tqdm(range(max_tokens-1)):
        outputs = model(input_ids=next_token, past_key_values=past, use_cache=True)
        past = outputs.past_key_values
        logits = outputs.logits[:,-1]
        next_token = torch.argmax(logits, dim=-1,keepdim=True)

        tokens.append(next_token)

    generated_tokens_history = torch.cat(tokens, dim=-1)
    return generated_tokens_history

In [149]:
def speculative_decode(input_ids, max_tokens, draft_tokens_len):
    # original_len = input_ids.shape[1]
    generated_tokens_history = torch.zeros(1,0, dtype=torch.int) #input_ids.clone()
    n_accepted_list = []

    # prefill
    draft_outputs = draft_model(input_ids, use_cache=True)
    draft_past = draft_outputs.past_key_values
    main_outputs = main_model(input_ids, use_cache=True)
    main_past = main_outputs.past_key_values
    next_token = torch.argmax(main_outputs.logits[:,-1], dim=-1,keepdim=True)
    # speculative decoding
    with tqdm(total=max_tokens) as pbar:
        while (generated_tokens_history.shape[-1]) < max_tokens:
            # draft generation
            draft_tokens = [next_token]
            for _ in range(draft_tokens_len):
                draft_outputs = draft_model(input_ids=next_token, past_key_values=draft_past, use_cache=True)
                draft_past = draft_outputs.past_key_values
                draft_logits = draft_outputs.logits[:,-1]
                next_token = torch.argmax(draft_logits, dim=-1,keepdim=True)

                draft_tokens.append(next_token)
            # [-draft_tokens_len-1:] for cases with full correct draft with + 1 token at the beginning
            draft_tokens = torch.cat(draft_tokens, dim=-1)[:, -draft_tokens_len-1:]

            # target prefill
            main_outputs = main_model(draft_tokens, past_key_values=main_past, use_cache=True)
            main_logits = main_outputs.logits
            main_tokens = torch.argmax(main_logits, dim=-1)

            # verification
            # equal = (main_tokens[:,:-1] == draft_tokens[:,1:]) * 1
            # if torch.all(equal):
            #     n_accepted = main_tokens.shape[-1]-1
            #     print(n_accepted)
            #     main_past = main_outputs.past_key_values
            # else:
            #     # Индекс первого несовпадения (или длина, если совпадают полностью)
            #     n_accepted = torch.argmax(~equal).item()


            equal = main_tokens[:, :-1] == draft_tokens[:, 1:]  # bool tensor
            # считаем, сколько токенов подряд совпало, до первого False
            if equal.all():
                n_accepted = equal.size(1)
                main_past = main_outputs.past_key_values
                # два токена
                next_token = main_tokens[:, -2:]
            else:
                # для batch_size=1 безопасно:
                eq_list = equal[0].tolist()
                n_accepted = eq_list.index(False)
                n_rejected = draft_tokens_len - n_accepted
                main_past = truncate_past_key_values(main_outputs.past_key_values, n_rejected)
                draft_past = truncate_past_key_values(draft_past, n_rejected - 1)
                next_token = main_tokens[:, n_accepted].unsqueeze(1)

            n_accepted_list.append(n_accepted)
            generated_tokens_history = torch.cat([generated_tokens_history, draft_tokens[:,:n_accepted+1]], dim=-1)
            pbar.update(n_accepted+1)
    return generated_tokens_history, n_accepted_list

In [157]:
# autoregressive_decode
start = time.time()
normal_ids = autoregressive_decode(input_ids, max_tokens=MAX_NEW_TOKENS)
time_normal = time.time() - start
print()
print(time_normal, normal_ids.shape, normal_ids, tokenizer.batch_decode(normal_ids), sep="\n")

100%|██████████| 11/11 [00:05<00:00,  1.98it/s]


7.0043580532073975
torch.Size([1, 12])
tensor([[   6,   89,   21,   10,  313,   54,   21,   10,  182,  205, 1441,    9]])
[', there was a man who was a very good friend of']





In [153]:
# autoregressive_decode
start = time.time()
draft_ids = autoregressive_decode(input_ids, max_tokens=MAX_NEW_TOKENS, model=draft_model)
time_draft = time.time() - start
print()
print(time_draft, draft_ids.shape, draft_ids, tokenizer.batch_decode(draft_ids), sep="\n")

100%|██████████| 11/11 [00:01<00:00,  5.72it/s]


2.4934303760528564
torch.Size([1, 12])
tensor([[    6,    89,    21,    10,   313,    54,    21,    10,   205,   313,
             4, 50118]])
[', there was a man who was a good man.\n']





In [156]:
# speculative_decode
start = time.time()
spec_ids, n_accepted_list = speculative_decode(input_ids, max_tokens=MAX_NEW_TOKENS, draft_tokens_len=1)
time_spec = time.time() - start
print()
print(time_spec, spec_ids.shape, spec_ids, tokenizer.batch_decode(spec_ids), n_accepted_list, sep="\n")

100%|██████████| 12/12 [00:05<00:00,  2.08it/s]


7.687418699264526
torch.Size([1, 12])
tensor([[   6,   89,   21,   10,  313,   54,   21,   10,  182,  205, 1441,    9]])
[', there was a man who was a very good friend of']
[1, 1, 1, 1, 1, 1]





In [160]:
start = time.time()
hf_ids = main_model.generate(input_ids, assistant_model=draft_model, max_new_tokens=MAX_NEW_TOKENS)
time_hf = time.time() - start
print()
print(time_hf, hf_ids.shape, hf_ids, tokenizer.batch_decode(hf_ids), sep="\n")


7.005308151245117
torch.Size([1, 17])
tensor([[    2, 11475,  2115,    10,    86,     6,    89,    21,    10,   313,
            54,    21,    10,   182,   205,  1441,     9]])
['</s>Once upon a time, there was a man who was a very good friend of']


In [161]:
start = time.time()
hf_ids = main_model.generate(input_ids, max_new_tokens=MAX_NEW_TOKENS)
time_hf = time.time() - start
print()
print(time_hf, hf_ids.shape, hf_ids, tokenizer.batch_decode(hf_ids), sep="\n")


6.436024188995361
torch.Size([1, 17])
tensor([[    2, 11475,  2115,    10,    86,     6,    89,    21,    10,   313,
            54,    21,    10,   182,   205,  1441,     9]])
['</s>Once upon a time, there was a man who was a very good friend of']


In [None]:
# Декодирование
text_normal = tokenizer.decode(normal_output[0], skip_special_tokens=True)
text_speculative = tokenizer.decode(generated_ids[0], skip_special_tokens=True)

# Метрики
speedup = time_normal / time_speculative
accept_rate = accepted / MAX_NEW_TOKENS * 100

# Вывод результатов
print(f"Обычное время: {time_normal:.2f}s")
print(f"Спекулятивное время: {time_speculative:.2f}s")
print(f"Ускорение: {speedup:.2f}x")
print(f"Acceptance Rate: {accept_rate:.2f}%\n")
print("--- Вывод обычной генерации ---")
print(text_normal, "\n")
print("--- Вывод со спекулятивным декодированием ---")
print(text_speculative)

# Визуализация
plt.figure()
plt.bar(['Normal', 'Speculative'], [time_normal, time_speculative])
plt.ylabel('Time (s)')
plt.title('Сравнение времени генерации')
plt.show()

plt.figure()
plt.bar(['Accepted', 'Rejected'], [accepted, MAX_NEW_TOKENS - accepted])
plt.ylabel('Token Count')
plt.title('Accepted vs Rejected Tokens')
plt.show()
