Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 73 additions & 30 deletions llama_cpp/llama_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,9 +81,15 @@ def from_json_schema(
cls,
json_schema: str,
verbose: bool = True,
treat_optional_as_nullable: bool = False,
) -> "LlamaGrammar":
"""Convert a JSON schema to a Llama grammar."""
return cls.from_string(json_schema_to_gbnf(json_schema), verbose=verbose)
return cls.from_string(
json_schema_to_gbnf(
json_schema, treat_optional_as_nullable=treat_optional_as_nullable
),
verbose=verbose,
)

@classmethod
def from_file(cls, file: Union[str, Path], verbose: bool = True) -> "LlamaGrammar":
Expand Down Expand Up @@ -1392,14 +1398,14 @@ def print_grammar(file: TextIO, state: parse_state) -> None:
SPACE_RULE = '" "?'

PRIMITIVE_RULES = {
"boolean": '("true" | "false") space',
"number": '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space',
"integer": '("-"? ([0-9] | [1-9] [0-9]*)) space',
"boolean": '("true" | "false")',
"number": '("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)?',
"integer": '("-"? ([0-9] | [1-9] [0-9]*))',
"string": r""" "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space """,
"null": '"null" space',
)* "\"" """,
"null": '"null"',
}

INVALID_RULE_CHARS_RE = re.compile(r"[^a-zA-Z0-9-]+")
Expand All @@ -1408,30 +1414,43 @@ def print_grammar(file: TextIO, state: parse_state) -> None:


class SchemaConverter:
def __init__(self, prop_order):
def __init__(self, prop_order, treat_optional_as_nullable: bool = False):
self._prop_order = prop_order
self._rules = {"space": SPACE_RULE}
self._defs: Dict[str, Any] = {}
self._treat_optional_as_nullable = treat_optional_as_nullable

def _format_literal(self, literal: str):
escaped: str = GRAMMAR_LITERAL_ESCAPE_RE.sub(
lambda m: GRAMMAR_LITERAL_ESCAPES.get(m.group(0)), json.dumps(literal)
)
return f'"{escaped}"'

def _add_rule(self, name: str, rule: str):
def _add_rule(
self, name: str, rule: str, is_required: bool = True, with_space: bool = False
):
esc_name = INVALID_RULE_CHARS_RE.sub("-", name)
if esc_name not in self._rules or self._rules[esc_name] == rule:
complete_rule = rule

if self._treat_optional_as_nullable and not is_required:
esc_name += "-or-null"
complete_rule = f"({complete_rule} | {PRIMITIVE_RULES['null']})"

if with_space:
complete_rule += " space"

if esc_name not in self._rules or self._rules[esc_name] == complete_rule:
key = esc_name
else:
i = 0
while f"{esc_name}{i}" in self._rules:
i += 1
key = f"{esc_name}{i}"
self._rules[key] = rule

self._rules[key] = complete_rule
return key

def visit(self, schema: Dict[str, Any], name: str) -> str:
def visit(self, schema: Dict[str, Any], name: str, is_required: bool = True) -> str:
rule_name = name or "root"

if "$defs" in schema:
Expand All @@ -1448,14 +1467,16 @@ def visit(self, schema: Dict[str, Any], name: str) -> str:
)
)
)
return self._add_rule(rule_name, rule)
return self._add_rule(rule_name, rule, is_required, False)

elif "const" in schema:
return self._add_rule(rule_name, self._format_literal(schema["const"]))
return self._add_rule(
rule_name, self._format_literal(schema["const"]), is_required, False
)

elif "enum" in schema:
rule = " | ".join((self._format_literal(v) for v in schema["enum"]))
return self._add_rule(rule_name, rule)
return self._add_rule(rule_name, rule, is_required, False)

elif "$ref" in schema:
ref = schema["$ref"]
Expand All @@ -1465,56 +1486,78 @@ def visit(self, schema: Dict[str, Any], name: str) -> str:
def_schema = self._defs[def_name]
return self.visit(def_schema, f'{name}{"-" if name else ""}{def_name}')


schema_type: Optional[str] = schema.get("type") # type: ignore
schema_type: Optional[str] = schema.get("type") # type: ignore
assert isinstance(schema_type, str), f"Unrecognized schema: {schema}"

if schema_type == "object" and "properties" in schema:
# TODO: `required` keyword
prop_order = self._prop_order
prop_pairs = sorted(
schema["properties"].items(),
# sort by position in prop_order (if specified) then by key
key=lambda kv: (prop_order.get(kv[0], len(prop_order)), kv[0]),
)

rule = '"{" space'
rule = ""
previous_is_prop_required = False
for i, (prop_name, prop_schema) in enumerate(prop_pairs):
is_prop_required = (
"required" not in schema or prop_name in schema["required"]
)
prop_rule_name = self.visit(
prop_schema, f'{name}{"-" if name else ""}{prop_name}'
prop_schema,
f'{name}{"-" if name else ""}{prop_name}',
is_prop_required,
)
if i > 0:
rule += ' "," space'
rule += rf' {self._format_literal(prop_name)} space ":" space {prop_rule_name}'
rule += ' "}" space'
prop_rule = rf'{self._format_literal(prop_name)} space ":" space {prop_rule_name}'
if i == 0:
rule += prop_rule
else:
if self._treat_optional_as_nullable or (
previous_is_prop_required and is_prop_required
):
rule = f'{rule} "," space {prop_rule}'
elif previous_is_prop_required and not is_prop_required:
rule = f'{rule} ("," space {prop_rule})?'
elif not previous_is_prop_required and is_prop_required:
rule = f'({rule} "," space)? {prop_rule}'
elif not previous_is_prop_required and not is_prop_required:
rule = f'({rule} | {prop_rule} | {rule} "," space {prop_rule})'

return self._add_rule(rule_name, rule)
previous_is_prop_required |= is_prop_required

rule = '"{" space ' + rule + ' space "}"'

return self._add_rule(rule_name, rule, is_required, True)

elif schema_type == "array" and "items" in schema:
# TODO `prefixItems` keyword
item_rule_name = self.visit(
schema["items"], f'{name}{"-" if name else ""}item'
)
rule = (
f'"[" space ({item_rule_name} ("," space {item_rule_name})*)? "]" space'
)
return self._add_rule(rule_name, rule)
rule = f'"[" space ({item_rule_name} ("," space {item_rule_name})*)? "]"'
return self._add_rule(rule_name, rule, is_required, True)

else:
assert schema_type in PRIMITIVE_RULES, f"Unrecognized schema: {schema}"
return self._add_rule(
"root" if rule_name == "root" else schema_type,
PRIMITIVE_RULES[schema_type],
is_required,
True,
)

def format_grammar(self):
return "\n".join((f"{name} ::= {rule}" for name, rule in self._rules.items()))


def json_schema_to_gbnf(schema: str, prop_order: Optional[List[str]] = None):
def json_schema_to_gbnf(
schema: str,
prop_order: Optional[List[str]] = None,
treat_optional_as_nullable: bool = False,
):
prop_order = prop_order or []
schema = json.loads(schema)
prop_order = {name: idx for idx, name in enumerate(prop_order)}
converter = SchemaConverter(prop_order)
converter = SchemaConverter(prop_order, treat_optional_as_nullable)
converter.visit(schema, "")
return converter.format_grammar()
88 changes: 84 additions & 4 deletions tests/test_llama_grammar.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class B(BaseModel):
"a": {"$ref": "#/$defs/A"},
"b": {"title": "B", "type": "integer"},
},
"required": ["a", "b"],
"required": ["a"],
"title": "B",
"type": "object",
}
Expand All @@ -51,9 +51,30 @@ class B(BaseModel):

assert grammar.grammar is not None

assert (
llama_cpp.llama_grammar.json_schema_to_gbnf(
json.dumps(schema), treat_optional_as_nullable=False
)
== r"""space ::= " "?
integer ::= ("-"? ([0-9] | [1-9] [0-9]*)) space
a-A ::= "{" space "\"a\"" space ":" space integer space "}" space
root ::= "{" space "\"a\"" space ":" space a-A ("," space "\"b\"" space ":" space integer)? space "}" space"""
)

assert (
llama_cpp.llama_grammar.json_schema_to_gbnf(
json.dumps(schema), treat_optional_as_nullable=True
)
== r"""space ::= " "?
integer ::= ("-"? ([0-9] | [1-9] [0-9]*)) space
a-A ::= "{" space "\"a\"" space ":" space integer space "}" space
integer-or-null ::= (("-"? ([0-9] | [1-9] [0-9]*)) | "null") space
root ::= "{" space "\"a\"" space ":" space a-A "," space "\"b\"" space ":" space integer-or-null space "}" space"""
)


def test_grammar_anyof():
sch = {
schema = {
"properties": {
"temperature": {
"description": "The temperature mentioned",
Expand All @@ -73,6 +94,65 @@ def test_grammar_anyof():
"type": "object",
}

grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(sch))
grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(schema))

assert grammar.grammar is not None

assert (
llama_cpp.llama_grammar.json_schema_to_gbnf(json.dumps(schema), None)
== r"""space ::= " "?
number ::= ("-"? ([0-9] | [1-9] [0-9]*)) ("." [0-9]+)? ([eE] [-+]? [0-9]+)? space
unit-0 ::= "\"celsius\"" | "\"fahrenheit\""
null ::= "null" space
unit ::= unit-0 | null
root ::= "{" space "\"temperature\"" space ":" space number "," space "\"unit\"" space ":" space unit space "}" space"""
)


def test_grammar_nested_object():
schema = {
"type": "object",
"properties": {
"test": {"type": "string"},
"nested": {
"type": "object",
"properties": {"other": {"type": "string"}},
"required": [],
},
},
"required": ["test"],
}

grammar = llama_cpp.LlamaGrammar.from_json_schema(json.dumps(schema))

assert grammar.grammar is not None

assert (
llama_cpp.llama_grammar.json_schema_to_gbnf(
json.dumps(schema), treat_optional_as_nullable=False
)
== r"""space ::= " "?
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
nested ::= "{" space "\"other\"" space ":" space string space "}" space
root ::= "{" space ("\"nested\"" space ":" space nested "," space)? "\"test\"" space ":" space string space "}" space"""
)

assert grammar.grammar is not None
assert (
llama_cpp.llama_grammar.json_schema_to_gbnf(
json.dumps(schema), treat_optional_as_nullable=True
)
== r"""space ::= " "?
string-or-null ::= ( "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" | "null") space
nested-or-null ::= ("{" space "\"other\"" space ":" space string-or-null space "}" | "null") space
string ::= "\"" (
[^"\\] |
"\\" (["\\/bfnrt] | "u" [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F] [0-9a-fA-F])
)* "\"" space
root ::= "{" space "\"nested\"" space ":" space nested-or-null "," space "\"test\"" space ":" space string space "}" space"""
)