In [16]:
from transformers import AutoTokenizer
from typing import List

In [5]:
tokenizer_name = "reshinthadith/codegen_350M_list_manip_5_len"

# tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')  # Using BERT tokenizer as an example
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

def find_overlapping_tokens(text, char_range):

    # Tokenize the text
    encodings = tokenizer(text, return_offsets_mapping=True, truncation=True, max_length=512)

    # Get the list of tokens and their corresponding start and end positions in the original text
    offsets = encodings.offset_mapping

    # Initialize the list to store the token indices
    token_indices = []

    # Find the tokens that overlap with the given character range
    for token_idx, (token_start, token_end) in enumerate(offsets):
        if char_range[0] < token_end and char_range[1] > token_start:  # Check for overlap
            token_indices.append(token_idx)  # Token indices are 0-indexed

    return token_indices

# Testing the function
text = "Hello World! How are you?"
char_range = (0, 11)  # Character range corresponding to "Hello World"
print(find_overlapping_tokens(text, char_range))  # Outputs: [0, 1, 2]


[0, 1]


In [10]:
text = "sub_n(reverse(take(sort_des(sort_asc([0, -5, -5, 5, 4, 0])),24)),-3)"
# find character range indices for search query
search = "reverse("

def get_matching_ranges(text, search):
    all_match_indices = []
    start = 0
    while True:
        match_index = text.find(search, start)
        if match_index == -1:
            break
        all_match_indices.append((match_index, match_index + len(search)))
        start = match_index + 1
    return all_match_indices

all_match_indices = get_matching_ranges(text, search)
print(all_match_indices)

all_match_indices = get_matching_ranges(text, "(")
print(all_match_indices)

print(find_overlapping_tokens(text, all_match_indices[0]))

print(find_overlapping_tokens(text, get_matching_ranges(text, "sub_n(")[0]))

[(6, 14)]
[(5, 6), (13, 14), (18, 19), (27, 28), (36, 37)]
[3]
[0, 1, 2, 3]


In [14]:
# print tokenized form of text as a list of tokens each converted to token
print(tokenizer.tokenize(text))

['sub', '_', 'n', '(', 'reverse', '(', 'take', '(', 'sort', '_', 'des', '(', 'sort', '_', 'asc', '([', '0', ',', 'Ġ-', '5', ',', 'Ġ-', '5', ',', 'Ġ5', ',', 'Ġ4', ',', 'Ġ0', '])', '),', '24', ')),', '-', '3', ')']


In [23]:
def reward_substring_matches(text: str, searches: List[str], max_reward: float) -> List[float]:
    """
    For each search string, find all matching ranges then find all overlapping tokens. Give max_reward for each of the overlapping tokens.
    Return a list of rewards for each token in the text. Non-overlapping tokens get 0 reward.
    """
    all_match_indices = []
    for search in searches:
        all_match_indices += get_matching_ranges(text, search)
    all_token_indices = []
    for match_index in all_match_indices:
        all_token_indices += find_overlapping_tokens(text, match_index)
    rewards = [0.0] * len(tokenizer.tokenize(text))
    for token_index in all_token_indices:
        rewards[token_index] = max_reward
    return rewards

print(reward_substring_matches(text, ["reverse("], 1.0))
print(reward_substring_matches(text, ["sub_n("], 1.0))
print(reward_substring_matches(text, ["reverse(", "sub_n("], 1.0))

[0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
