In [None]:
%matplotlib inline

import pickle
from transformers import AutoTokenizer, AutoModelForCausalLM

import os
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import glob
import numpy as np
from IPython.display import clear_output
from collections import Counter
os.chdir("../")

from src import utils
from matplotlib.lines import Line2D
from matplotlib.patches import Patch
from matplotlib.ticker import LogLocator
import torch
from termcolor import colored  # Use termcolor for colored output in the terminal



In [193]:
# Split and verify functions
def verify_sampling_conditions(tokens, prompt_length, top_k=None, top_p=None, model=None, tokenizer=None, temp = 1.0):
    # Convert tokens to tensor and run the model
    input_ids = torch.tensor([tokens]).to("cpu")
    with torch.no_grad():
        outputs = model(input_ids)
    
    logits = outputs.logits

    all_top_k_met = True
    all_top_p_met = True

    # Evaluate only on tokens after the prompt
    for i in range(prompt_length, len(tokens)):  # Start from tokens after the prompt
        previous_logits = logits[0, i - 1]  # Logits for predicting the current token
        probabilities = torch.softmax(previous_logits / temp, dim=-1)  # Convert logits to probabilities

        # Get current token
        
        
        current_token = tokens[i]
        token_probability = probabilities[current_token].item()

        # Check top-k condition
        top_k_condition = False
        if top_k is not None:
            top_k_indices = torch.topk(probabilities, k=top_k).indices
            top_k_condition = current_token in top_k_indices.tolist()
            all_top_k_met = all_top_k_met and top_k_condition  # Update overall status

        # Check top-p condition
        top_p_condition = False
        if top_p is not None:
            sorted_probs, sorted_indices = torch.sort(probabilities, descending=True)
            cumulative_probs = torch.cumsum(sorted_probs, dim=-1)
            top_p_indices = sorted_indices[cumulative_probs <= top_p]
            # Include the first token that pushes cumulative probability over top_p
            if len(top_p_indices) < len(sorted_probs):
                top_p_indices = torch.cat([top_p_indices, sorted_indices[len(top_p_indices):len(top_p_indices) + 1]])
            top_p_condition = current_token in top_p_indices.tolist()
            all_top_p_met = all_top_p_met and top_p_condition  # Update overall status



    return {
        "all_top_k_met": all_top_k_met if top_k is not None else None,
        "all_top_p_met": all_top_p_met if top_p is not None else None,
    }



def split_token(sequence, tokenizer, vocab):
    """
    Heuristic function to splits a token into two subtokens based on the sum of their indices in the vocabulary.
    """
    
    #print("Sequence: ", sequence)
    # Reverse mapping: ID -> Token
    id_to_token = {v: k for k, v in vocab.items()}

    # Select a token ID to split (heuristic: pick the lowest ID token)
    #print("Sequence: ", sequence)
    
    #Get all token IDs in the sequence that have at least two characters
    #print("Splitting token sequence", sequence)
    #valid_ids = [token_id for token_id in sequence if len(id_to_token[token_id]) > 1]
    valid_ids = [token_id for token_id in sequence if len(tokenizer.decode([token_id])) > 1]
    if len(valid_ids) == 0:
        print("No valid token IDs found, returning original sequence", sequence)
        return sequence
    #print("Valid IDs: ", valid_ids)
    
    token_id_to_split = max(valid_ids)

    # Get the token corresponding to the selected ID
    token_to_split = id_to_token[token_id_to_split]

    
    
    # Initialize variables to store the best split
    best_split = None
    
    
    max_index = -float('inf')  # Start with a very low number for comparison

    # Try all possible splits and calculate the sum of the indices for each part
    for mid_index in range(1, len(token_to_split)):  # Split at various points
        Y = token_to_split[:mid_index]
        Z = token_to_split[mid_index:]
        
        # Get the token IDs for Y and Z
        Y_id = vocab.get(Y)  # No default value; will return None if Y isn't valid
        Z_id = vocab.get(Z)  # No default value; will return None if Z isn't valid


        # Skip this split if either Y or Z is invalid
        if Y_id is None or Z_id is None:
            continue

        # Calculate the sum of the indices
        index_min = min(Y_id, Z_id)

        # If the sum of the indices is the largest found so far, update best split
        if index_min > max_index:
            best_split = (Y, Z)
            max_index = index_min




    # If no valid split was found, return the original sequence
    if best_split is None:
        return sequence

    # Replace the token X with its split subtokens Y and Z in the sequence
    new_sequence = []
    updated = False
    for token_id in sequence:
        if token_id == token_id_to_split and not updated:
            # Replace token X with subtokens Y and Z
            new_sequence.extend([vocab[best_split[0]], vocab[best_split[1]]])
            updated = True
        else:
            new_sequence.append(token_id)

    return new_sequence



