## Import packages

In [1]:
import copy
import json
import math
from collections import defaultdict
from functools import lru_cache
from typing import Callable, DefaultDict, Dict, List, Union

import torch
from lark import Lark
from outlines import grammars
from outlines.caching import cache
from outlines.fsm.guide import CFGGuide, Generate, Guide, RegexGuide, Write
from outlines.fsm.json_schema import build_regex_from_schema
from pydantic import BaseModel
from transformers import PreTrainedTokenizerBase

## regex pattern from BaseModel

In [2]:
class User(BaseModel):
    name: str
    age: int

schema_str = json.dumps(User.model_json_schema())
whitespace_pattern = r"[\n]?"
regex_string = build_regex_from_schema(schema_str, whitespace_pattern)
print(regex_string)

\{[\n]?"name"[\n]?:[\n]?"([^"\\\x00-\x1F\x7F-\x9F]|\\["\\])*"[\n]?,[\n]?"age"[\n]?:[\n]?(-)?(0|[1-9][0-9]*)[\n]?\}


## RegexLogitsProcessor definition

In [3]:
class BaseLogitsProcessor:
    def __init__(self, guide: Guide):
        self._guide: Guide = guide
        self._fsm_state: DefaultDict[int, int] = defaultdict(int)

    def __call__(self, input_ids: List[int], scores: torch.Tensor) -> torch.Tensor:
        """Use the FSM to bias the logits before sampling the next token."""
        seq_id = hash(tuple(input_ids))

        if len(input_ids) > 0:
            last_token = input_ids[-1]
            last_seq_id = hash(tuple(input_ids[:-1]))
            self._fsm_state[seq_id] = self._guide.get_next_state(
                state=self._fsm_state[last_seq_id], token_id=last_token
            )
        else:
            # Note: this is a hack.
            # Lark pickling does not work properly (silent failure),
            # which breaks the RPC (which uses python pickleing).
            # We need to find a better solution.
            # On the first time this is called, we simply re-create
            # the Lark object.
            if isinstance(self._guide, CFGGuide):
                self._guide.parser = Lark(
                    self._guide.cfg_string,
                    parser="lalr",
                    lexer="contextual",
                    propagate_positions=False,
                    maybe_placeholders=False,
                    regex=True,
                    import_paths=[grammars.GRAMMAR_PATH],
                )

        instruction = self._guide.get_next_instruction(state=self._fsm_state[seq_id])

        if type(instruction) == Generate:  # noqa: E721
            allowed_tokens = instruction.tokens
        elif type(instruction) == Write:  # noqa: E721
            # TODO: support fast forward tokens
            allowed_tokens = [instruction.tokens[0]]
        else:
            raise TypeError(f"Unsupported instruction type {type(instruction)}")

        mask = torch.full((scores.shape[-1],), -math.inf, device=scores.device)
        mask[allowed_tokens] = 0
        scores.add_(mask)
        return scores

In [4]:
@lru_cache(maxsize=32)
def _adapt_tokenizer(tokenizer: PreTrainedTokenizerBase):
    """Adapt vLLM's tokenizer to use to compile the FSM.

    The API of Outlines tokenizers is slightly different to that of
    `transformers`. The decoder of outlines, returns a list whereas
    the decode of vLLM returns an str. To sync the vLLM decoder with
    outlines internal api, the decoder should be adapted. In addition
    we need to handle the missing spaces to Llama's tokenizer to be
    able to compile FSMs for this model.

    """
    if getattr(tokenizer, "_outlines_adapted", False):
        return tokenizer

    tokenizer = copy.deepcopy(tokenizer)

    tokenizer.vocabulary = tokenizer.get_vocab()
    tokenizer.special_tokens = set(tokenizer.all_special_tokens)

    def convert_token_to_string(token: str) -> str:
        from transformers.file_utils import SPIECE_UNDERLINE

        string = tokenizer.convert_tokens_to_string([token])

        # A hack to handle missing spaces to HF's Llama tokenizers
        if token.startswith(SPIECE_UNDERLINE) or token == "<0x20>":
            return " " + string

        return string

    def change_decoder(
        decoder: Callable[[List[int]], str],
    ) -> Callable[[List[int]], List[str]]:
        """Sync vLLM's decoder with the outlines by returning list."""

        def new_decoder(inp_tokens: List[int]) -> List[str]:
            return [decoder(inp_tokens)]

        return new_decoder

    tokenizer.convert_token_to_string = convert_token_to_string
    tokenizer.decode = change_decoder(tokenizer.decode)
    setattr(tokenizer, "_outlines_adapted", True)  # noqa: B010

    return tokenizer

