In [1]:
import bisect
import copy
import logging

import numpy as np
import rich
import rich.logging
import time
import torch
import transformers
import re


LOGGER = logging.getLogger(__name__)


logging.basicConfig(
    level=logging.DEBUG, 
    format="%(message)s", 
    handlers=[rich.logging.RichHandler(markup=True)]
)


2023-02-14 18:32:33.090205: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2023-02-14 18:32:36.585480: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /cvmfs/ai.mila.quebec/apps/arch/distro/libffi/3.2.1/lib:/cvmfs/ai.mila.quebec/apps/arch/common/cudnn/11.2-v8.1/lib64:/cvmfs/ai.mila.quebec/apps/arch/common/cudnn/11.2-v8.1/lib:/cvmfs/ai.mila.quebec/apps/arch/common/cuda/11.2/lib64:/cvmfs/ai.mila.quebec/apps/arch/common/nccl/11.2-v2.8/lib:/cvmfs/ai.mila.quebec/apps/arch/distro/openmpi/4.0.4/lib
2023-02-14 18:32:36.586333: W tensorflow/compiler/xla/s

In [2]:
t = transformers.AutoTokenizer.from_pretrained("google/flan-t5-xxl")

In [65]:
import bisect
def find_lt(a, x):
    'Find rightmost value less than x'
    i = bisect.bisect_left(a, x)
    if i:
        return i - 1
    raise ValueError

def find_le(a, x):
    'Find rightmost value less than or equal to x'
    i = bisect.bisect_right(a, x)
    if i:
        return i - 1
    raise ValueError

def find_gt(a, x):
    'Find leftmost value greater than x'
    i = bisect.bisect_right(a, x)
    if i != len(a):
        return i
    raise ValueError

def find_ge(a, x):
    'Find leftmost item greater than or equal to x'
    i = bisect.bisect_left(a, x)
    if i != len(a):
        return i
    raise ValueError

