In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

import torch
import re
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
from peft import PeftModel


# tokenizer = AutoTokenizer.from_pretrained("codellama/CodeLlama-7b-Instruct-hf")
# model = AutoModelForCausalLM.from_pretrained("codellama/CodeLlama-7b-Instruct-hf")

def load_model(model_name, lora_path=None):
    model = AutoModelForCausalLM.from_pretrained(
        model_name,
        torch_dtype=torch.bfloat16,
        device_map="auto",
        trust_remote_code=True
    )
    tokenizer_name = model_name
    tokenizer = AutoTokenizer.from_pretrained(
        tokenizer_name,
        use_fast=False,  # Take care of llama
        trust_remote_code=True,
        padding_side='left'
    )
    # Workaround for LLaMA tokenizers
    if "llama" in model_name.lower():
        tokenizer.pad_token_id = tokenizer.eos_token_id

    if lora_path:
        model = PeftModel.from_pretrained(
            model,
            lora_path,
            torch_dtype=torch.bfloat16,
        )
        print(f"Loaded peft model from {lora_path}")
    return model, tokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
model, tokenizer = load_model("meta-llama/Llama-2-7b-hf")

Loading checkpoint shards: 100%|██████████| 2/2 [00:09<00:00,  4.85s/it]


In [47]:
tokenizer.decode([917, 29896, 29889, 29896, 29922, 29896])

'Tags1.1=1'

In [117]:
def get_probs(text: str, model, tokenizer):
    pattern = re.compile(r'(<(\w+)>(.+?)</\2>)|([^<>]+)')
    
    cumulative_text = ""
    span_names = []
    span_starts = []  # span start token positions
    span_ends = []  # span end token positions
    span_contents = []

    for match in pattern.finditer(text):
        text_before_span = cumulative_text
        
        if match.group(1):
            span_names.append(match.group(2))
            span_text = match.group(3)
            span_contents.append(span_text)
        else:
            span_text = match.group(0)
        
        cumulative_text += span_text

        if match.group(1):
            span_starts.append(len(tokenizer.tokenize(text_before_span)))
            span_ends.append(len(tokenizer.tokenize(cumulative_text)))
    
    # ASSERT THAT THE CUMULATIVE TEXT IS EQUAL TO THE INPUT TEXT WITH OUT TAGS
    assert cumulative_text == re.sub(r'<.+?>|</.+?>', "", text)

    # Forward the model to find the unconditional probabilities of the spans
    uncond_span_probs = {}
    for span_name, span_content in zip(span_names, span_contents):
        uncond_span_input = tokenizer(span_content, return_tensors="pt").to(model.device)
        uncond_logits = model(**uncond_span_input, return_dict=True).logits
        uncond_probs = torch.softmax(uncond_logits, 2)
        target_span_probs = torch.gather(uncond_probs[0, :-1], -1, uncond_span_input.input_ids[0, 1:].unsqueeze(-1))
        uncond_span_probs[span_name] = target_span_probs.mean().item()


    inputs = tokenizer(cumulative_text, return_tensors="pt").to(model.device)
    # Forward the model to find the conditional probabilities of the spans
    outputs = model(**inputs, return_dict=True)
    logits = outputs.logits
    probs = torch.softmax(logits, 2)
    seq_probs = probs[0]
    # Extract scores for each span.
    cond_span_probs = {}
    for span_name, start, end in zip(span_names, span_starts, span_ends):
        # print(start, end)
        # print(seq_probs[start-1:end-1])
        target_probs = torch.gather(seq_probs[start-1:end-1], -1, inputs.input_ids[0, start:end].unsqueeze(-1))
        cond_span_probs[span_name] = target_probs.mean().item()
    
    print("Unconditional probabilities:")
    for span_name in span_names:
        print(f"p({span_name}) = {uncond_span_probs[span_name]}")

    print("\nConditional probabilities:")
    history_span_names = []
    for span_name in span_names:
        if not history_span_names:
            print(f"p({span_name} | <prompt_prefix>) = {cond_span_probs[span_name]}")
        else:
            print(f"p({span_name} | {', '.join(history_span_names)}) = {cond_span_probs[span_name]}")
        history_span_names.append(span_name)



    return uncond_span_probs, cond_span_probs

In [118]:
probs = get_probs("""An<A> apple</A> is an<B> English</B> word.""", model, tokenizer)  # An / A

Unconditional probabilities:
p(A) = 0.00778670608997345
p(B) = 0.007787714712321758

Conditional probabilities:
p(A | <prompt_prefix>) = 0.0022087732795625925
p(B | A) = 0.30144742131233215


In [52]:
tokenizer("An apple is a fruit.")

{'input_ids': [1, 530, 26163, 338, 263, 15774, 29889], 'attention_mask': [1, 1, 1, 1, 1, 1, 1]}