In [5]:
class RegexLogitsProcessor(BaseLogitsProcessor):
    @classmethod
    @cache()
    def _get_guide(cls, regex_string: str, tokenizer: PreTrainedTokenizerBase) -> Guide:
        tokenizer = _adapt_tokenizer(tokenizer)
        return RegexGuide(regex_string, tokenizer)

    def __init__(self, regex_string: str, tokenizer: PreTrainedTokenizerBase):
        """Compile the FSM that drives the regex-structured generation.

        Parameters
        ----------
        regex_string
            A string that represents a regular expression
        tokenizer
            The model's tokenizer

        """
        super().__init__(RegexLogitsProcessor._get_guide(regex_string, tokenizer))

In [6]:
from transformers import AutoTokenizer
from outlines.fsm.guide import Generate, Guide, RegexGuide, Write

tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-72B-Instruct-AWQ")

regex_processor = RegexLogitsProcessor(regex_string, tokenizer)

Compiling FSM index for all state transitions: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 35/35 [00:00<00:00, 36.50it/s]


This processor has two properties:
* `self._guide: Guide = guide`: a generation guide in the language of regular pattern.
* `self._fsm_state: DefaultDict[int, int] = defaultdict(int)`: a dictionary for storage of `id:state`.

Since `self._guide`'s `get_next_state` and `get_next_instruction` these two methods are mainly used, we will mainly look into these two methods, along with the created states map.

In [7]:
regex_processor._guide.states_to_token_maps

