In [None]:
%%capture
!pip install pip3-autoremove
!pip-autoremove torch torchvision torchaudio -y
!pip install torch torchvision torchaudio xformers --index-url https://download.pytorch.org/whl/cu121
!pip install unsloth

In [None]:
from unsloth import FastLanguageModel
import torch

In [None]:
from datasets import load_dataset
max_seq_length = 2048 # Choose any! We auto support RoPE Scaling internally!
dtype = None # None for auto detection. Float16 for Tesla T4, V100, Bfloat16 for Ampere+
load_in_4bit = True # Use 4bit quantization to reduce memory usage. Can be False.
dataset = load_dataset("BurgerTruck/counsel-chat", split = "test")
if True:
    from unsloth import FastLanguageModel
    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name = "BurgerTruck/LLAMA-Counsel-chat",
        max_seq_length = max_seq_length,
        dtype = dtype,
        load_in_4bit = load_in_4bit,
    )
    FastLanguageModel.for_inference(model) # Enable native 2x faster inference


In [None]:
from unsloth.chat_templates import get_chat_template

tokenizer = get_chat_template(
    tokenizer,
    chat_template = "llama-3.1",
    mapping = {"role" : "from", "content" : "value", "user" : "speaker0", "assistant" : "speaker1"} # for some reason does not work when chate template is llama

)
FastLanguageModel.for_inference(model) # Enable native 2x faster inference





In [None]:
# 128006: <|start_header_id|>
# 78191: assistant
# 128007: <|end_header_id|>
# 271: \n\n
# 128009: <|eot_id|>

# 128006: <|start_header_id|>
# 882: user
# 128007: <|end_header_id|>
import json
def formatting_prompts_func(examples):
    convos = examples["conversations"]
    json_convs = [json.loads(convo) for convo in convos]
    for conversation in json_convs:
        for utterance_dict in conversation:
            utterance_dict["role"] = utterance_dict["from"]
            del utterance_dict["from"]

            utterance_dict["content"] = utterance_dict["value"]
            del utterance_dict["value"]

            if utterance_dict["role"] == "speaker0":
                utterance_dict["role"] = "user"
            else:
                utterance_dict["role"] = "assistant"

    for convo in json_convs:
      yield tokenizer.apply_chat_template(convo, tokenize = True, add_generation_prompt = False, return_tensors='pt')
def create_assistant_mask(input_ids):
    start_header_id = 128006
    assistant_id = 78191
    end_header_id = 128007
    newline = 271
    eos_token = 128009

    # Squeeze the batch dimension if it exists
    if len(input_ids.shape) == 2:
        input_ids = input_ids.squeeze(0)

    # Create an empty mask of the same size as input_ids initialized to 0
    mask = torch.zeros_like(input_ids, dtype=torch.long)

    # Iterate through the input_ids to detect the assistant's utterances
    in_assistant_utterance = False
    for i in range(3, len(input_ids)):
        if (
            input_ids[i - 3] == start_header_id
            and input_ids[i - 2] == assistant_id
            and input_ids[i - 1] == end_header_id
            and input_ids[i] == newline
            and not in_assistant_utterance
        ):
            # Found the start of an assistant's utterance
            in_assistant_utterance = True
        elif input_ids[i] == eos_token and in_assistant_utterance:
            # End of assistant's utterance
            in_assistant_utterance = False
            mask[i] = 1
        elif in_assistant_utterance:
            mask[i] = 1

    # Add batch dimension back if the input had one
    if len(mask.shape) == 1:
        mask = mask.unsqueeze(0)

    return mask


In [None]:
print(4000.0)

In [None]:
model.eval()
import torch.nn.functional as F
with torch.no_grad():
  perplexity = 0
  count = 0
  for input_ids in formatting_prompts_func(dataset):
    input_ids = input_ids.to("cuda")
    # for token in input_ids[0]:
    #   print(f"{token}: {tokenizer.decode(token)}")
    output = model(input_ids)
    logits = output.logits

    log_probs = F.log_softmax(logits, dim=-1)
    target_ids = input_ids[:, 1:].clone() 
    padding = torch.full((input_ids.size(0), 1), 128009).to("cuda")  # Padding token 
    target_ids = torch.cat([target_ids, padding], dim=1)  # Shift to the left by 1


    # # Step 2: Gather log probabilities of the ground truth tokens
    # # Use torch.gather to extract the probabilities corresponding to target tokens
    target_log_probs = torch.gather(log_probs, dim=-1, index=target_ids.unsqueeze(-1)).squeeze(-1)
    mask = create_assistant_mask(input_ids).to("cuda")

    # Check if user input is masked correctly 
    # for token in (input_ids*mask)[0]:
    #   print(tokenizer.decode(token), end = '')

    target_log_probs = target_log_probs * mask
    # print(target_log_probs)
    if torch.isnan(target_log_probs).any() or torch.isinf(target_log_probs).any():
      print("Found nan or inf in target_log_probs!")

    # # Step 3: Calculate the average negative log-likelihood
    nll = -target_log_probs.sum(dim=-1) / mask.sum(dim=-1)  # Mean over sequence length per batch
    # # Step 4: Calculate perplexity for each sequence
    perplexity +=nll[0]
    count+=1
    # if count==5:
    #   break
final_perplexity = torch.exp(perplexity/count)
print(perplexity)
print(final_perplexity)
