In [1]:
import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np
import matplotlib.pyplot as plt

if torch.cuda.is_available():
    torch.set_default_tensor_type('torch.cuda.FloatTensor')
    print('----------- Float tensor set --------------')
else:
    print('---------------------- No CUDA -----------------')

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model_name = "BAAI/bge-large-en-v1.5"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name).to(device)

  _C._set_default_tensor_type(t)


----------- Float tensor set --------------


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


In [2]:
valid_token_ids = []
for id in range(tokenizer.vocab_size):
    token = tokenizer.convert_ids_to_tokens(id)
    # Filter for actual words - you may need to adjust these conditions for BGE
    if (
        not token.startswith('[') and  # Skip special tokens
        not token.startswith('##') and  # Skip subword pieces
        not any(c in token for c in '〜་『』«»‰―⟩（') and  # Skip special characters
        # len(token) > 1 and  # Skip single characters
        token.isascii()   # Only keep ASCII tokens
        # not token.isnumeric()
    ):
        valid_token_ids.append(id)

print(f'Number of valid tokens: {len(valid_token_ids)}')

Number of valid tokens: 22748


In [3]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm
from itertools import product

def find_matching_tokens(model, tokenizer, target_embedding, valid_token_ids, num_tokens=1, batch_size=32, similarity_threshold=0.99999):
    """Find token combinations that produce embeddings matching the target"""
    device = next(model.parameters()).device
    target_embedding = F.normalize(torch.tensor(target_embedding.clone().detach(), device=device).unsqueeze(0), p=2, dim=1)

    def process_batch(token_combos):
        input_ids = torch.tensor([[101] + list(combo) + [102] for combo in token_combos], device=device)
        attention_mask = torch.ones_like(input_ids)
        token_type_ids = torch.zeros_like(input_ids)

        with torch.no_grad():
            embeddings = model(input_ids, attention_mask, token_type_ids).pooler_output
            similarities = torch.mm(F.normalize(embeddings, p=2, dim=1), target_embedding.T).squeeze()

            matches = []
            for idx in torch.where(similarities >= similarity_threshold)[0]:
                token_ids = token_combos[idx]
                matches.append({
                    'tokens': [tokenizer.convert_ids_to_tokens(tid) for tid in token_ids],
                    'token_ids': token_ids,
                    'similarity': similarities[idx].item()
                })
            return matches

    total_combinations = len(valid_token_ids) ** num_tokens
    print(f"Searching through {total_combinations} combinations...")

    matches = []
    current_batch = []

    for combo in tqdm(product(valid_token_ids, repeat=num_tokens), total=total_combinations):
        current_batch.append(combo) # keeps looping until current batch is batch size full
        if len(current_batch) == batch_size:
            matches.extend(process_batch(current_batch))
            current_batch = []

    if current_batch:  # Process any remaining combinations
        matches.extend(process_batch(current_batch))

    return sorted(matches, key=lambda x: x['similarity'], reverse=True)

def printMatches(matchesList):
    print(f'\n\nFound {len(matchesList)} matching tokens')
    for tok in matchesList:
        print(f"Token: {tok['tokens']}, ID: {tok['token_ids']}, Similarity: {tok['similarity']:.6f}")

def processTokens(nTokenMatches):
    return list(map(lambda a: a['token_ids'][0], nTokenMatches))

def processTokensForStrings(nTokenMatches):
    return list(map(lambda a: a['tokens'][0], nTokenMatches))

In [4]:
targetText = 'magic is real'
print(tokenizer(targetText, return_tensors='pt'))
target_embedding = model(**tokenizer(targetText, return_tensors='pt'), output_hidden_states=True).pooler_output[0]
print(target_embedding)

