In [1]:
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from accelerate import Accelerator
from typing import List
from datasets import load_dataset

In [2]:
model_id = 'meta-llama/Llama-3.1-8B-Instruct'

In [3]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

In [4]:
accelerator = Accelerator()
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16)
model = accelerator.prepare(model)
model.eval()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (norm): LlamaRMSNorm((4096,), eps=1e-05)
    (rotary_

In [5]:
def batch_compute_perplexity(texts, context_length, batch_size):
    perplexities = []
    nll_sum = 0
    n_tokens = 0
    for i in range(0, len(texts), batch_size):
        print(f"processing batch: {i} out of {len(texts) // batch_size}")
        # Tokenize batch
        batch = texts[i:i + batch_size]
        encodings = tokenizer(
            batch,
            padding='longest',
            truncation=True,
            return_tensors="pt",
        ).to(accelerator.device)
        
        # Create labels (-100 for context and padding)
        labels = encodings.input_ids.clone()
        labels[labels == tokenizer.pad_token_id] = -100
        labels[:, :context_length] = -100
        
        # Forward pass and compute loss
        with torch.no_grad():
            shift_logits = model(
                input_ids=encodings.input_ids,
                attention_mask=encodings.attention_mask
            ).logits[:, :-1]
            
            shift_labels = labels[:, 1:]
            
            loss = torch.nn.functional.cross_entropy(
                shift_logits.reshape(-1, shift_logits.size(-1)),
                shift_labels.reshape(-1),
                ignore_index=-100
            )
            print(loss)

            num_valid_tokens = (shift_labels != -100).sum().item()
            nll_sum += loss.item() * num_valid_tokens
            n_tokens += num_valid_tokens
            
        # Calculate perplexity per sequence
    avg_nll = nll_sum / n_tokens  # average negative log-likelihood per token
    ppl = torch.exp(avg_nll)
        
    return avg_nll, ppl

In [6]:
ds = load_dataset('amang1802/synthetic_data_unconditioned_L3.1_70B')['train'].select(range(128))

In [7]:
context_length = 16
batch_size = 4

In [8]:
texts = ds['synthetic_content']
batch_compute_perplexity(texts, context_length, batch_size)

processing batch: 0 out of 32
tensor(0.5283, device='cuda:0', dtype=torch.float16)
processing batch: 4 out of 32
tensor(0.4246, device='cuda:0', dtype=torch.float16)
processing batch: 8 out of 32


OutOfMemoryError: CUDA out of memory. Tried to allocate 3.88 GiB. GPU 0 has a total capacity of 23.69 GiB of which 1.08 GiB is free. Including non-PyTorch memory, this process has 21.98 GiB memory in use. Of the allocated memory 19.27 GiB is allocated by PyTorch, and 2.41 GiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)