In [1]:
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset

import torch
import os

In [2]:
torch.set_float32_matmul_precision('high')

In [3]:
model_id = '/root/synthetic-data-recipes/diversity/ft_models/llama3_1_8B/fulltext_conditioned_10epochs_lr1e-5/epoch_9'

In [4]:
tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

In [5]:
context_length = 16
batch_size = 4
num_samples = 512

In [6]:
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.bfloat16).to("cuda:0")
model.eval()

Loading checkpoint shards:   0%|          | 0/4 [00:00<?, ?it/s]

LlamaForCausalLM(
  (model): LlamaModel(
    (embed_tokens): Embedding(128256, 4096)
    (layers): ModuleList(
      (0-31): 32 x LlamaDecoderLayer(
        (self_attn): LlamaSdpaAttention(
          (q_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (k_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (v_proj): Linear(in_features=4096, out_features=1024, bias=False)
          (o_proj): Linear(in_features=4096, out_features=4096, bias=False)
          (rotary_emb): LlamaRotaryEmbedding()
        )
        (mlp): LlamaMLP(
          (gate_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (up_proj): Linear(in_features=4096, out_features=14336, bias=False)
          (down_proj): Linear(in_features=14336, out_features=4096, bias=False)
          (act_fn): SiLU()
        )
        (input_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
        (post_attention_layernorm): LlamaRMSNorm((4096,), eps=1e-05)
      )
    )
    (n

In [7]:
def get_tokens_batch(batch):
    encodings = tokenizer(
            batch,
            padding='longest',
            return_tensors="pt",
        ).to("cuda:0")
    
    return encodings

def get_shift_labels(encodings):
    labels = encodings.input_ids.clone()
    labels[labels == tokenizer.pad_token_id] = -100
    labels[:, :context_length] = -100
    shift_labels = labels[:, 1:]

    return shift_labels

In [8]:
def batch_compute_perplexity(texts, context_length, batch_size):
    nll_sum = torch.tensor(0, dtype=torch.float64, requires_grad=False).to("cuda:0")
    n_tokens = torch.tensor(0, dtype=torch.int64, requires_grad=False).to("cuda:0")

    for i in range(0, len(texts), batch_size):
        #print(f"processing batch: {i//batch_size} out of {len(texts)//batch_size}")

        batch = texts[i:i + batch_size]
        encodings = get_tokens_batch(batch)
        
        with torch.no_grad():
            shift_logits = model(
                input_ids=encodings.input_ids,
                attention_mask=encodings.attention_mask
            ).logits[:, :-1]
        
        shift_labels = get_shift_labels(encodings)
        
        loss = torch.nn.functional.cross_entropy(
            shift_logits.reshape(-1, shift_logits.size(-1)),
            shift_labels.reshape(-1),
            ignore_index=-100
        ).type(torch.float64)
        
        #print(loss)

        num_valid_tokens = (shift_labels != -100).sum()
        
        nll_sum += loss * num_valid_tokens
        n_tokens += num_valid_tokens
            
    avg_nll = nll_sum / n_tokens
    ppl = torch.exp(avg_nll)
        
    return nll_sum, n_tokens, avg_nll, ppl

In [9]:
ds_list = [
    # 'amang1802/synthetic_data_topic_conditioned_L3.3_70B_deduped',
    # 'amang1802/synthetic_data_prefix_conditioned_L3.3_70B_deduped',
    'amang1802/synthetic_data_fulltext_conditioned_L3.3_70B_deduped'
]

In [None]:
for ds_id in ds_list:
    ds = load_dataset(ds_id)
    for split in ['train', 'test']:    
        texts = ds[split].shuffle(seed=1998).select(range(num_samples))['synthetic_content']
        metrics = batch_compute_perplexity(texts, context_length, batch_size)
        print(ds_id, split, metrics)

amang1802/synthetic_data_fulltext_conditioned_L3.3_70B_deduped train (tensor(340549.3350, device='cuda:0', dtype=torch.float64), tensor(545041, device='cuda:0'), tensor(0.6248, device='cuda:0', dtype=torch.float64), tensor(1.8679, device='cuda:0', dtype=torch.float64))