{'input_ids': tensor([[ 101, 3894, 2003, 2613,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}
tensor([-0.9450, -0.8097, -0.8362,  ...,  0.4859,  0.9817, -0.8716],
       grad_fn=<SelectBackward0>)


In [5]:
oneTokenMatches = find_matching_tokens(model, tokenizer, target_embedding, valid_token_ids, num_tokens=1, batch_size=32, similarity_threshold=0)
print(f'One token matches: {len(oneTokenMatches)}')

  target_embedding = F.normalize(torch.tensor(target_embedding.clone().detach(), device=device).unsqueeze(0), p=2, dim=1)


Searching through 22748 combinations...


100%|██████████| 22748/22748 [00:22<00:00, 1005.97it/s]

One token matches: 22748





In [6]:
printMatches(oneTokenMatches[:100])



Found 100 matching tokens
Token: ['shared'], ID: (4207,), Similarity: 0.961855
Token: ['wand'], ID: (23967,), Similarity: 0.961072
Token: ['mcgregor'], ID: (23023,), Similarity: 0.960443
Token: ['andy'], ID: (5557,), Similarity: 0.959906
Token: ['called'], ID: (2170,), Similarity: 0.957359
Token: ['apparently'], ID: (4593,), Similarity: 0.957178
Token: ['manifested'], ID: (24906,), Similarity: 0.956952
Token: ['reza'], ID: (26323,), Similarity: 0.956810
Token: ['lenny'], ID: (19065,), Similarity: 0.956553
Token: ['zee'], ID: (23727,), Similarity: 0.956521
Token: ['boss'], ID: (5795,), Similarity: 0.956143
Token: ['magician'], ID: (16669,), Similarity: 0.955991
Token: ['chong'], ID: (24008,), Similarity: 0.955913
Token: ['donkey'], ID: (20325,), Similarity: 0.955821
Token: ['garth'], ID: (21523,), Similarity: 0.955528
Token: ['scrolls'], ID: (23074,), Similarity: 0.954058
Token: ['sheng'], ID: (25981,), Similarity: 0.953799
Token: ['nedra'], ID: (28240,), Similarity: 0.953260
Token: [

In [7]:
processedTokens = processTokens(oneTokenMatches)
processedTokenStrings = processTokensForStrings(oneTokenMatches)

In [8]:
idx = processedTokens.index(3894)
print(f"The index of token magic is: {idx} -> {processedTokens[idx]} - {oneTokenMatches[idx]}")

idx = processedTokens.index(2003)
print(f"The index of token is is: {idx} -> {processedTokens[idx]} - {oneTokenMatches[idx]}")

idx = processedTokens.index(2613)
print(f"The index of token real is: {idx} -> {processedTokens[idx]} - {oneTokenMatches[idx]}")

idx = processedTokens.index(16669)
print(f"The index of token magician is: {idx} -> {processedTokens[idx]} - {oneTokenMatches[idx]}")

The index of token magic is: 198 -> 3894 - {'tokens': ['magic'], 'token_ids': (3894,), 'similarity': 0.9440378546714783}
The index of token is is: 9418 -> 2003 - {'tokens': ['is'], 'token_ids': (2003,), 'similarity': 0.888039767742157}
The index of token real is: 15094 -> 2613 - {'tokens': ['real'], 'token_ids': (2613,), 'similarity': 0.8671032190322876}
The index of token magician is: 11 -> 16669 - {'tokens': ['magician'], 'token_ids': (16669,), 'similarity': 0.9559913873672485}


In [9]:
import numpy as np

pooledEmbeddingsOfValidTokens = np.load('/content/pooledEmbeddingsOfAllValidTokens.npy')
pooledEmbeddingsOfValidTokens = torch.tensor(pooledEmbeddingsOfValidTokens, device=device)
pooledEmbeddingsOfValidTokens

tensor([[-0.0344, -0.0173, -0.0039,  ..., -0.0270,  0.0351, -0.0341],
        [-0.0367, -0.0316,  0.0099,  ..., -0.0207,  0.0373, -0.0367],
        [-0.0384, -0.0321, -0.0104,  ..., -0.0187,  0.0386, -0.0375],
        ...,
        [-0.0377, -0.0252, -0.0004,  ..., -0.0097,  0.0391, -0.0365],
        [-0.0356, -0.0156, -0.0149,  ..., -0.0068,  0.0382, -0.0358],
        [-0.0389, -0.0378, -0.0303,  ...,  0.0041,  0.0406, -0.0389]])

In [10]:
similarity_threshold = 0.9
last_idx = -1
for idx, tok in enumerate(oneTokenMatches):
    if tok['similarity'] > similarity_threshold:
        last_idx = idx
print(f"The index of the last token with similarity above {similarity_threshold} is {last_idx}")

The index of the last token with similarity above 0.9 is 6281


In [11]:
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModel
from tqdm import tqdm

def get_top_k_similar_tokens(input_tokens, valid_embeddings, k=5):
    """
    Find top k most cosine similar tokens from valid_token_ids for each input token
    based on the final pooled embeddings from BGE model.

    Args:
        input_token_ids (List[int]): List of input token IDs
        valid_token_ids (List[int]): List of valid token IDs to compare against
        model_name (str): Name of the BGE model to use
        k (int): Number of similar tokens to return for each input token

    Returns:
        List[List[Tuple[int, float]]]: For each input token, returns list of (token_id, similarity_score) tuples
    """
    # Get embeddings for input tokens
    input_embeddings = []
    for tok in input_tokens:
        # Create input with single token
        inputs = tokenizer(tok, return_tensors='pt').to(device)

        with torch.no_grad():
            outputs = model(**inputs)
            # Use pooled output for token representation
            pooled_output = outputs.pooler_output
            input_embeddings.append(pooled_output)

    # Stack all input embeddings
    input_embeddings = torch.cat(input_embeddings, dim=0)

    # Normalize embeddings
    input_embeddings = F.normalize(input_embeddings, p=2, dim=1)

    # Calculate cosine similarity
    similarity_matrix = torch.mm(input_embeddings, valid_embeddings.t())

    # Get top k similar tokens for each input token
    results = []
    top_k_values, top_k_indices = torch.topk(similarity_matrix, k=k, dim=1)

    for i in range(len(input_tokens)):
        token_results = []
        for j in range(k):
            similar_token_id = valid_token_ids[top_k_indices[i][j].item()]
            similarity_score = top_k_values[i][j].item()
            token_results.append((similar_token_id, similarity_score, tokenizer.decode([similar_token_id])))
        results.append(token_results)

    return results

results = [j for i in get_top_k_similar_tokens(processedTokenStrings[:50], pooledEmbeddingsOfValidTokens, k=1) for j in i]
resultsWithTokenIds = [i[0] for i in results]
for i in results:
    print(i)

(4207, 1.0000001192092896, 'shared')
(23967, 0.9999997615814209, 'wand')
(23023, 1.0000001192092896, 'mcgregor')
(5557, 1.0, 'andy')
(2170, 0.9999998807907104, 'called')
(4593, 1.000000238418579, 'apparently')
(24906, 0.9999995231628418, 'manifested')
(26323, 1.0000001192092896, 'reza')
(19065, 1.0000001192092896, 'lenny')
(23727, 1.0, 'zee')
(5795, 1.0000001192092896, 'boss')
(16669, 0.9999997019767761, 'magician')
(24008, 0.9999998807907104, 'chong')
(20325, 1.0000003576278687, 'donkey')
(21523, 1.0000003576278687, 'garth')
(23074, 1.0, 'scrolls')
(25981, 0.9999999403953552, 'sheng')
(28240, 0.9999997615814209, 'nedra')
(16980, 0.9999999403953552, 'elves')
(12523, 0.9999998211860657, 'garion')
(5421, 0.9999997615814209, 'translated')
(22050, 1.0000001192092896, 'symbolism')
(18850, 1.0000003576278687, 'isaiah')
(14682, 1.0000004768371582, 'hunted')
(22788, 0.9999999403953552, 'mummy')
(28163, 0.9999997615814209, 'aldo')
(26511, 0.9999996423721313, 'lear')
(18773, 0.9999998807907104, 

In [12]:
import torch
from sklearn.metrics.pairwise import cosine_similarity

def calculate_similarity(item, target_embedding, num_times=3):
    try:
        with torch.no_grad():
            input_embedding = model(**tokenizer(' '.join([item[2]]*num_times), return_tensors='pt')).pooler_output
            input_embedding = F.normalize(input_embedding, p=2, dim=1)
            target_norm = F.normalize(target_embedding.unsqueeze(0), p=2, dim=1)
            similarity = torch.mm(input_embedding, target_norm.T).item()
            return similarity
    except Exception as e:
        print(f"Error calculating similarity for item {item}: {e}")
        return -1

def sort_list_by_similarity(data_list, target_embedding, model, tokenizer):
    def calculate_similarity_helper(item):
        return calculate_similarity(item, target_embedding)
    sorted_list = sorted(data_list, key=calculate_similarity_helper, reverse=True)
    return sorted_list

similarTokensToTopMatches = [j for i in get_top_k_similar_tokens(processedTokenStrings[:50], pooledEmbeddingsOfValidTokens, k=10) for j in i]
sortedList = sort_list_by_similarity(similarTokensToTopMatches, target_embedding, model, tokenizer)

seen = {}
for item in sortedList:
    text = item[2]
    if text not in seen:
        seen[text] = calculate_similarity(item, target_embedding)

sortedListWithoutDupes = sorted(seen.items(), key=lambda x: x[1], reverse=True)

# Print in sorted order
for text, score in sortedListWithoutDupes:
    print(f'{text} - [{score}]')

mana - [0.9600640535354614]
arthur - [0.9551794528961182]
manifested - [0.9551074504852295]
orb - [0.9541681408882141]
marco - [0.9530236124992371]
scrolls - [0.9522126913070679]
nedra - [0.9516910314559937]
intuition - [0.9515164494514465]
trained - [0.9514984488487244]
hunted - [0.9512765407562256]
wand - [0.9511296153068542]
owns - [0.9510325193405151]
abby - [0.9508976340293884]
magic - [0.9505640268325806]
sleeves - [0.9504505395889282]
eric - [0.9502708911895752]
chosen - [0.9501370191574097]
lenny - [0.9501152038574219]
called - [0.9498265981674194]
enoch - [0.9494744539260864]
karim - [0.9494023323059082]
mcgregor - [0.9493308067321777]
garth - [0.9490499496459961]
owner - [0.9488882422447205]
astrid - [0.9487721920013428]
xiao - [0.9487250447273254]
translated - [0.9485800862312317]
ark - [0.9484646916389465]
translates - [0.9478991031646729]
garion - [0.9477963447570801]
manny - [0.9477595686912537]
magician - [0.9473398327827454]
turned - [0.9472270607948303]
isaac - [0.9472

In [13]:
threeTokenMatches = find_matching_tokens(model, tokenizer, target_embedding, resultsWithTokenIds[:50], num_tokens=3, batch_size=10_000, similarity_threshold=0.0)

  target_embedding = F.normalize(torch.tensor(target_embedding.clone().detach(), device=device).unsqueeze(0), p=2, dim=1)


Searching through 125000 combinations...


100%|██████████| 125000/125000 [01:58<00:00, 1055.87it/s]


In [14]:
printMatches(threeTokenMatches[:400])



Found 400 matching tokens
Token: ['gratitude', 'amar', 'magician'], ID: (15531, 23204, 16669), Similarity: 0.977725
Token: ['gratitude', 'nedra', 'magician'], ID: (15531, 28240, 16669), Similarity: 0.977165
Token: ['magician', 'wand', 'gratitude'], ID: (16669, 23967, 15531), Similarity: 0.976124
Token: ['trained', 'nedra', 'magician'], ID: (4738, 28240, 16669), Similarity: 0.976040
Token: ['gratitude', 'shared', 'magician'], ID: (15531, 4207, 16669), Similarity: 0.976032
Token: ['magician', 'trained', 'trained'], ID: (16669, 4738, 4738), Similarity: 0.975762
Token: ['magician', 'intuition', 'mummy'], ID: (16669, 26406, 22788), Similarity: 0.975667
Token: ['magician', 'gratitude', 'amar'], ID: (16669, 15531, 23204), Similarity: 0.975258
Token: ['magician', 'wand', 'trained'], ID: (16669, 23967, 4738), Similarity: 0.974909
Token: ['magician', 'intuition', 'nedra'], ID: (16669, 26406, 28240), Similarity: 0.974773
Token: ['trained', 'aldo', 'magician'], ID: (4738, 28163, 16669), Similari

In [15]:
twoTokenMatches = find_matching_tokens(model, tokenizer, target_embedding, resultsWithTokenIds[:100], num_tokens=2, batch_size=10_000, similarity_threshold=0.0)

  target_embedding = F.normalize(torch.tensor(target_embedding.clone().detach(), device=device).unsqueeze(0), p=2, dim=1)


Searching through 2500 combinations...


100%|██████████| 2500/2500 [00:00<00:00, 1218849.24it/s]


In [16]:
printMatches(twoTokenMatches[:100])



Found 100 matching tokens
Token: ['magician', 'trained'], ID: (16669, 4738), Similarity: 0.976125
Token: ['magician', 'gratitude'], ID: (16669, 15531), Similarity: 0.975093
Token: ['trained', 'magician'], ID: (4738, 16669), Similarity: 0.973437
Token: ['manifested', 'lear'], ID: (24906, 26511), Similarity: 0.971552
Token: ['magician', 'manifested'], ID: (16669, 24906), Similarity: 0.971087
Token: ['wand', 'apparently'], ID: (23967, 4593), Similarity: 0.970904
Token: ['elves', 'intuition'], ID: (16980, 26406), Similarity: 0.970645
Token: ['aldo', 'magician'], ID: (28163, 16669), Similarity: 0.970627
Token: ['elves', 'magician'], ID: (16980, 16669), Similarity: 0.970465
Token: ['nedra', 'magician'], ID: (28240, 16669), Similarity: 0.970464
Token: ['magician', 'intuition'], ID: (16669, 26406), Similarity: 0.970413
Token: ['mundo', 'manifested'], ID: (25989, 24906), Similarity: 0.970172
Token: ['elves', 'collaborated'], ID: (16980, 8678), Similarity: 0.970118
Token: ['sleeves', 'magician

In [39]:
import torch
from transformers import AutoTokenizer, AutoModel
import numpy as np
import torch.nn.functional as F

def cosine_similarity(embedding1, embedding2):
  """
  Calculates the cosine similarity between two embeddings.

  Args:
    embedding1: The first embedding (PyTorch tensor).
    embedding2: The second embedding (PyTorch tensor).

  Returns:
    The cosine similarity between the two embeddings (float).
  """

  # Normalize the embeddings
  embedding1_normalized = F.normalize(embedding1, p=2, dim=0)
  embedding2_normalized = F.normalize(embedding2, p=2, dim=0)

  # Calculate the dot product
  dot_product = torch.dot(embedding1_normalized, embedding2_normalized)

  return dot_product.item()

text1 = 'magic is real'
text2 = 'magic are real'
print(tokenizer(text1, return_tensors='pt'))
print(tokenizer(text2, return_tensors='pt'))
cosine_similarity(model(**tokenizer(text1, return_tensors='pt'), output_hidden_states=True).pooler_output[0], model(**tokenizer(text2, return_tensors='pt'), output_hidden_states=True).pooler_output[0])

{'input_ids': tensor([[ 101, 3894, 2003, 2613,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}
{'input_ids': tensor([[ 101, 3894, 2024, 2613,  102]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1]])}


0.9960455894470215

In [18]:
from typing import Dict, List, Tuple, Union
import torch
import torch.nn as nn

def create_biased_token_parameter(
    valid_token_ids: List[int],
    position_biases: Dict[int, Dict[int, float]],
    semantic_neighbors: Dict[int, Dict[int, float]] = None,
    default_value: float = 0.001,
    num_positions: int = 3
) -> nn.Parameter:
    """
    Create a biased initialization for token selection learning.

    Args:
        valid_token_ids: List of valid token IDs
        position_biases: Dict mapping position -> {token_id: weight}
                        for explicit position-based token preferences
        semantic_neighbors: Dict mapping token_id -> {neighbor_id: similarity}
                          for semantic similarity-based scaling
        default_value: Base logit value for non-preferred tokens
        num_positions: Number of positions to initialize

    Returns:
        nn.Parameter: Initialized parameter tensor with biased weights
    """
    # Initialize base tensor
    learned_hidden = torch.full((num_positions, len(valid_token_ids)), default_value)

    # Helper to get index in valid_tokens list
    def get_token_idx(token_id: int) -> int:
        matches = (torch.tensor(valid_token_ids) == token_id).nonzero()
        if len(matches) == 0:
            print(f"Warning: token_id {token_id} not found in valid_token_ids")
            return None
        return matches.item()

    # Apply position-based biases
    for position, token_weights in position_biases.items():
        for token_id, weight in token_weights.items():
            idx = get_token_idx(token_id)
            learned_hidden[position, idx] = weight
            # print(f'{position} {idx} {weight}')

            # Apply semantic neighbor scaling if provided
            if semantic_neighbors and token_id in semantic_neighbors:
                base_weight = weight
                for neighbor_id, similarity in semantic_neighbors[token_id].items():
                    try:
                        neighbor_idx = get_token_idx(neighbor_id)
                        # Scale weight by similarity
                        neighbor_weight = base_weight * similarity
                        # Only set if it would increase the weight
                        if neighbor_weight > learned_hidden[position, neighbor_idx]:
                            learned_hidden[position, neighbor_idx] = neighbor_weight
                    except ValueError:
                        # Skip if neighbor_id not in valid_token_ids
                        continue

    return learned_hidden
    # return nn.Parameter(learned_hidden, requires_grad=True)

from collections import defaultdict
from typing import List, Dict, Any
import torch
from torch import nn
import numpy as np

def analyze_token_patterns(
    matches: List[Dict[str, Any]],
    frequency_weight: float = 1.5,
    similarity_weight: float = 2.5,
    similarity_threshold: float = 0.97,
) -> Dict[int, Dict[int, float]]:
    """
    Analyze token patterns from matching results and generate position-based biases.

    Args:
        matches: List of dictionaries from find_matching_tokens containing 'tokens',
                'token_ids', and 'similarity' keys
        top_k: Only consider top k matches (None for all)
        frequency_weight: Weight for frequency-based scoring (default: 1.0)
        similarity_weight: Weight for similarity-based scoring (default: 1.0)

    Returns:
        Dict[int, Dict[int, float]]: Position biases suitable for create_biased_token_parameter
    """
    # Initialize counters for each position
    position_counts = defaultdict(lambda: defaultdict(list))

    filtered_matches = [m for m in matches if m['similarity'] >= similarity_threshold]

    # Collect token occurrences and similarities by position
    for i, m in enumerate(filtered_matches):
        similarity = m['similarity']
        for pos, token_id in enumerate(m['token_ids']):
            position_counts[pos][token_id].append(similarity)

    # for k1,v1 in position_counts.items():
    #     print(f'Position {k1} {v1}')
    #     for k2,v2 in v1.items():
    #         print(f'{k2} - {len(v2)}')

    # Calculate weighted scores for each token in each position
    position_biases = {}
    for pos in position_counts:
        token_scores = {}
        max_freq = max(len(sims) for sims in position_counts[pos].values())

        for token_id, similarities in position_counts[pos].items():
            token_scores[token_id] = ((len(similarities)/max_freq)**frequency_weight) * (np.mean(similarities)**similarity_weight) * 10  # Scale factor can be adjusted

        position_biases[pos] = token_scores

    return position_biases

In [19]:
from typing import List, Dict, Tuple

def transform_list(data: List[List[Tuple[int, float, str]]]) -> Dict[int, Dict[int, float]]:
    result = {}
    for sublist in data:
        key = sublist[0][0]  # First element's first item
        result[key] = {
            item[0]: item[1]  # Create sub-dictionary with first and second items
            for item in sublist[1:]  # Skip the first item when creating sub-dict
        }
    return result

In [20]:
topSimilarTokens = [i for i in get_top_k_similar_tokens(processedTokenStrings[:50], pooledEmbeddingsOfValidTokens, k=4)]
print(topSimilarTokens)
similarWordsToPositionBias = transform_list(topSimilarTokens)

[[(4207, 1.0000001192092896, 'shared'), (8678, 0.976253867149353, 'collaborated'), (2170, 0.9750165939331055, 'called'), (2921, 0.9737001657485962, 'kept')], [(23967, 0.9999997615814209, 'wand'), (19607, 0.9738000631332397, 'orb'), (4880, 0.9729925990104675, 'cape'), (16148, 0.9719747304916382, 'forearm')], [(23023, 1.0000001192092896, 'mcgregor'), (20545, 0.981041431427002, 'conor'), (14093, 0.9763465523719788, 'abdullah'), (5557, 0.974487841129303, 'andy')], [(5557, 1.0, 'andy'), (25998, 0.9798793196678162, 'mikey'), (20545, 0.9773740172386169, 'conor'), (4463, 0.9771308302879333, 'jason')], [(2170, 0.9999998807907104, 'called'), (6153, 0.9771158695220947, 'backed'), (2921, 0.9765416383743286, 'kept'), (2056, 0.9764332175254822, 'said')], [(4593, 1.000000238418579, 'apparently'), (15329, 0.9897410273551941, 'evidently'), (10743, 0.9875698089599609, 'supposedly'), (3849, 0.9820659160614014, 'seems')], [(24906, 0.9999995231628418, 'manifested'), (19676, 0.9799712896347046, 'manifest'),

In [21]:
position_weighting = analyze_token_patterns(threeTokenMatches)
intialParam = create_biased_token_parameter(valid_token_ids, position_weighting, similarWordsToPositionBias)
print(position_weighting)
intialParam

{0: {15531: 5.594437405172545, 16669: 7.054537035666199, 4738: 2.1211520408910296, 23967: 5.585073604825677, 10743: 0.1642890269760727, 28240: 0.46492996690152194, 16980: 9.297872048580878, 26406: 3.541247007105273, 22788: 0.4646257663462889, 6135: 0.748853720897264, 4207: 1.4381539808618946, 19065: 0.3801976878461878, 25989: 0.2294348316325748, 23204: 0.30149461214727846, 24906: 1.3118056904342148, 21268: 0.10659413437134099, 21544: 0.05809524926933624, 26323: 0.1639208142436112, 15114: 0.22917820218792867, 8678: 0.9603885887127339, 16320: 0.16413711960368282, 26511: 0.7474715179106527, 24008: 0.020541966178182058, 23727: 0.1639772055341837, 4593: 0.10645041985450819, 5076: 0.05796692085875358, 20325: 0.02052874553230158, 21196: 0.05793684092940462, 19607: 0.020506825260333823, 23023: 0.3009839674238398, 18850: 0.02050435208562632, 2170: 0.02050313758952922, 12523: 0.05797339796437387, 28163: 0.10634621328525577, 22399: 0.05787355824461913, 25981: 0.02046454477802349, 21523: 0.0204629

tensor([[0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
        [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010],
        [0.0010, 0.0010, 0.0010,  ..., 0.0010, 0.0010, 0.0010]])

In [22]:
normalizedLearnedHidden = intialParam / intialParam.sum(dim=1, keepdim=True)
normalizedLearnedHidden

tensor([[5.6349e-06, 5.6349e-06, 5.6349e-06,  ..., 5.6349e-06, 5.6349e-06,
         5.6349e-06],
        [6.7519e-06, 6.7519e-06, 6.7519e-06,  ..., 6.7519e-06, 6.7519e-06,
         6.7519e-06],
        [8.9070e-06, 8.9070e-06, 8.9070e-06,  ..., 8.9070e-06, 8.9070e-06,
         8.9070e-06]])

In [23]:
selected_indices = torch.multinomial(normalizedLearnedHidden, num_samples=1)
selected_values = normalizedLearnedHidden[torch.arange(len(normalizedLearnedHidden)), selected_indices.squeeze()]
selected_values

tensor([5.6349e-06, 6.1322e-02, 8.2929e-02])

In [24]:
values, indices = torch.max(normalizedLearnedHidden, dim=1)
[(tokenizer.decode(valid_token_ids[i]), values[idx]) for idx,i in enumerate(indices)]

[('elves', tensor(0.0524)),
 ('intuition', tensor(0.0628)),
 ('magician', tensor(0.0829))]

In [25]:
magic_output = model(**tokenizer(targetText, return_tensors='pt'), output_hidden_states=True)
# print(tokenizer)

In [26]:
# import torch
# import torch.nn as nn
# import torch.optim as optim
# import math
# import torch.nn.functional as F

# targetText = 'magic is real'
# # Get the magic embedding and fantasy hidden states
# with torch.no_grad():
#     # Get magic's pooled embedding
#     magic_output = model(**tokenizer(targetText, return_tensors='pt'), output_hidden_states=True)
#     magic_embedding = magic_output.pooler_output

#     # Get fantasy's last hidden states
#     numTokensInput = (2) + 3 # start and end token + how many actual tokens

# # Create learnable hidden states initialized with fantasy's hidden states
# valid_tokens = torch.tensor(valid_token_ids)

# learned_hidden = normalizedLearnedHidden
# # Setup optimizer
# # optimizer = optim.Adam([learned_hidden], lr=0.1)

# # learned_hidden_checkpoints = [learned_hidden.cpu().detach().clone()]
# tokensWithSimilarity = []
# # Training loop
# n_steps = 100_000
# for step in tqdm(range(n_steps+1)):
#     # initial_temp = 2.0
#     # final_temp = 0.1
#     # temperature = initial_temp * (final_temp / initial_temp) ** (step / n_steps)
#     # # Apply Gumbel-Softmax to get differentiable discrete tokens
#     # gumbel_probs = F.gumbel_softmax(learned_hidden, tau=temperature, hard=True, dim=-1)
#     # selected_tokens = (gumbel_probs @ valid_tokens.float()).long()

#     selected_indices = torch.multinomial(normalizedLearnedHidden, num_samples=1)
#     selected_tokens = torch.tensor([valid_token_ids[i] for i in selected_indices])

#     input_ids = torch.cat([
#         torch.tensor([101], dtype=torch.long),
#         selected_tokens.long(),
#         torch.tensor([102], dtype=torch.long)
#     ]).unsqueeze(0)

#     with torch.no_grad():
#         layer_output = model(
#             input_ids = input_ids,
#             attention_mask=torch.tensor([[1] * numTokensInput]),
#             output_hidden_states=True
#         ).pooler_output

#     # Compute MSE loss between pooled result and magic embedding
#     # mseLoss = nn.MSELoss()(layer_output, magic_embedding)
#     cosineLoss = 1 - nn.CosineSimilarity(dim=1)(layer_output, magic_embedding)

#     # Backward pass and optimization
#     # optimizer.zero_grad()
#     # cosineLoss.backward()
#     # optimizer.step()

#     if cosineLoss.item() < 1e-12:
#         print("Reached zero error!")
#         break

#     if cosineLoss.item() < 0.03:
#         tokensWithSimilarity.append((tokenizer.convert_ids_to_tokens(selected_tokens.tolist()), cosineLoss))

#     if cosineLoss.item() < 0.05:
#         value_to_add = (0.06 - cosineLoss.item())/10
#         row_indices = torch.arange(len(selected_indices))
#         normalizedLearnedHidden[row_indices, selected_indices.squeeze()] += value_to_add
#         normalizedLearnedHidden = normalizedLearnedHidden / normalizedLearnedHidden.sum(dim=1, keepdim=True)

#     # # Print progress every 100 steps
#     # if (step + 1) % 100 == 0:
#     #     print(f'Step {step + 1}/{n_steps}, MSE: {mseLoss.item()}, Cosine: {cosineLoss.item()}')
#     #     tokens = tokenizer.convert_ids_to_tokens(selected_tokens.tolist())
#     #     print(f'Current tokens: {tokens}')
#     #     # learned_hidden_checkpoints.append(learned_hidden.cpu().detach().clone())

In [27]:
import torch
import torch.nn as nn
import torch.optim as optim
import math
import torch.nn.functional as F

targetText = 'magic is real'
# Get the magic embedding and fantasy hidden states
with torch.no_grad():
    # Get magic's pooled embedding
    magic_output = model(**tokenizer(targetText, return_tensors='pt'), output_hidden_states=True)
    magic_embedding = magic_output.pooler_output

    # Get fantasy's last hidden states
    numTokensInput = (2) + 3 # start and end token + how many actual tokens

# Create learnable hidden states initialized with fantasy's hidden states
valid_tokens = torch.tensor(valid_token_ids)

learned_hidden = normalizedLearnedHidden
# Setup optimizer
# optimizer = optim.Adam([learned_hidden], lr=0.1)

# learned_hidden_checkpoints = [learned_hidden.cpu().detach().clone()]
tokensWithSimilarity = []
# Training loop
n_steps = 100
for step in tqdm(range(n_steps+1)):
    batch_size = 5_000  # Adjust based on your memory constraints

    # normalizedLearnedHidden shape: [3, num_possible_tokens]
    selected_indices = torch.vstack([
        torch.multinomial(normalizedLearnedHidden[i], num_samples=batch_size, replacement=True)
        for i in range(3)
    ])  # Shape: [3, batch_size]

    # Convert indices to tokens for each row
    selected_tokens = torch.vstack([
        torch.tensor([valid_token_ids[i] for i in row])
        for row in selected_indices
    ])  # Shape: [3, batch_size]

    # Transpose to get [batch_size, 3]
    selected_tokens = selected_tokens.t()

    # Create batched input tensor
    batch_start_tokens = torch.full((batch_size, 1), 101, dtype=torch.long)
    batch_end_tokens = torch.full((batch_size, 1), 102, dtype=torch.long)
    input_ids = torch.cat([
        batch_start_tokens,
        selected_tokens.long(),
        batch_end_tokens
    ], dim=1)  # Final shape: [batch_size, 5] (start + 3 tokens + end)

    attention_mask = torch.ones_like(input_ids)

    with torch.no_grad():
        layer_output = model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            output_hidden_states=True
        ).pooler_output

    # Compute MSE loss between pooled result and magic embedding
    cosineLoss = 1 - nn.CosineSimilarity(dim=1)(layer_output, magic_embedding)

    # Handle collecting tokens with similarity < 0.03
    good_samples_mask = cosineLoss < 0.03
    if good_samples_mask.any():
        good_tokens = selected_tokens[good_samples_mask]
        good_losses = cosineLoss[good_samples_mask]
        tokensWithSimilarity.extend([
            (tokenizer.convert_ids_to_tokens(tokens.tolist()), loss)
            for tokens, loss in zip(good_tokens, good_losses)
        ])

    # Handle probability updates for samples < 0.05
    update_mask = cosineLoss < 0.05
    if update_mask.any():
        values_to_add = (0.06 - cosineLoss[update_mask])/10

        # Update each row's probabilities in one operation
        for row in range(3):
            indices = selected_indices[row, update_mask]
            normalizedLearnedHidden[row].index_add_(0, indices, values_to_add)

    # Renormalize after all updates
    normalizedLearnedHidden = normalizedLearnedHidden / normalizedLearnedHidden.sum(dim=1, keepdim=True)

100%|██████████| 101/101 [08:58<00:00,  5.33s/it]


In [30]:
unique_items = []
seen_first_items = set()

for item in tokensWithSimilarity:
  first_item = tuple(item[0]) # Convert the list to a tuple for hashing
  if first_item not in seen_first_items:
    unique_items.append(item)
    seen_first_items.add(first_item)

print(f"Number of unique items: {len(unique_items)}")

Number of unique items: 2830


In [31]:
unique_items.sort(key=lambda x: x[1])
print(f'Length: {len(unique_items)}')
unique_items

Length: 2830


[(['magic', 'magic', 'manifested'], tensor(0.0154)),
 (['magic', 'lives', 'magic'], tensor(0.0166)),
 (['shared', 'undoubtedly', 'magic'], tensor(0.0167)),
 (['mana', 'lives', 'magic'], tensor(0.0169)),
 (['reared', 'magic', 'manifested'], tensor(0.0175)),
 (['magic', 'reared', 'manifested'], tensor(0.0175)),
 (['workings', 'manifest', 'magic'], tensor(0.0175)),
 (['magic', 'reared', 'shared'], tensor(0.0179)),
 (['magic', 'lives', 'manifest'], tensor(0.0179)),
 (['shared', 'lives', 'magic'], tensor(0.0179)),
 (['mana', 'thank', 'magic'], tensor(0.0180)),
 (['thankful', 'magic', 'shared'], tensor(0.0181)),
 (['intuition', 'raised', 'magic'], tensor(0.0182)),
 (['workings', 'profound', 'magic'], tensor(0.0183)),
 (['magic', 'instinctively', 'nedra'], tensor(0.0183)),
 (['manifest', 'magic', 'manifested'], tensor(0.0183)),
 (['magic', 'magic', 'trained'], tensor(0.0185)),
 (['magic', 'raised', 'manifested'], tensor(0.0185)),
 (['magic', 'reared', 'manifest'], tensor(0.0186)),
 (['magic',

In [None]:
def display_top_tokens_per_position(hidden, tokenizer, valid_token_ids, top_k=10):
    # Convert to probabilities
    logits = F.softmax(hidden, dim=-1)

    # For each position
    for pos in range(hidden.shape[0]):
        # Get top k values and indices
        values, indices = torch.topk(logits[pos], k=top_k)

        print(f"\nPosition {pos} top {top_k} tokens:")
        print("-" * 40)

        # Convert indices back to token IDs and then to tokens
        for i, (value, idx) in enumerate(zip(values, indices)):
            token_id = valid_token_ids[idx]
            token = tokenizer.convert_ids_to_tokens([token_id])[0]
            print(f"{i+1:2d}. {token:15s} (prob: {value:.4f})")

for idx,i in enumerate(learned_hidden_checkpoints):
    print(f'Checkpoint {idx}')
    print(display_top_tokens_per_position(i, tokenizer, valid_token_ids))
    print('*' * 100)

In [None]:
# Check if checkpoints are different
for i in range(len(learned_hidden_checkpoints)-1):
    diff = (learned_hidden_checkpoints[i] - learned_hidden_checkpoints[i+1]).abs().sum()
    print(f"Difference between checkpoint {i} and {i+1}: {diff.item()}")