In [1]:
from datasets import load_dataset
from tqdm import tqdm
from model_chunking.models.qwen2 import Qwen2ChunkingForCausalLM, Qwen2ChunkingConfig, Qwen2Tokenizer
import torch

  from .autonotebook import tqdm as notebook_tqdm


In [16]:
model_name = "Qwen/Qwen2.5-0.5B-Instruct"
config = Qwen2ChunkingConfig.from_pretrained(
    model_name,
    num_layers_per_chunk=24, 
    chunking_mode="prune",
    layers_to_prune=[23,24],
    aggregation_mode="mean", 
    use_adapters=False
)
model = Qwen2ChunkingForCausalLM.from_pretrained(
    model_name,
    config=config,
    torch_dtype="auto",
    device_map="auto",
)
tokenizer = Qwen2Tokenizer.from_pretrained(model_name)

You are using a model of type qwen2 to instantiate a model of type qwen2_chunking. This is not supported for all configurations of models and can yield errors.


In [3]:
dataset = load_dataset("Salesforce/wikitext", "wikitext-103-v1", split="validation")

In [4]:
# Concatenate the validation set into a single long text and then tokenize it
# NOTE: Taken from https://huggingface.co/docs/transformers/en/perplexity
encoded_inputs = tokenizer("\n\n".join(dataset['text']), return_tensors="pt")

max_length = 2048 # Memory reasons
stride = 256
length_of_input = encoded_inputs['input_ids'].shape[-1]

Token indices sequence length is longer than the specified maximum sequence length for this model (262363 > 131072). Running this sequence through the model will result in indexing errors


In [17]:
losses : list[float] = []
last_end = 0

for start in tqdm(range(0, length_of_input, stride)):
    end = min(start + max_length, length_of_input)
    target_length = end - last_end
    input_ids = encoded_inputs['input_ids'][:, start : end].to(model.device)
    target_ids = input_ids.clone()
    target_ids[:, :-target_length] = -100

    with torch.no_grad():
        outputs = model(input_ids, labels=target_ids)
        loss = outputs.loss

    losses.append(loss)
    last_end = end
    break

    del target_ids, input_ids

  0%|          | 0/1025 [00:00<?, ?it/s]


In [19]:
print(outputs.past_key_values)

AttributeError: 'CausalLMOutputWithPast' object has no attribute 'last_hidden_state'

In [9]:
perplexity_score = torch.exp(torch.stack(losses).mean())
print(perplexity_score)