Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

KeyError in BetterFSM::FSMInfo when input FSM alphabet contains UTF-8 characters that ends with \xb8\x80 #833

Open
m0g1cian opened this issue Apr 23, 2024 · 2 comments · May be fixed by #904
Labels

Comments

@m0g1cian
Copy link

m0g1cian commented Apr 23, 2024

Describe the issue as clearly as possible:

Update 2

Can confirm there's something wrong with Numba's Typed Dict implementation. Check issue here

Update

After some testing, it is clear that this KeyError occurs for UTF-8 characters that ends with \xb8\x80 (e.g. "帀", "㸀", "渀", 縀").


When outlines builds BetterFSM from a reference FSM (e.g. from interegular), if the reference FSM contains Chinese character "一", the corresponding numba.typed.Dict used by BetterFSM::alphabet_symbol_map somehow converts this character into an empty string, causing a KeyError whenever __getitem__ is triggered .

Steps/code to reproduce the bug:

debug_keyerror.py

import interegular
from outlines.fsm.regex import FSMInfo, make_deterministic_fsm

if __name__ == "__main__":
    regex_string = r"(1|2|3|one|two|three|一|二|三)"
    regex_pattern = interegular.parse_pattern(regex_string)
    regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
    fsm_info: FSMInfo = regex_fsm.fsm_info
    print(fsm_info)

Some insight:

print (k, v) in alphabet_symbol_mapping_items before create_fsm_info() (right after outlines.fsm.regex.py::96)

e 3
w 9
2 1
h 4
t 8
o 6
r 7
1 0
3 2
n 5
一 10
二 12
三 11

print (k, v) in alphabet_symbol_mapping_items in create_fsm_info() when building alphabet_symbol_map (right after outlines.fsm.regex.py::139)

e 3
w 9
2 1
h 4
t 8
o 6
r 7
1 0
3 2
n 5
 10
二 12
三 11

Expected result:

I was able to get the expected result after tweaking two places:

  1. outlines.fsm.regex.py::112: change nb_unichar_2_type = numba.types.UnicodeCharSeq(2) to nb_unichar_2_type = numba.types.unicode_type
  2. outlines.fsm.regex.py::89: change alphabet_symbol_mapping_items to a simple python list alphabet_symbol_mapping_items = list((k,v) for k, v in self.alphabet._symbol_mapping.items() if k != anything_else)
FSMInfo(
  initial=0,
  finals={1},
  transitions=DictType[UniTuple(int64 x 2),int64]<iv=None>({(0, 0): 1, (0, 1): 1, (0, 2): 1, (0, 6): 2, (0, 8): 3, (0, 10): 1, (0, 11): 1, (0, 12): 1, (2, 5): 7, (3, 4): 4, (3, 9): 5, (4, 7): 6, (5, 6): 1, (6, 3): 7, (7, 3): 1}),
  trans_key_to_states=DictType[int64,ListType[int64]]<iv=None>({0: [0], 1: [0], 2: [0], 6: [0, 5], 8: [0], 10: [0], 11: [0], 12: [0], 5: [2], 4: [3], 9: [3], 7: [4], 3: [6, 7]}),
  alphabet_anything_value=13,
  alphabet_symbol_mapping=DictType[unicode_type,int64]<iv=None>({2: 1, 1: 0, o: 6, 3: 2, r: 7, 一: 10, n: 5, w: 9, e: 3, h: 4, 三: 11, 二: 12, t: 8})
)

Error message:

Traceback (most recent call last):
  File "...\debug_keyerror.py", line 9, in <module>
    print(fsm_info)
  File "...\lib\collections\__init__.py", line 441, in __repr__
    return self.__class__.__name__ + repr_fmt % self
  File "...\lib\site-packages\numba\typed\typeddict.py", line 217, in __repr__
    body = str(self)
  File "...\lib\site-packages\numba\typed\typeddict.py", line 212, in __str__
    for k, v in self.items():
  File "...\lib\_collections_abc.py", line 911, in __iter__
    yield (key, self._mapping[key])
  File "...\lib\site-packages\numba\typed\typeddict.py", line 180, in __getitem__
    return _getitem(self, key)
  File "...\lib\site-packages\numba\typed\dictobject.py", line 783, in impl
    raise KeyError()
KeyError

Outlines/Python version information:

Version information

``` > \> python -c "from outlines import _version; print(_version.version)" 0.0.40

> python -c "import sys; print('Python', sys.version)"
Python 3.10.12 | packaged by conda-forge | (main, Jun 23 2023, 22:34:57) [MSC v.1936 64 bit (AMD64)]

> pip freeze
annotated-types==0.6.0
anyio==4.3.0
attrs==23.2.0
boltons @ file:///home/conda/feedstock_root/build_artifacts/boltons_1677499911949/work
Brotli @ file:///D:/bld/brotli-split_1693583621767/work
build==1.2.1
CacheControl==0.14.0
cached-property @ file:///home/conda/feedstock_root/build_artifacts/cached_property_1615209429212/work
certifi @ file:///home/conda/feedstock_root/build_artifacts/certifi_1707022139797/work/certifi
cffi @ file:///D:/bld/cffi_1671179506518/work
charset-normalizer @ file:///home/conda/feedstock_root/build_artifacts/charset-normalizer_1688813409104/work
cleo==2.1.0
cloudpickle==3.0.0
colorama @ file:///home/conda/feedstock_root/build_artifacts/colorama_1666700638685/work
conda==23.3.1
conda-libmamba-solver @ file:///home/conda/feedstock_root/build_artifacts/conda-libmamba-solver_1680508672016/work/src
conda-package-handling @ file:///home/conda/feedstock_root/build_artifacts/conda-package-handling_1691048088238/work
conda_package_streaming @ file:///home/conda/feedstock_root/build_artifacts/conda-package-streaming_1691009212940/work
crashtest==0.4.1
cryptography @ file:///D:/bld/cryptography-split_1691444290667/work
diskcache==5.6.3
distlib==0.3.8
distro==1.9.0
docopt==0.6.2
dulwich==0.21.7
exceptiongroup==1.2.0
fastjsonschema==2.19.1
filelock==3.13.1
fsspec==2024.2.0
h11==0.14.0
h5py @ file:///D:/bld/h5py_1702471423597/work
httpcore==1.0.5
httpx==0.27.0
huggingface-hub==0.21.3
idna @ file:///home/conda/feedstock_root/build_artifacts/idna_1663625384323/work
importlib_metadata==7.1.0
inquirerpy==0.3.4
installer==0.7.0
interegular==0.3.3
jaraco.classes==3.4.0
Jinja2 @ file:///home/conda/feedstock_root/build_artifacts/jinja2_1704966972576/work
joblib==1.4.0
jsonpatch @ file:///home/conda/feedstock_root/build_artifacts/jsonpatch_1632759296524/work
jsonpointer==2.0
jsonschema==4.21.1
jsonschema-specifications==2023.12.1
keyring==24.3.1
lark==1.1.9
libmambapy @ file:///D:/bld/mamba-split_1680791188848/work/libmambapy
llvmlite==0.42.0
mamba @ file:///D:/bld/mamba-split_1680791188848/work/mamba
MarkupSafe @ file:///D:/bld/markupsafe_1706900062361/work
menuinst @ file:///D:/bld/menuinst_1666839998718/work
more-itertools==10.2.0
mpmath @ file:///home/conda/feedstock_root/build_artifacts/mpmath_1678228039184/work
msgpack==1.0.8
mypy==1.9.0
mypy-extensions==1.0.0
nest-asyncio==1.6.0
networkx @ file:///home/conda/feedstock_root/build_artifacts/networkx_1698504735452/work
numba==0.59.1
numpy @ file:///D:/bld/numpy_1707225570061/work/dist/numpy-1.26.4-cp310-cp310-win_amd64.whl#sha256=6761da75b1528684e6bf4dabdbdded9d1eb4d0e9b299482c7ce152cfb3155106
openai==1.21.2
outlines==0.0.40
packaging @ file:///home/conda/feedstock_root/build_artifacts/packaging_1681337016113/work
parse @ file:///home/conda/feedstock_root/build_artifacts/parse_1706516706584/work
pexpect==4.9.0
pfzy==0.3.4
pipreqs==0.4.13
pkginfo==1.10.0
platformdirs==4.2.0
pluggy @ file:///home/conda/feedstock_root/build_artifacts/pluggy_1693086607691/work
poetry==1.8.2
poetry-core==1.9.0
poetry-plugin-export==1.7.1
prompt-toolkit==3.0.43
ptyprocess==0.7.0
pycosat @ file:///D:/bld/pycosat_1666836675990/work
pycparser @ file:///home/conda/feedstock_root/build_artifacts/pycparser_1636257122734/work
pydantic==2.7.0
pydantic_core==2.18.1
pyOpenSSL @ file:///home/conda/feedstock_root/build_artifacts/pyopenssl_1685514481738/work
pyproject_hooks==1.0.0
PySocks @ file:///D:/bld/pysocks_1661604991356/work
pywin32-ctypes==0.2.2
PyYAML @ file:///D:/bld/pyyaml_1695373629531/work
rapidfuzz==3.8.1
referencing==0.34.0
regex @ file:///D:/bld/regex_1703393598862/work
requests @ file:///home/conda/feedstock_root/build_artifacts/requests_1684774241324/work
requests-toolbelt==1.0.0
rpds-py==0.18.0
ruamel.yaml @ file:///D:/bld/ruamel.yaml_1686994025923/work
ruamel.yaml.clib @ file:///D:/bld/ruamel.yaml.clib_1670412994006/work
safetensors==0.4.3
scipy==1.13.0
sglang==0.1.14
shellingham==1.5.4
sniffio==1.3.1
sympy @ file:///home/conda/feedstock_root/build_artifacts/sympy_1684180539862/work
tokenizers==0.19.1
tomli==2.0.1
tomlkit==0.12.4
toolz @ file:///home/conda/feedstock_root/build_artifacts/toolz_1657485559105/work
torch==2.2.2
tqdm @ file:///home/conda/feedstock_root/build_artifacts/tqdm_1691671248568/work
transformers==4.40.0
trove-classifiers==2024.4.10
typing_extensions @ file:///home/conda/feedstock_root/build_artifacts/typing_extensions_1708904622550/work
urllib3 @ file:///home/conda/feedstock_root/build_artifacts/urllib3_1689789803562/work
virtualenv==20.25.3
wcwidth==0.2.13
win-inet-pton @ file:///D:/bld/win_inet_pton_1667051142467/work
yarg==0.1.9
zipp==3.18.1
zstandard==0.19.0

</details>


### Context for the issue:

I not sure why only the Chinese character "一" breaks everything while other Chinese characters are working fine as far as I can tell.
@m0g1cian m0g1cian added the bug label Apr 23, 2024
@m0g1cian m0g1cian changed the title KeyError in BetterFSM::FSMInfo when input FSM alphabet contains a specific Chinese character KeyError in BetterFSM::FSMInfo when input FSM alphabet contains UTF-8 characters that starts with \xb8\x80 Apr 23, 2024
@m0g1cian m0g1cian changed the title KeyError in BetterFSM::FSMInfo when input FSM alphabet contains UTF-8 characters that starts with \xb8\x80 KeyError in BetterFSM::FSMInfo when input FSM alphabet contains UTF-8 characters that ends with \xb8\x80 Apr 24, 2024
@lapp0
Copy link
Contributor

lapp0 commented May 17, 2024

@m0g1cian opened an upstream issue: numba/numba#9542

Per the thread, it appears to be an upstream bug on the numba side due to UnicodeCharSeq having trouble handling leading null byte \x00.

There are a few options here:

import numba
import numpy as np
from numba.cpython.charseq import unicode_charseq_get_code

@numba.njit
def function():
    s = np.empty(3, dtype="<U1")
    s[0] = "  ^`"
    s[1] = "  ^l"
    s[2] = "  ^i"
    return [unicode_charseq_get_code(item, 0) for item in s]

result = function()
print(result)

Output: [19968, 20108, 32]

@M0gician
Copy link

@m0g1cian opened an upstream issue: numba/numba#9542

Per the thread, it appears to be an upstream bug on the numba side due to UnicodeCharSeq having trouble handling leading null byte \x00.

There are a few options here:

import numba
import numpy as np
from numba.cpython.charseq import unicode_charseq_get_code

@numba.njit
def function():
    s = np.empty(3, dtype="<U1")
    s[0] = "  ^`"
    s[1] = "  ^l"
    s[2] = "  ^i"
    return [unicode_charseq_get_code(item, 0) for item in s]

result = function()
print(result)

Output: [19968, 20108, 32]

I made a local patch to fix this issue in outlines. It basically makes numba typed Dict or List always use unicode_type rather than unicode_charseq

I'll make a PR soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
Status: Todo
Development

Successfully merging a pull request may close this issue.

3 participants