In [1]:
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import math

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model_name = "meta-llama/Llama-3.1-8B"

device = "cuda" if torch.cuda.is_available() else "cpu"

tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    torch_dtype=torch.float16 if device == "cuda" else torch.float32,
    device_map="auto"
)
model.eval()

Loading checkpoint shards: 100%|██████████| 4/4 [00:37<00:00,  9.46s/it]


LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_

In [3]:
def compute_perplexity(prompt, max_new_tokens=64):
    inputs = tokenizer(prompt, return_tensors="pt").to(device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,  # greedy decoding
            return_dict_in_generate=True,
            output_scores=True
        )
    
    generated_ids = outputs.sequences[0]
    prompt_len = inputs.input_ids.shape[1]
    generated_text = tokenizer.decode(generated_ids, skip_special_tokens=True)

    # Compute logits for loss
    with torch.no_grad():
        logits = model(generated_ids.unsqueeze(0)).logits

    shift_logits = logits[:, :-1, :].contiguous()
    shift_labels = generated_ids.unsqueeze(0)[:, 1:].contiguous()
    
    # Start computing loss from the end of the prompt
    shift_logits = shift_logits[:, prompt_len-1:, :]
    shift_labels = shift_labels[:, prompt_len-1:]

    loss_fct = torch.nn.CrossEntropyLoss(reduction="none")
    loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)),
                    shift_labels.view(-1))
    loss_per_token = loss.view(shift_labels.size())

    # Per-token perplexity
    ppl_per_token = torch.exp(loss_per_token).squeeze().tolist()

    # Global perplexity
    global_ppl = math.exp(loss.mean().item())

    return {
        "prompt": prompt,
        "generated_text": generated_text,
        "per_token_ppl": ppl_per_token,
        "global_ppl": global_ppl
    }

In [4]:
prompts = [
    "2 + 2 = ",
    "The capital of China is ",
    "Once upon a time, there was a ",
    "The message is: qwerf23jdaf0klsaf",
    "The secret of the pen is ",
    "The password for the ancient library is "
]

results = []
for p in prompts:
    res = compute_perplexity(p)
    results.append(res)

for r in results:
    print("="*60)
    print(f"Prompt: {r['prompt']}")
    print(f"Generated: {r['generated_text'][:200]}...")
    print(f"Global perplexity: {r['global_ppl']:.2f}")
    print(f"First 10 per-token PPL: {r['per_token_ppl'][:10]}")

Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:128001 for open-end generation.


Prompt: 2 + 2 = 
Generated: 2 + 2 = 4. 2 + 2 = 4. 2 + 2 = 4. 2 + 2 = 4. 2 + 2 = 4. 2 + 2 = 4. 2 + 2 = 4. 2 + 2 = 4...
Global perplexity: 1.21
First 10 per-token PPL: [2.498046875, 6.828125, 13.40625, 4.53515625, 2.779296875, 1.0556640625, 1.734375, 1.451171875, 1.12109375, 2.40625]
Prompt: The capital of China is 
Generated: The capital of China is 1,500 miles from the nearest ocean. The city is located in the north of the country, in the middle of the North China Plain. The city is located in the north of the country, in...
Global perplexity: 2.07
First 10 per-token PPL: [10.296875, 1.3212890625, 6.765625, 2.853515625, 2.294921875, 3.28515625, 9.921875, 2.884765625, 2.697265625, 9.140625]
Prompt: Once upon a time, there was a 
Generated: Once upon a time, there was a 3-year-old boy named Jack. He was a very happy boy, and he loved to play with his toys. One day, Jack was playing with his toy car when he accidentally knocked over a vas...
Global perplexity: 2.22
First 10 per-token PPL: