In [None]:

%matplotlib notebook

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import itertools
import math
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import os


os.environ["TOKENIZERS_PARALLELISM"] = "false"  # Or "true"

def sort_lists(x,y):


    unique_x = torch.unique(x)

    # Initialize an empty tensor to accumulate the summed y values
    sorted_y = torch.zeros_like(unique_x, dtype=torch.float64)

    # Sum the corresponding y values for each unique x
    for i, ux in enumerate(unique_x):
        sorted_y[i] = y[x == ux].sum()

    return unique_x, sorted_y

# Function to plot the KDE
def plot_kde(x_values, probabilities, num_bins=20, bw_adjust=1.0):
    """
    Plots a KDE estimation of the density along with a histogram where bins are centered.

    :param x_values: List or array of x values (1D).
    :param probabilities: List or array of corresponding probabilities (1D).
    :param num_bins: Number of bins for the histogram.
    :param bw_adjust: Bandwidth adjustment for KDE.
    """
    probabilities = np.array(probabilities)
    probabilities /= probabilities.sum()  # Normalize probabilities
    
    x_values = np.array(x_values)

    # Define integer-based bins centered at each unique value
    unique_x = np.unique(x_values)
    bin_edges = np.arange(unique_x.min() - 0.5, unique_x.max() + 1.5, 1)  # Bin edges centered around integers
    bin_centers = bin_edges[:-1] + 0.5  # Midpoints of bins

    # Compute histogram
    hist_values, _ = np.histogram(x_values, bins=bin_edges, weights=probabilities, density=True)

    # Plot histogram as bars centered on bin values
    plt.bar(bin_centers, hist_values, width=1, alpha=0.3, color='gray', edgecolor='black', label="Histogram")

    # Convert data to DataFrame for Seaborn
    df = pd.DataFrame({'x': x_values, 'prob': probabilities})

    # Plot KDE
    sns.kdeplot(data=df, x='x', weights='prob', fill=True, color='blue', alpha=0.5, label="KDE", bw_adjust=bw_adjust)
    plt.yscale("log")
    # Labels and title
    plt.title("Kernel Density Estimation (KDE) with Centered Histogram", fontsize=16)
    plt.xlabel("X values", fontsize=14)
    plt.ylabel("Density", fontsize=14)
    plt.legend()

    # Show the plot
    plt.show()
    plt.savefig("plot.png", dpi=1000)

def find_tokenizations(sentence, tokenizer, memo=None, encode=False):
    """Recursive function to find all possible tokenizations."""
    if memo is None:
        memo = {}
    if sentence in memo:
        return memo[sentence]
    if not sentence:
        return [[]]

    tokenizations = []
    for i in range(1, len(sentence) + 1):
        prefix = sentence[:i]
        rest = sentence[i:]
        encoded_prefix = tokenizer.encode(prefix, add_special_tokens=False)
        if len(encoded_prefix) == 1:  # Only consider valid tokenizations
            for rest_tokenization in find_tokenizations(rest, tokenizer, memo):
                tokenizations.append([prefix] + rest_tokenization)

    memo[sentence] = tokenizations
    
    if encode:
        # Encode the tokenizations
        tokenizations = [ [tokenizer.encode(string, add_special_tokens=False) for string in tokenization] for tokenization in tokenizations]
        # Flatten the list of lists
        tokenizations = [list(itertools.chain.from_iterable(tokenization)) for tokenization in tokenizations]
    return tokenizations

def compute_tokenization_probability(tokenization, prompt, tokenizer, model, logit=False):
    """Computes the probability of a tokenization by multiplying the probabilities of each token."""
    
    
    tokenization_ids=torch.tensor(tokenization).unsqueeze(0) # Convert tokenization to tensor
    
    # Tokenize the entire sequence to get the token ids
    prompt_ids = torch.tensor(tokenizer.encode(prompt, add_special_tokens=False)).unsqueeze(0)
    
    #print("tokenization_ids", tokenization_ids.shape)
    #print("prompt_ids", prompt_ids.shape)
    input_ids = torch.cat((prompt_ids, tokenization_ids), dim=1) # Concatenate prompt and tokenization ids
    #print("input_ids", input_ids.shape)
    # Get model's predictions (logits) for each token
    with torch.no_grad():
        outputs = model(input_ids)
        logits = outputs.logits

    # Convert logits to probabilities using softmax
    probabilities = torch.softmax(logits, dim=-1)

    # Calculate the probability of the tokenization by multiplying the probabilities of each token
    tokenization_probability = 1.0
    tokenization_logit = 0
    for idx, token in enumerate(tokenization):
        # Get the token ID from the tokenizer
        #print("token", token)
        #token_id = tokenizer.encode(token, add_special_tokens=False)[0]
        token_id=token
        
        # Get the probability of the token in the model's output
        token_probability = probabilities[0, prompt_ids.shape[1]+idx, token_id].item()
        token_logit = logits[0, prompt_ids.shape[1]+idx, token_id].item()
        tokenization_logit += token_logit
        #print("current token prob", token_probability)
        #print("current token logit", token_logit)
        tokenization_probability *= token_probability

    return tokenization_probability if not logit else tokenization_logit


print("Initilizing script...")
custom_cache_dir = "/NL/token-pricing/work/models"


# Load tokenizer and model
model_name = "meta-llama/Llama-3.2-1B"
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=custom_cache_dir)
model = AutoModelForCausalLM.from_pretrained(model_name, 
                                             torch_dtype=torch.float16, 
                                             device_map="auto",
                                             cache_dir=custom_cache_dir)

print("Model loaded...")
# Define the prompt and text
prompt = " "
text = "Speechlessly"

# Find all possible tokenizations
tokenizations = find_tokenizations(text, tokenizer, encode=True) #List, with each element=tokenization being a list of token IDs
print("Tokenizations found...")
list_lengths = []
list_prob = []

for idx, tokenization in enumerate(tokenizations):
    # Compute the probability of this tokenization
    list_lengths.append(len(tokenization))
    
    prob = compute_tokenization_probability(tokenization, prompt, tokenizer, model, logit=False)
    
    list_prob.append(prob)
    
    #tokenization = [tokenizer.decode(tokenization, skip_special_tokens=True)
                    
    readable_tokenization = ' '.join(tokenizer.decode([token_id], skip_special_tokens=True) for token_id in tokenization)
    
    print(f"Tokenization {idx + 1}: {readable_tokenization} | Probability/logit: {prob:.15f}")

list_prob = [prob / np.sum(list_prob) for prob in list_prob]  # Normalize the probabilities
# Plot the KDE
lengths, probs = sort_lists(torch.tensor(list_lengths, dtype=torch.float64), torch.tensor(list_prob, dtype=torch.float64))

plot_kde(lengths, probs)



