# Fine Tuned Perplexity Analysis Test

Once we've gone through training, we'll go through the perplexity analysis

In [None]:
import torch
import math
from transformers import AutoModelForCausalLM, LlamaTokenizer
import bitsandbytes as bnb
from datasets import load_dataset

In [None]:
# do the perplexity comparison
def compute_perplexity(model, dataset, tokenizer):
    model.eval()  # Set the model to evaluation mode
    total_log_likelihood = 0
    total_tokens = 0

    with torch.no_grad():
        for batch in dataset:
            inputs = batch['input_ids']
            inputs = inputs.to(model.device)

            # Shift the inputs so that tokens predict the next token
            labels = inputs[:, 1:].clone()
            inputs = inputs[:, :-1]

            outputs = model(input_ids=inputs, labels=labels)
            log_likelihood = outputs.loss.item()
            total_log_likelihood += log_likelihood * labels.size(0)
            total_tokens += labels.size(0)

    perplexity = math.exp(total_log_likelihood / total_tokens)
    return perplexity

In [None]:
model_name = 'llama-2-7b-hf'

tokenizer = LlamaTokenizer.from_pretrained(model_name)

free_in_GB = int(torch.cuda.mem_get_info()[0]/1024**3)
max_memory = f'{int(torch.cuda.mem_get_info()[0]/1024**3)-2}GB'

n_gpus = torch.cuda.device_count()
max_memory = {i: max_memory for i in range(n_gpus)}

model = AutoModelForCausalLM.from_pretrained(
  model_name,
  device_map='auto',
  load_in_4bit=True,
  max_memory=max_memory,
  do_sample=True,
  torch_dtype="auto"
)

In [None]:
data_name = "mlabonne/guanaco-llama2-1k"
training_data = load_dataset(data_name, split="test")

In [None]:
import torch
from tqdm import tqdm

max_length = model.config.max_length
stride = 512
seq_len = encodings.input_ids.size(1)
device = "cuda"

nlls = []
prev_end_loc = 0
for begin_loc in tqdm(range(0, seq_len, stride)):
    end_loc = min(begin_loc + max_length, seq_len)
    trg_len = end_loc - prev_end_loc  # may be different from stride on last loop
    input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
    target_ids = input_ids.clone()
    target_ids[:, :-trg_len] = -100

    with torch.no_grad():
        outputs = model(input_ids, labels=target_ids)

        # loss is calculated using CrossEntropyLoss which averages over valid labels
        # N.B. the model only calculates loss over trg_len - 1 labels, because it internally shifts the labels
        # to the left by 1.
        neg_log_likelihood = outputs.loss

    nlls.append(neg_log_likelihood)

    prev_end_loc = end_loc
    if end_loc == seq_len:
        break

ppl = torch.exp(torch.stack(nlls).mean())

In [None]:
ppl