{0: {90: 1, 515: 2, 4913: 3},
 1: {198: 2, 1: 3, 31486: 7},
 2: {1: 3, 31486: 7},
 3: {3376: 5, 12400: 6, 606: 7, 77: 4},
 4: {64: 5, 373: 7, 309: 6},
 5: {76: 6, 2660: 7},
 6: {68: 7},
 7: {60767: 22,
  71835: 12,
  698: 9,
  3252: 12,
  62366: 17,
  4660: 11,
  1: 8,
  58528: 12,
  70318: 12,
  42398: 21,
  788: 10},
 8: {91920: 12,
  95740: 12,
  2974: 12,
  79389: 21,
  90220: 12,
  198: 9,
  69034: 12,
  86789: 12,
  28798: 12,
  79729: 12,
  12147: 12,
  83131: 20,
  510: 11,
  25: 10},
 9: {91920: 12,
  95740: 12,
  2974: 12,
  79389: 21,
  90220: 12,
  69034: 12,
  86789: 12,
  28798: 12,
  79729: 12,
  12147: 12,
  83131: 20,
  510: 11,
  25: 10},
 10: {67779: 12,
  30337: 12,
  46316: 12,
  40622: 12,
  21608: 12,
  59849: 12,
  34600: 12,
  27533: 12,
  14345: 12,
  14129: 12,
  57439: 12,
  45140: 12,
  86555: 12,
  10713: 12,
  24011: 12,
  40124: 12,
  65120: 12,
  9749: 12,
  70959: 12,
  34802: 12,
  32139: 12,
  32973: 12,
  13123: 12,
  52570: 12,
  51725: 12,
  83405

In [8]:
def decode_token_maps(tokenizer, states_to_token_maps):
    decoded_maps = {}
    
    for state, token_map in states_to_token_maps.items():
        decoded_maps[state] = {}
        for token_id, next_state in token_map.items():
            # Decode single token_id to string
            decoded_text = tokenizer.decode([token_id])
            decoded_maps[state][decoded_text] = next_state
            
    return decoded_maps

decoded_states = decode_token_maps(tokenizer, regex_processor._guide.states_to_token_maps)

for state, token_map in decoded_states.items():
    print(f"\nState {state}:")
    inner_state_counter = 0
    for text, next_state in token_map.items():
        if inner_state_counter > 30:
            break
        print(f"  '{text}' → State {next_state}")
        inner_state_counter += 1
        


State 0:
  '{' → State 1
  '{
' → State 2
  '{"' → State 3

State 1:
  '
' → State 2
  '"' → State 3
  '"name' → State 7

State 2:
  '"' → State 3
  '"name' → State 7

State 3:
  'na' → State 5
  'nam' → State 6
  'name' → State 7
  'n' → State 4

State 4:
  'a' → State 5
  'ame' → State 7
  'am' → State 6

State 5:
  'm' → State 6
  'me' → State 7

State 6:
  'e' → State 7

State 7:
  '":"","' → State 22
  '":"'' → State 12
  '"
' → State 9
  '":"' → State 12
  '":""' → State 17
  '":
' → State 11
  '"' → State 8
  '":"/' → State 12
  '":"+' → State 12
  '":"",
' → State 21
  '":' → State 10

State 8:
  ':".$' → State 12
  ':".' → State 12
  ':"' → State 12
  ':"",
' → State 21
  ':")' → State 12
  '
' → State 9
  ':"<<' → State 12
  ':"-' → State 12
  ':"+' → State 12
  ':"#' → State 12
  ':",' → State 12
  ':"",' → State 20
  ':
' → State 11
  ':' → State 10

State 9:
  ':".$' → State 12
  ':".' → State 12
  ':"' → State 12
  ':"",
' → State 21
  ':")' → State 12
  ':"<<' → State 1

In [9]:
import interegular
import re
from functools import lru_cache
from typing import (
    Dict,
    FrozenSet,
    Iterable,
    List,
    Optional,
    Sequence,
    Set,
    Tuple,
    Union,
    cast,
)

from interegular.fsm import (
    FSM,
    Alphabet,
    State,
    TransitionKey,
    _AnythingElseCls,
    anything_else,
)

from typing import Dict, List, Optional, Set, Tuple

class Vocabulary:
    """
    Vocabulary of an LLM.
    """

    @staticmethod
    def from_dict(map: Dict[str, List[int]]) -> "Vocabulary":
        """
        Creates a vocabulary from a dictionary of tokens to token IDs.
        """
        ...
    def __repr__(self) -> str:
        """
        Gets the debug string representation of the vocabulary.
        """
        ...
    def __str__(self) -> str:
        """
        Gets the string representation of the vocabulary.
        """
        ...

class Index:
    def get_allowed_tokens(self, state: int) -> Optional[List[int]]:
        """Returns allowed tokens in this state."""
        ...
    def get_next_state(self, state: int, token_id: int) -> Optional[int]:
        """Updates the state."""
        ...
    def is_final_state(self, state: int) -> bool:
        """Determines whether the current state is a final state."""
        ...
    def get_index_dict(self) -> Dict[int, Dict[int, int]]:
        """Returns the Index as a Python Dict object."""
        ...
    def get_initial_state(self) -> int:
        """Returns the ID of the initial state of the input FSM automata."""
        ...


class FSMInfo:
    initial: int
    finals: Set[int]
    transitions: Dict[Tuple[int, int], int]
    alphabet_anything_value: int
    alphabet_symbol_mapping: Dict[str, int]

    def __init__(
        self,
        initial: int,
        finals: Set[int],
        transitions: Dict[Tuple[int, int], int],
        alphabet_anything_value: int,
        alphabet_symbol_mapping: Dict[str, int],
    ) -> None: ...

def build_regex_from_schema(
    json: str, whitespace_pattern: Optional[str] = None
) -> str: ...
def to_regex(json: Dict, whitespace_pattern: Optional[str] = None) -> str: ...
def _walk_fsm(
    fsm_transitions: Dict[Tuple[int, int], int],
    fsm_initial: int,
    fsm_finals: Set[int],
    token_transition_keys: List[int],
    start_state: int,
    full_match: bool,
) -> List[int]: ...
def state_scan_tokens(
    fsm_transitions: Dict[Tuple[int, int], int],
    fsm_initial: int,
    fsm_finals: Set[int],
    vocabulary: Vocabulary,
    vocabulary_transition_keys: Dict[str, List[int]],
    start_state: int,
) -> Set[Tuple[int, int]]: ...
def get_token_transition_keys(
    alphabet_symbol_mapping: Dict[str, int],
    alphabet_anything_value: int,
    token_str: str,
) -> List[int]: ...
def get_vocabulary_transition_keys(
    alphabet_symbol_mapping: Dict[str, int],
    alphabet_anything_value: int,
    vocabulary: Vocabulary,
    frozen_tokens: Set[str],
) -> Dict[str, List[int]]: ...
def create_fsm_index_end_to_end(
    fsm_info: FSMInfo,
    vocabulary: Vocabulary,
    frozen_tokens: frozenset[str],
) -> Dict[int, Dict[int, int]]: ...

BOOLEAN: str
DATE: str
DATE_TIME: str
INTEGER: str
NULL: str
NUMBER: str
STRING: str
STRING_INNER: str
TIME: str
UUID: str
WHITESPACE: str

In [10]:
class BetterAlphabet(Alphabet):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        assert anything_else in self._symbol_mapping
        self.anything_value = self._symbol_mapping[anything_else]

    def __getitem__(self, item):
        return self._symbol_mapping.get(item, self.anything_value)

    def copy(self):
        return BetterAlphabet(self._symbol_mapping.copy())


class BetterFSM(FSM):
    flat_transition_map: Dict[Tuple[int, int], int]

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        if not isinstance(self.alphabet, BetterAlphabet):
            self.__dict__["alphabet"] = BetterAlphabet(self.alphabet._symbol_mapping)

        flat_transition_map = {}
        for from_state, trans_map in self.map.items():
            for trans_key, to_state in trans_map.items():
                flat_transition_map[(from_state, trans_key)] = to_state

        self.__dict__["flat_transition_map"] = flat_transition_map
        self.__dict__["_fsm_info"] = None

    def copy(self):
        return BetterFSM(
            alphabet=self.alphabet.copy(),
            states=self.states.copy(),
            initial=self.initial,
            finals=self.finals.copy(),
            map=self.map.copy(),
            __no_validation__=True,
        )

    @property
    def fsm_info(self):
        if self._fsm_info is None:
            anything_value = self.alphabet.anything_value
            self.__dict__["_fsm_info"] = FSMInfo(
                self.initial,
                self.finals,
                self.flat_transition_map,
                anything_value,
                # TODO FIXME: Perform this conversion in Rust?
                {
                    k: v
                    for k, v in self.alphabet._symbol_mapping.items()
                    if k != anything_else
                },
            )

        return self._fsm_info


TransitionTrie = Dict[TransitionKey, "Union[TransitionTrie, State, None]"]


def add_to_transition_trie(
    trie: TransitionTrie,
    key_seq: Sequence[TransitionKey],
    value: Union[State, None],
):
    for key in key_seq[:-1]:
        trie = cast(TransitionTrie, trie.setdefault(key, {}))
        assert isinstance(trie, dict), "key sequence of incompatible length"
    trie[key_seq[-1]] = value


# merge default_trie into the trie, only updating entries not present in the trie
def transition_trie_setdefault(
    trie: TransitionTrie,
    default_trie: TransitionTrie,
):
    for key, default_value in default_trie.items():
        dest_value = trie.get(key)
        if isinstance(dest_value, dict) and isinstance(default_value, dict):
            transition_trie_setdefault(dest_value, default_value)
        elif key not in trie:
            trie[key] = default_value


def byte_symbol(byte: int) -> str:
    return f"\x00{byte:02X}" if byte >= 0x80 else chr(byte)


def make_byte_level_fsm(
    fsm: FSM, keep_utf8: bool = False, frozen_tokens: List[str] = []
) -> FSM:
    """Convert an FSM to a byte-level FSM, expanding multi-byte characters as
    sequences of single-byte transitions.

    Parameters
    ----------
    fsm: (`interegular.FSM`):
        The token-level FSM to convert to a byte-level FSM.
    keep_utf8: (`bool`, *optional*):
        If set to True, the original utf-8 characters are kept as-is. Defaults to
        False. NOTE: we're representing bytes as strings to keep it type-compatible.
    frozen_tokens: (`List[str]`, *optional*):
        A list of tokens that should be kept as-is in the byte-level FSM. That is,
        these tokens will not be expanded into byte-level transitions. Defaults to
        an empty list.

    Returns
    -------
    `interegular.FSM`: A byte-level FSM.
    """

    anything_else_key = fsm.alphabet[anything_else]
    symbol_mapping: Dict[Union[str, _AnythingElseCls], TransitionKey] = {}
    map: Dict[State, Dict[TransitionKey, State]] = {}
    states: List[State] = list(fsm.states)

    # identify all multi-byte characters in the alphabet and build a mapping
    # from the original transition keys to sequences of new keys for each byte
    key_to_key_seqs: Dict[TransitionKey, Set[Tuple[TransitionKey, ...]]] = {}
    all_key_seqs: Set[Tuple[TransitionKey, ...]] = set()
    all_bytes: Set[int] = set()
    max_key = max(fsm.alphabet.values())
    for symbol, transition_key in fsm.alphabet.items():
        assert symbol == anything_else or symbol in frozen_tokens or len(symbol) == 1
        if symbol == anything_else or symbol in frozen_tokens or ord(symbol) < 0x80:
            symbol_mapping[symbol] = transition_key
        else:
            if keep_utf8:
                symbol_mapping[symbol] = transition_key
            key_list: List[TransitionKey] = []
            for byte in symbol.encode("utf-8"):
                symbol = byte_symbol(byte)
                if symbol not in symbol_mapping:
                    symbol_mapping[symbol] = max_key = TransitionKey(max_key + 1)
                    all_bytes.add(byte)
                key_list.append(symbol_mapping[symbol])
            key_seq = tuple(key_list)
            key_to_key_seqs.setdefault(transition_key, set()).add(key_seq)
            all_key_seqs.add(key_seq)

    # add all remaining multi-byte utf-8 bytes to the alphabet
    # (this is required to represent `anything_else`)
    utf8_ranges = {
        1: (0x80, 0xC0),  # continuation bytes
        2: (0xC0, 0xE0),  # 2-byte sequences
        3: (0xE0, 0xF0),  # 3-byte sequences
        4: (0xF0, 0xF8),  # 4-byte sequences
    }
    utf8_all_keys: Dict[int, Set[TransitionKey]] = {
        n: set() for n in utf8_ranges.keys()
    }
    for n, (start, end) in utf8_ranges.items():
        range_key = max_key = TransitionKey(max_key + 1)
        for byte in range(start, end):
            byte_key = symbol_mapping.setdefault(byte_symbol(byte), range_key)
            utf8_all_keys[n].add(byte_key)

    # cache of intermediate transition states by transitions from that state
    state_cache: Dict[FrozenSet[Tuple[TransitionKey, State]], State] = {}

    # helper function to create multi-step transitions between states
    max_state = max(fsm.states)

    def create_seq_transitions(
        seq_transitions_trie: TransitionTrie,
    ) -> Dict[TransitionKey, State]:
        nonlocal max_state
        result: Dict[TransitionKey, State] = {}

        for next_key, next_trie in seq_transitions_trie.items():
            if isinstance(next_trie, dict):
                next_transitions = create_seq_transitions(next_trie)
                if not next_transitions:
                    continue
                cache_key = frozenset(next_transitions.items())
                next_state = state_cache.get(cache_key)
                if next_state is None:
                    next_state = max_state = State(max_state + 1)
                    map[next_state] = next_transitions
                    state_cache[cache_key] = next_state
                    states.append(next_state)
                result[next_key] = next_state
            elif next_trie is not None:
                result[next_key] = next_trie

        return result

    # create new states and transitions
    for state, transitions in fsm.map.items():
        seq_transitions_trie: TransitionTrie = {}
        state_map: Dict[TransitionKey, State] = {}

        for transition_key, to_state in transitions.items():
            if transition_key in key_to_key_seqs:
                if keep_utf8:
                    state_map[transition_key] = to_state
                for key_seq in key_to_key_seqs[transition_key]:
                    add_to_transition_trie(seq_transitions_trie, key_seq, to_state)
            else:  # keep single-byte transitions as is
                state_map[transition_key] = to_state

        # handle multi-byte anything_else sequences
        if anything_else_key in transitions:
            for key_seq in all_key_seqs:
                add_to_transition_trie(seq_transitions_trie, key_seq, None)

            anything_else_trie: TransitionTrie = {}
            cont_trie: Union[TransitionTrie, State] = transitions[anything_else_key]
            for n in range(2, 5):
                cont_trie = {key: cont_trie for key in utf8_all_keys[1]}
                for key in utf8_all_keys[n]:
                    anything_else_trie[key] = cont_trie

            transition_trie_setdefault(seq_transitions_trie, anything_else_trie)

        # create new states and transitions
        next_transitions = create_seq_transitions(seq_transitions_trie)
        state_map.update(next_transitions)
        map[state] = state_map

    return FSM(
        alphabet=Alphabet(symbol_mapping),
        states=states,
        initial=fsm.initial,
        finals=fsm.finals,
        map=map,
    )


def make_deterministic_fsm(fsm: FSM) -> Tuple[BetterFSM, Dict[int, int]]:
    """Construct an equivalent FSM with deterministic state labels."""
    old_to_new_trans_keys = {
        trans_key: i
        for i, (trans_key, _) in enumerate(
            sorted(fsm.alphabet.by_transition.items(), key=lambda x: sorted(x[1]))
        )
    }

    new_symbol_mapping = {
        symbol: old_to_new_trans_keys[trans_key]
        for symbol, trans_key in fsm.alphabet._symbol_mapping.items()
    }

    new_alphabet = BetterAlphabet(new_symbol_mapping)

    new_map = {
        from_state: {
            old_to_new_trans_keys[trans_key]: to_state
            for trans_key, to_state in trans_map.items()
        }
        for from_state, trans_map in fsm.map.items()
    }

    old_to_new_states = {}
    old_to_new_states[fsm.initial] = 0

    i = 0
    seen = {fsm.initial}
    old_state_queue = [fsm.initial]
    while old_state_queue:
        old_state = old_state_queue.pop(-1)
        transitions = new_map[old_state]
        sorted_transitions = sorted(transitions.items(), key=lambda v: v[0])
        for _, old_state in sorted_transitions:
            if old_state not in seen:
                old_state_queue.append(old_state)
                seen.add(old_state)
            if old_state not in old_to_new_states:
                i += 1
                old_to_new_states[old_state] = i

    new_map = dict(
        sorted(
            (
                (
                    old_to_new_states[from_state],
                    dict(
                        sorted(
                            (
                                (trans_key, old_to_new_states[to_state])
                                for trans_key, to_state in trans_map.items()
                            ),
                            key=lambda v: v[0],
                        )
                    ),
                )
                for from_state, trans_map in new_map.items()
            ),
            key=lambda v: v[0],
        )
    )

    new_initial = 0
    new_finals = frozenset(
        sorted(old_to_new_states[old_state] for old_state in fsm.finals)
    )
    new_states = frozenset(sorted(new_map.keys()))

    new_fsm = BetterFSM(new_alphabet, new_states, new_initial, new_finals, new_map)

    return new_fsm, old_to_new_states


re_llama_byte_token = re.compile(r"^<0x[0-9A-F]{2}>$")

# The "▁*" prefix is required to handle Gemma and GPT-SW3 tokenizers, and the "\.*"
# suffix is required to handle the NorwAI tokenizer.
re_replacement_seq = re.compile(r"^▁*�+\.*$")


# Copied from transformers.models.gpt2.tokenization_gpt2.bytes_to_unicode
@lru_cache()
def gpt2_bytes_to_unicode():
    """
    Returns list of utf-8 byte and a mapping to unicode strings. We specifically avoids mapping to whitespace/control
    characters the bpe code barfs on.

    The reversible bpe codes work on unicode strings. This means you need a large # of unicode characters in your vocab
    if you want to avoid UNKs. When you're at something like a 10B token dataset you end up needing around 5K for
    decent coverage. This is a significant percentage of your normal, say, 32K bpe vocab. To avoid that, we want lookup
    tables between utf-8 bytes and unicode strings.
    """
    bs = (
        list(range(ord("!"), ord("~") + 1))
        + list(range(ord("¡"), ord("¬") + 1))
        + list(range(ord("®"), ord("ÿ") + 1))
    )
    cs = bs[:]
    n = 0
    for b in range(2**8):
        if b not in bs:
            bs.append(b)
            cs.append(2**8 + n)
            n += 1
    cs = [chr(n) for n in cs]
    return dict(zip(bs, cs))


@lru_cache()
def gpt2_unicode_to_bytes():
    return {v: k for k, v in gpt2_bytes_to_unicode().items()}


@lru_cache
def reduced_vocabulary(
    tokenizer,
) -> Tuple[Dict[str, List[int]], Set[int]]:
    """Create a map from decoded vocabulary tokens to lists of equivalent token ids."""
    # TODO FIXME: See if we can get the underlying Rust tokenizers from HF and
    # do all this in Rust
    empty_token_ids = set()
    vocabulary: Dict[str, List[int]] = {}
    for token, token_idx in tokenizer.vocab.items():
    # for token, token_idx in tokenizer.vocabulary.items():
        if token in tokenizer.special_tokens:
            continue

        token_str: Union[str, Tuple[str, ...]] = tokenizer.convert_token_to_string(
            token
        )

        if token_str:
            if isinstance(token, bytes):
                # Handle BPE tokenizers where the tokens are directly stored as bytes
                # https://github.com/QwenLM/Qwen/blob/main/tokenization_note.md#regular-tokens
                token_str = "".join(byte_symbol(b) for b in token)

            elif "\ufffd" in token_str and not re_replacement_seq.match(token):
                # invalid utf-8 sequences are replaced with � (\ufffd), but there
                # might also be tokens specifically for �, ��, ���, etc.

                if re_llama_byte_token.match(token):
                    # llama-like tokenizers have <0xXX> tokens for all
                    # bytes >= 0x80 and represent all incomplete utf-8
                    # sequences using such tokens
                    token_bytes = [int(token[3:5], 16)]
                else:
                    # gpt2-like tokenizers have multi-byte tokens that can
                    # have a mix of full and incomplete utf-8 characters,
                    # for example, b` \xf0` can be one token; these tokenizers
                    # map each byte to a valid utf-8 character
                    token_bytes = cast(
                        List[int], [gpt2_unicode_to_bytes().get(c) for c in token]
                    )
                    if None in token_bytes:
                        raise RuntimeError(
                            f"Cannot convert token `{token}` ({token_idx}) to bytes: {token_str}"
                        )
                token_str = "".join(byte_symbol(b) for b in token_bytes)

            assert isinstance(token_str, str)

            vocabulary.setdefault(token_str, []).append(token_idx)
        else:
            empty_token_ids.add(token_idx)

    return vocabulary, empty_token_ids


def create_fsm_index_tokenizer(
    fsm: BetterFSM,
    tokenizer,
    frozen_tokens: Optional[Iterable[str]] = None,
) -> Tuple[Index, Set[int]]:
    """Construct an FMS index from a tokenizer.

    This uses the end-to-end approach of `create_fsm_index_end_to_end`.

    Parameters
    ----------
    fsm: (`BetterFSM`):
        A cache-friendly FSM. Other interegular FSMs can also be used, but caching
        may not work as expected.
    tokenizer: (`Tokenizer`):
        The model's tokenizer.
    frozen_tokens: (`List[str]`, *optional*):
        A list of tokens that should be kept as-is when expanding the token-level
        FSM into a byte-level FSM. Defaults to an empty list.

    Returns
    -------
    states_to_token_maps: (`Dict[int, Dict[int, int]]`):
        A mapping from states to a mapping from token ids originating from that state
        to the next state to transition to given that token. The structure is as follows:
        (origin_state -> (token_id -> next_state))
    empty_token_ids: (`Set[int]`):
        A set of token ids that correspond to empty strings.

    .. warning::

        `fsm` needs to be deterministically ordered so that future caching makes sense.
    """
    tokens_to_token_ids, empty_token_ids = reduced_vocabulary(tokenizer)

    states_to_token_subsets = Index(  # type: ignore
        fsm.fsm_info,
        Vocabulary.from_dict(tokens_to_token_ids),
        tokenizer.eos_token_id,
        frozenset(frozen_tokens) if frozen_tokens is not None else frozenset(),
    )

    return states_to_token_subsets, empty_token_ids

In [11]:
def create_states_mapping(
    regex_string: str,
    tokenizer,
    regex_parser: Callable[[str], interegular.Pattern] = interegular.parse_pattern,
    frozen_tokens: List[str] = [],
) -> Tuple[Dict[int, Dict[int, int]], Set[int], Set[int]]:
    """Create the variables related to the mapping between states and tokens from a regex string.

    The parameters of the function are used for caching purpose.

    Parameters
    ----------
    regex_string:
        The regular expression string to generate a states mapping for.
    tokenizer:
        The model's tokenizer.
    regex_parser:
        A function that parses a regex string into an `interegular` Pattern object.
    frozen_tokens:
        A list of tokens that should be kept as-is when expanding the token-level FSM
        into a byte-level FSM. Defaults to an empty list.

    Returns
    -------
    states_to_token_maps:
        A mapping from states to a mapping from token ids originating from that state
        to the next state to transition to given that token. The structure is as follows:
        (origin_state -> (token_id -> next_state))
    empty_token_ids:
        A set of token ids that correspond to empty strings.
    final_states:
        A set of final states in the FSM.
    """
    regex_fsm = regex_parser(regex_string).to_fsm()
    return create_states_mapping_from_fsm(regex_fsm, tokenizer, frozen_tokens)

In [12]:
def create_states_mapping_from_fsm(
    fsm: interegular.fsm.FSM,
    tokenizer,
    frozen_tokens: List[str] = [],
) -> Tuple[Dict[int, Dict[int, int]], Set[int], Set[int]]:
    """Create the variables related to the mapping between states and tokens from an FSM.

    The parameters of the function are used for caching purpose.

    Parameters
    ----------
    fsm:
        An FSM for the regular expression.
    tokenizer:
        The model's tokenizer.
    frozen_tokens:
        A list of tokens that should be kept as-is when expanding the token-level FSM
        into a byte-level FSM. Defaults to an empty list.

    Returns
    -------
    states_to_token_maps:
        A mapping from states to a mapping from token ids originating from that state
        to the next state to transition to given that token. The structure is as follows:
        (origin_state -> (token_id -> next_state))
    empty_token_ids:
        A set of token ids that correspond to empty strings.
    final_states:
        A set of final states in the FSM.
    """
    byte_fsm = make_byte_level_fsm(
        fsm.reduce(), keep_utf8=True, frozen_tokens=frozen_tokens
    )
    regex_fsm, _ = make_deterministic_fsm(byte_fsm)
    states_to_token_maps, empty_token_ids = create_fsm_index_tokenizer(
        regex_fsm, tokenizer
    )

    # We make sure that it is possible to generate strings in the language
    # of the regular expression with the tokens present in the model's
    # vocabulary.
    if not any(
        regex_fsm.finals.intersection(v.values()) for v in states_to_token_maps.values()
    ):
        raise ValueError(
            "The vocabulary does not allow us to build a sequence that matches the input regex"
        )

    return states_to_token_maps, empty_token_ids, regex_fsm.finals