def print_tokens(tokenizer, token_ids, separator="|"):
    """
    Decodes a sequence of token IDs into strings, separated by a specified symbol.

    Args:
        tokenizer: A tokenizer from the transformers library.
        token_ids: A list of token IDs.
        separator: A string to separate decoded tokens (default is '|').

    Returns:
        A single string with each token's decoded representation separated by the separator.
    """
    # Decode each token ID into its corresponding string
    token_strings = [tokenizer.decode([token_id], clean_up_tokenization_spaces=False) for token_id in token_ids]
    # Join the decoded strings with the separator
    result = separator.join(token_strings)
    return result


def print_tokens_with_reference(tokenizer, token_ids, reference_ids, separator="|"):
    """
    Decodes a sequence of token IDs into strings, separated by a specified symbol.
    Highlights tokens not in the reference sequence in red.

    Args:
        tokenizer: A tokenizer from the transformers library.
        token_ids: A list of token IDs.
        reference_ids: A reference list of token IDs to compare against.
        separator: A string to separate decoded tokens (default is '|').

    Returns:
        A single string with each token's decoded representation, 
        new tokens highlighted in red, separated by the separator.
    """
    # Decode the reference token IDs into strings
    reference_strings = set(
        tokenizer.decode([token_id], clean_up_tokenization_spaces=False) for token_id in reference_ids
    )
    
    # Decode the target token IDs and mark new tokens
    token_strings = []
    for token_id in token_ids:
        token_string = tokenizer.decode([token_id], clean_up_tokenization_spaces=False)
        if token_string not in reference_strings:
            token_strings.append(colored(token_string, "red"))
        else:
            token_strings.append(token_string)
    
    # Join the tokens with the separator
    result = separator.join(token_strings)
    #print(result)
    return result



In [None]:
model_name = "meta-llama/Llama-3.2-1B-Instruct"
model_cache = "../models"

tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=model_cache)
model = AutoModelForCausalLM.from_pretrained(model_name, cache_dir=model_cache)



In [None]:
with open("../outputs/cpt/shortest_vs_factual_modelLlama-3.2-3B-Instruct_p1.0_kNone_numseq3_numprompts100_maxoutlen200_temp1.3_idare you .pkl", "rb") as f:
    data = pickle.load(f)
    


In [None]:
#First prompt_id = 28, seq_id = 0, min_tok, max_tok = 0, 100temperature = 1.3, top_p=0.95
#First prompt_id = 36, seq_id = 0, min_tok, max_tok = 20, 150temperature = 1.3, top_p=0.95

prompt_id = 34
seq_id = 1
min_tok, max_tok = 0, 200
prompt = data[prompt_id]["prompt"]
prompt_tokens = tokenizer.encode(prompt, add_special_tokens=False)
print(data[prompt_id]["prompt"])
print("--------")
print(data[prompt_id]["output"][0])
print("--------")
print(tokenizer.decode(data[prompt_id]["output"][seq_id][min_tok: max_tok ]))


