In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import time 

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# Load the tokenizer and model
model = AutoModelForCausalLM.from_pretrained("gpt2")
tokenizer = AutoTokenizer.from_pretrained("gpt2")
model.eval()

GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-11): 12 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(in_features=768, out_features=50257, bias=False)
)

In [3]:
prompt = "The future of AI is"
input_ids = tokenizer.encode(prompt, return_tensors="pt")

In [4]:
with torch.no_grad():
    start = time.time()
    output = model(input_ids, use_cache=True)
    past_key_values = output.past_key_values
    next_token = torch.argmax(output.logits[:, -1, :], dim=-1).unsqueeze(-1)

# Continue generation using KV cache
generated = [next_token]
for _ in range(10):
    with torch.no_grad():
        output = model(next_token, past_key_values=past_key_values, use_cache=True)
        past_key_values = output.past_key_values
        next_token = torch.argmax(output.logits[:, -1, :], dim=-1).unsqueeze(-1)
        generated.append(next_token)
        
cached_time = time.time() - start
generated_ids = torch.cat([input_ids] + generated, dim=1)
print("With KV Cache:", tokenizer.decode(generated_ids[0]))
print("Time with KV Cache:", cached_time)

With KV Cache: The future of AI is uncertain. The future of AI is uncertain.


Time with KV Cache: 0.49265623092651367


In [5]:
# Re-encode the input and regenerate at every step from scratch
input_ids = tokenizer.encode(prompt, return_tensors="pt")
generated_ids = input_ids.clone()
start = time.time()

for _ in range(10):
    with torch.no_grad():
        output = model(generated_ids, use_cache=False)  # No caching
        next_token = torch.argmax(output.logits[:, -1, :], dim=-1).unsqueeze(-1)
        generated_ids = torch.cat([generated_ids, next_token], dim=1)

no_cache_time = time.time() - start
print("Without KV Cache:", tokenizer.decode(generated_ids[0]))
print("Time without KV Cache:", no_cache_time)


Without KV Cache: The future of AI is uncertain. The future of AI is uncertain.

Time without KV Cache: 0.9878010749816895
