In [None]:
import torch
import torch.nn as nn
from transformers import AutoModelForCausalLM, AutoTokenizer
from collections import OrderedDict
from utils import *
from tqdm.auto import tqdm
import gc


model = AutoModelForCausalLM.from_pretrained(
    "huggyllama/llama-7b",
    torch_dtype=torch.float16,
)

tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b")

data_loader = load_eval_tokenized_dataset(
    tokenizer=tokenizer,
    dataset_name="ptb",
    seq_len=2048,
    batch_size=8,
)

# Evaluate ppl

model.eval()

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

model.to(DEVICE)

nlls = []
with torch.no_grad():
    for batch in tqdm(data_loader, desc="Evaluating perplexity", total=len(data_loader)):
        batch = batch.to(DEVICE)
        logits = model(batch, use_cache=False).logits
        if torch.isfinite(logits).all():
            shift_logits = logits[:, :-1, :].contiguous()
            shift_labels = batch[:, 1:].contiguous()
            loss_fct = nn.CrossEntropyLoss(reduction='none')
            loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
            nlls.append(loss.cpu())
        else:
            print("Bad logits detected, skipping batch.")
            continue
    ppl = torch.exp(torch.cat(nlls, dim=-1).mean()).item()
    if ppl > 1000:
        ppl = int(ppl)
        
print(f"Perplexity: {ppl}")

  from .autonotebook import tqdm as notebook_tqdm
Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00,  7.93it/s]
You are using the default legacy behaviour of the <class 'transformers.models.llama.tokenization_llama_fast.LlamaTokenizerFast'>. This is expected, and simply means that the `legacy` (previous) behavior will be used so nothing changes for you. If you want to use the new behaviour, set `legacy=False`. This should only be set if you understand what it means, and thoroughly read the reason why this was added as explained in https://github.com/huggingface/transformers/pull/24565 - if you loaded a llama tokenizer from a GGUF file you can ignore this message.
Token indices sequence length is longer than the specified maximum sequence length for this model (105835 > 2048). Running this sequence through the model will result in indexing errors
Evaluating perplexity: 100%|██████████| 7/7 [00:11<00:00,  1.71s/it]

Perplexity: 36.4375





: 