What is currently the easiest investment opportunity with the capital and the highest game
--------
[tensor(62439, device='cuda:0'), tensor(5380, device='cuda:0'), tensor(40, device='cuda:0'), tensor(2846, device='cuda:0'), tensor(16984, device='cuda:0'), tensor(912, device='cuda:0'), tensor(9341, device='cuda:0'), tensor(4623, device='cuda:0'), tensor(4131, device='cuda:0'), tensor(449, device='cuda:0'), tensor(904, device='cuda:0'), tensor(36755, device='cuda:0'), tensor(11, device='cuda:0'), tensor(719, device='cuda:0'), tensor(304, device='cuda:0'), tensor(3432, device='cuda:0'), tensor(596, device='cuda:0'), tensor(7100, device='cuda:0'), tensor(10182, device='cuda:0'), tensor(11, device='cuda:0'), tensor(1618, device='cuda:0'), tensor(527, device='cuda:0'), tensor(2380, device='cuda:0'), tensor(9341, device='cuda:0'), tensor(73234, device='cuda:0'), tensor(449, device='cuda:0'), tensor(12309, device='cuda:0'), tensor(3428, device='cuda:0'), tensor(6864, device='cuda:0'), tensor(8

In [318]:
max_split = 20

output_sequence_original = [ token.item() for token in data[prompt_id]["output"][seq_id] ][min_tok: max_tok ]

splits = []

for m in range(max_split):
    output_sequence = output_sequence_original

    for _ in range(m):
                        
        output_sequence = split_token(output_sequence, tokenizer, tokenizer.vocab)
    
    splits.append(output_sequence)

sampling_conditions = []
temperature = 1.3
top_p=0.95



for m in range(max_split):
    
    
    sampling_conditions.append(
        verify_sampling_conditions(prompt_tokens + splits[m], len(prompt_tokens) , top_k=None, top_p=top_p, model=model, tokenizer=tokenizer, temp = temperature)["all_top_p_met"]
        
        
    )

print(sampling_conditions)

[True, True, True, True, True, True, True, True, True, True, True, True, False, False, False, False, False, False, False, False]


In [320]:
#print(tokenizer.decode(splits[0]))
#print(splits[0])
#print_token_sequence(tokenizer, splits[0])
print("----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------")
print(   print_tokens_with_reference(tokenizer, splits[0], splits[0])   )
print("----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------")
print(   print_tokens_with_reference(tokenizer, splits[1],splits[0])   )
print("----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------")
print(   print_tokens_with_reference(tokenizer, splits[2],splits[0])   )
print("----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------")
print(   print_tokens_with_reference(tokenizer, splits[3],splits[0])   )
print("----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------")
print(   print_tokens_with_reference(tokenizer, splits[4],splits[0])   )
print("----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------")
print(   print_tokens_with_reference(tokenizer, splits[5],splits[0])   )
print("----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------")
print(   print_tokens_with_reference(tokenizer, splits[6],splits[0])   )
print("----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------")
print(   print_tokens_with_reference(tokenizer, splits[7],splits[0])   )
print("----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------")
print(   print_tokens_with_reference(tokenizer, splits[8],splits[0])   )
print("----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------")
print(   print_tokens_with_reference(tokenizer, splits[9],splits[0])   )
print("----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------")
print(   print_tokens_with_reference(tokenizer, splits[10],splits[0])   )


----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
?
|The| easiest| way| to| invest| in| property| with| the| capital| is| still| considered| for| |3| or| more| properties| or| a| real| estate| mutual| fund|.
|Real| estate| investment| trusts| or| RE|IT|s|,| real| estate| mutual| funds| may| be| the| easiest|.| One| can| not| simply| buy| into| a| building|,| but| you| can| gain| an| interest| or| gain| into| something| that| the| others| are| making| that| investment| profitable|.| There| are| many| options| for| acquiring| income| such| as| ground| level| rental| or| owning| a| building| through| a| partnership|.
|The| highest| performing| investing| may| remain| a| gamble| and| have| no| guarantee|.| The| next| highest| would| have| to| be| investing| in| stocks| and| bonds|,| the| old| main|stay|.| Div|idend| paying| and| bonds| both| have| higher| reliabi