def extract_match_tokens(*, regexes, strings, tokenizer, tokenizer_kwargs=None):
    ########################################################################
    # Preliminary checks and setup
    ########################################################################
    if tokenizer_kwargs is None:
        tokenizer_kwargs = {}
    assert (
        "return_offsets_mapping" not in tokenizer_kwargs or 
        tokenizer_kwargs["return_offsets_mapping"]), (
        "`return_offsets_mapping` is required."
    )

    # compile the regexes in place:
    for i, regex in enumerate(regexes):
        if isinstance(regex, str):
            regexes[i] = re.compile(regex)
        else:
            assert isinstance(regex, re.Pattern), type(regex).mro()

    ########################################################################
    # Main bout
    ########################################################################
    
    # Tokenize
    tok_output = tokenizer(strings, return_offsets_mapping=True, **tokenizer_kwargs)
    tokens = tok_output["input_ids"]
    offsets = tok_output["offset_mapping"]

    # Extract right boundaries
    # if "return_tensors" not in tokenizer_kwargs or not tokenizer_kwargs["return_tensors"]:
    #     assert False
    #     left_boundaries = []
    #     right_boundaries = []
    #     for offset in offsets:
    #         left_local = []
    #         right_local = []
    #         for entry in offset:
    #             left_local.append(entry[0])
    #             right_local.append(entry[1])
    #         left_boundaries.append(left_local)
    #         right_boundaries.append(right_local)
    # elif (
    #     tokenizer_kwargs["return_tensors"] == "pt" or 
    #     tokenizer_kwargs["return_tensors"] == "np"
    # ):

        # left_boundaries = offsets[:, :, 0]
        # right_boundaries = offsets[:, :, 1]

    left_boundaries = []
    right_boundaries = []
    for offset, mask in zip(offsets, tok_output["attention_mask"]):
        left_local = []
        right_local = []
        largest = 0
        for offset_seq, mask_seq in zip(offset, mask):
            if mask_seq != 0:
                if offset_seq[0] > largest:
                    largest = offset_seq[0]
                

                left_local.append(largest)

                if offset_seq[1] > largest:
                    largest = offset_seq[1]

                right_local.append(largest)

        left_boundaries.append(left_local)
        right_boundaries.append(right_local)

    outputs = []
    LOGGER.debug(
        f"[bold blue]offsets:[/]           {offsets}\n"
        f"[bold blue]left_boundaries:[/]   {left_boundaries}\n"
        f"[bold blue]right_boundaries:[/]  {right_boundaries}\n"
        f"[bold blue]pairs:[/]             {[list(zip(a, b)) for a, b in zip(left_boundaries, right_boundaries)]}\n"
    )

    # Extract the matches
    for i, (toks, l_b, r_b, str_, regex) in enumerate(zip(tokens, left_boundaries, right_boundaries, strings, regexes)):
        matches = list(regex.finditer(str_))
        per_str_output = []
        for j, match in enumerate(matches):
            start_char, end_char = match.span()
            lb_right = torch.searchsorted(torch.tensor(l_b), start_char, side="right")
            lb_left  = torch.searchsorted(torch.tensor(l_b), start_char, side="left")
            rb_right = torch.searchsorted(torch.tensor(r_b), end_char,   side="right")
            rb_left  = torch.searchsorted(torch.tensor(r_b), end_char,   side="left")

            start_idx = find_le(l_b, start_char) - 1
            end_idx   = find_le(r_b, end_char)

            LOGGER.debug(
                "\n" +
                "-" * 40 + "\n" +
                f"[bold green]String {i + 1}/{len(strings)} Match {j + 1}/{len(matches)}:[/]\n" +
                "-" * 40 + "\n" +
                f"[bold blue]string match:[/]     `{str_[start_char:end_char]}`" + "\n" +
                f"[bold blue]token match:[/]      `{tokenizer.decode(toks[start_idx:end_idx + 1])}`" + "\n" +
                f"[bold blue]start_char:[/]        {start_char}" + "\n" +
                f"[bold blue]end_char:[/]          {end_char}"   + "\n" +
                f"[bold blue]start_idx:[/]         {start_idx}"  + "\n" +
                f"[bold blue]end_idx:[/]           {end_idx}"    + "\n"+
                f"[bold blue]l_b_right:[/]         {lb_right}"   + "\n"+
                f"[bold blue]l_b_left:[/]          {lb_left}"    + "\n"+
                f"[bold blue]r_b_right:[/]         {rb_right}"   + "\n"+
                f"[bold blue]r_b_left:[/]          {rb_left}"    + "\n" +
                f"[bold blue]l_b:[/]               " + str([(i, int(b)) for i, b in enumerate(l_b)]) + "\n" +
                f"[bold blue]r_b:[/]               " + str([(i, int(b)) for i, b in enumerate(r_b)]) + "\n" +
                f"[bold blue]both boundaries:[/]   " + str([(i, (int(l), int(r))) for i, (l, r) in enumerate(zip(l_b, r_b))]) + "\n" +
                f"[bold blue]tokens:[/]            " + str([(i, tokenizer.decode([t], skip_special_tokens=False)) for i, t in enumerate(toks)]) + "\n" +
                f"[bold blue]token ids:[/]         " + str([(i, int(t)) for i, t in enumerate(toks)]) + "\n" +
                "-" * 40 + "\n"
            )

            per_str_output.append((start_idx, end_idx, tokenizer.decode(toks[start_idx:end_idx + 1])))
        outputs.append(per_str_output)
    return tok_output, outputs

strings = [
    "This is a string with 32362513213 potatoes and 15 apples", 
    "222",
    "222 ",
]

tok_outputs, outputs = extract_match_tokens(
    regexes=[r"\d+"] * len(strings),
    strings=strings, 
    tokenizer=t, 
    tokenizer_kwargs=dict(),
)


# tok_outputs_copy = copy.deepcopy(tok_outputs)
# for i, (tok_output, output) in enumerate(zip(tok_outputs, outputs)):
#     print(f"{output = }")
#     for output_ in output:
#         print(f"{output_ = }")
#         tok_outputs_copy["input_ids"][i][output_[0] : output_[1] + 1] = 0
#     print(t.decode(tok_outputs_copy["input_ids"][i]))

