Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 16 additions & 11 deletions src/graphforge/parser/cypher.lark
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,28 @@ script: query+
query: union_query | single_part_query | multi_part_query | with_query

// UNION queries - combines multiple query results
union_query: (single_part_query | multi_part_query) (union_clause (single_part_query | multi_part_query))+
union_query: (single_part_query | multi_part_query | with_query) (union_clause (single_part_query | multi_part_query | with_query))+

union_clause: "UNION"i "ALL"i -> union_all
| "UNION"i -> union_distinct

// Multi-part queries with WITH clause
// Allows: reading clauses, writing clauses, or both before WITH
multi_part_query: reading_or_writing_clauses+ with_clause+ final_query_part
// A multi_part_query is one or more segments ending in WITH, followed by a final part.
// Each segment may contain any mix of reading, writing, or updating clauses.
multi_part_query: multi_part_segment+ final_query_part

// Allow any combination of reading and writing clauses before WITH
reading_or_writing_clauses: reading_clause | writing_clause
// One segment: zero or more clauses of any type, ending with a WITH clause
multi_part_segment: segment_clause* with_clause

// Writing clauses that can precede WITH
writing_clause: create_clause
// Any clause that can appear before a WITH clause
segment_clause: reading_clause
| create_clause
| merge_clause
| set_clause
| remove_clause
| delete_clause

// Queries starting with WITH clause
// Queries starting with WITH clause (no preceding read/write clauses)
with_query: with_clause+ final_query_part

// Single-part queries (without WITH)
Expand Down Expand Up @@ -209,8 +214,8 @@ POW_OP: "^"

label_predicate: IDENTIFIER (":" IDENTIFIER)+

exists_expr: "EXISTS"i "{" single_part_query "}"
count_expr: "COUNT"i "{" single_part_query "}"
exists_expr: "EXISTS"i "{" query "}"
count_expr: "COUNT"i "{" query "}"

quantifier_expr: "ALL"i "(" variable "IN"i expression "WHERE"i expression ")" -> all_quantifier
| "ANY"i "(" variable "IN"i expression "WHERE"i expression ")" -> any_quantifier
Expand Down Expand Up @@ -294,7 +299,7 @@ IDENTIFIER: /[a-zA-Z_][a-zA-Z0-9_]*/
FUNCTION_NAME: /relationships|percentiledisc|percentilecont|localdatetime|substring|toboolean|tointeger|tofloat|tostring|toupper|tolower|truncate|datetime|localtime|duration|distance|dangerous|coalesce|collect|replace|isempty|length|minute|second|reverse|ltrim|rtrim|exists|count|month|lower|upper|point|nodes|split|stdevp|stdev|sqrt|trim|year|type|date|time|hour|tail|head|last|right|left|floor|round|range|rand|ceil|sign|pow|day|sum|avg|min|max|abs|size|labels|id/i

INT: /0[xX][0-9a-fA-F]+|0[oO][0-7]+|[0-9]+/
FLOAT: /[0-9]+\.[0-9]+/
FLOAT: /[0-9]+\.[0-9]+([eE][+-]?[0-9]+)?|\.[0-9]+([eE][+-]?[0-9]+)?|[0-9]+[eE][+-]?[0-9]+/
STRING: /"([^"\\]|\\.)*"/ | /'([^'\\]|\\.)*'/

// Whitespace
Expand Down
62 changes: 35 additions & 27 deletions src/graphforge/parser/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
"""

from functools import lru_cache
import math
from pathlib import Path
from typing import cast

Expand Down Expand Up @@ -134,27 +135,49 @@ def single_part_query(self, items):
def multi_part_query(self, items):
"""Transform multi-part query (with WITH clauses).

Structure: reading_or_writing_clauses+ with_clause+ final_query_part
Each reading_or_writing_clauses is a list of clauses (MATCH, WHERE, CREATE, MERGE)
Each with_clause is a WithClause
Structure: multi_part_segment+ final_query_part
Each multi_part_segment is a list of clauses ending with a WithClause
final_query_part is a CypherQuery
"""
# Flatten all clauses from reading/writing clauses, with clauses, and final query
# Flatten all clauses from segments and the final query part
all_clauses = []

for item in items:
if isinstance(item, list):
# reading_or_writing_clauses returns a list of clauses
# multi_part_segment returns a list of clauses
all_clauses.extend(item)
elif isinstance(item, WithClause):
# with_clause returns a single WithClause
all_clauses.append(item)
elif isinstance(item, CypherQuery):
# final_query_part returns a CypherQuery
all_clauses.extend(item.clauses)

return CypherQuery(clauses=all_clauses)

def multi_part_segment(self, items):
"""Transform one segment of a multi-part query.

Structure: segment_clause* with_clause
Returns a flat list of clauses (segment clauses + the WITH clause).
"""
all_clauses = []
for item in items:
if isinstance(item, list):
all_clauses.extend(item)
elif isinstance(item, WithClause):
all_clauses.append(item)
else:
all_clauses.append(item)
return all_clauses

def segment_clause(self, items):
"""Transform a single clause that can appear before WITH.

Returns a list for consistent flattening in multi_part_segment.
"""
item = items[0]
if isinstance(item, list):
return item
return [item]

def with_query(self, items):
"""Transform query starting with WITH clause (Issue #172).

Expand Down Expand Up @@ -182,24 +205,6 @@ def reading_clause(self, items):
"""
return list(items)

def writing_clause(self, items):
"""Transform writing clause (CREATE/MERGE).

Returns the clause directly (single clause, no WHERE allowed).
"""
return items[0]

def reading_or_writing_clauses(self, items):
"""Transform reading or writing clauses for multi_part_query.

Returns a list of clauses for flattening.
"""
# reading_clause returns a list, writing_clause returns a single clause
if isinstance(items[0], list):
return items[0]
else:
return [items[0]]

def final_query_part(self, items):
"""Transform final part of multi-part query.

Expand Down Expand Up @@ -1032,7 +1037,10 @@ def int_literal(self, items):

def float_literal(self, items):
"""Transform float literal."""
return Literal(value=float(items[0]))
val = float(items[0])
if math.isinf(val):
raise ValueError(f"Floating point overflow: {items[0]}")
return Literal(value=val)

def string_literal(self, items):
"""Transform string literal."""
Expand Down
108 changes: 108 additions & 0 deletions tests/integration/test_with_clause_positions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
"""Integration tests for WITH clause position support (Issue #257).

Verify that multi-clause queries with WITH after SET, REMOVE, DELETE, and in
chained multi-part queries produce correct results end-to-end.
"""

from graphforge import GraphForge


class TestWithClausePositionExecution:
"""Execution-level tests — verify multi-WITH patterns produce correct results."""

def test_match_set_with_return(self):
"""MATCH ... SET ... WITH ... RETURN propagates updated property."""
gf = GraphForge()
gf.execute("CREATE (:T {x: 0})")
result = gf.execute("MATCH (n:T) SET n.x = 99 WITH n RETURN n.x AS val")
assert len(result) == 1
assert result[0]["val"].value == 99

def test_match_with_set_with_return(self):
"""MATCH ... WITH ... SET ... WITH ... RETURN: two WITH segments."""
gf = GraphForge()
gf.execute("CREATE (:T {x: 0, y: 0})")
result = gf.execute(
"MATCH (n:T) WITH n SET n.x = 1 WITH n SET n.y = 2 WITH n RETURN n.x AS x, n.y AS y"
)
assert len(result) == 1
assert result[0]["x"].value == 1
assert result[0]["y"].value == 2

def test_match_with_match_with_return(self):
"""MATCH ... WITH ... MATCH ... WITH ... RETURN: chained multi-part."""
gf = GraphForge()
gf.execute("CREATE (:A {v: 1})")
gf.execute("CREATE (:B {v: 2})")
result = gf.execute("MATCH (a:A) WITH a MATCH (b:B) WITH a, b RETURN a.v AS av, b.v AS bv")
assert len(result) == 1
assert result[0]["av"].value == 1
assert result[0]["bv"].value == 2

def test_three_segment_chain(self):
"""Three-segment multi-part query works correctly."""
gf = GraphForge()
gf.execute("CREATE (:A)-[:R]->(:B)-[:R]->(:C)")
result = gf.execute(
"MATCH (a:A) WITH a "
"MATCH (a)-[:R]->(b:B) WITH a, b "
"MATCH (b)-[:R]->(c:C) WITH a, b, c "
"RETURN a, b, c"
)
assert len(result) == 1

def test_create_with_set_with_return(self):
"""CREATE ... WITH ... SET ... WITH ... RETURN creates and updates node."""
gf = GraphForge()
result = gf.execute("CREATE (n:T {x: 0}) WITH n SET n.x = 42 WITH n RETURN n.x AS val")
assert len(result) == 1
assert result[0]["val"].value == 42

def test_set_with_produces_correct_row_count(self):
"""SET ... WITH produces one row per matched node."""
gf = GraphForge()
gf.execute("CREATE (:T {v: 1})")
gf.execute("CREATE (:T {v: 2})")
gf.execute("CREATE (:T {v: 3})")
result = gf.execute("MATCH (n:T) SET n.updated = true WITH n RETURN n.v AS v ORDER BY v")
assert len(result) == 3
assert [r["v"].value for r in result] == [1, 2, 3]

def test_match_remove_with_return(self):
"""MATCH ... REMOVE ... WITH ... RETURN: label removed, node still returned."""
gf = GraphForge()
gf.execute("CREATE (:A:B {x: 1})")
result = gf.execute("MATCH (n:A:B) REMOVE n:B WITH n RETURN n.x AS x")
assert len(result) == 1
assert result[0]["x"].value == 1
# Verify label was removed
check = gf.execute("MATCH (n:B) RETURN n")
assert len(check) == 0

def test_existing_with_at_start_still_works(self):
"""Regression: WITH at query start (issue #172) still works."""
gf = GraphForge()
result = gf.execute("WITH 1 AS x, 2 AS y RETURN x + y AS sum")
assert len(result) == 1
assert result[0]["sum"].value == 3

def test_existing_match_with_return_still_works(self):
"""Regression: simple MATCH ... WITH ... RETURN still works."""
gf = GraphForge()
gf.execute("CREATE (:Person {name: 'Alice'})")
result = gf.execute("MATCH (n:Person) WITH n RETURN n.name AS name")
assert len(result) == 1
assert result[0]["name"].value == "Alice"

def test_with_where_filters_in_segment(self):
"""WITH clause with WHERE correctly filters rows in multi-segment query."""
gf = GraphForge()
gf.execute("CREATE (:T {v: 1})")
gf.execute("CREATE (:T {v: 5})")
gf.execute("CREATE (:T {v: 10})")
result = gf.execute(
"MATCH (n:T) WITH n WHERE n.v > 3 MATCH (m:T) WHERE m.v = n.v RETURN m.v AS val"
)
# Only nodes with v > 3 pass the WITH WHERE filter
vals = sorted(r["val"].value for r in result)
assert vals == [5, 10]
41 changes: 41 additions & 0 deletions tests/unit/parser/test_with_clause_positions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Unit tests for WITH clause position support (Issue #257).

Tests that WITH is accepted after all valid clause combinations in multi-clause
queries, including after SET, REMOVE, DELETE, and in chained multi-part queries.
"""

import pytest

from graphforge.parser.parser import parse_cypher

_WITH_POSITION_CASES = [
("MATCH (n) SET n.x = 1 WITH n RETURN n", "MATCH SET WITH RETURN"),
("MATCH (n:T) REMOVE n:T WITH n RETURN n", "MATCH REMOVE WITH RETURN"),
("MATCH (n:T) DELETE n WITH 1 AS done RETURN done", "MATCH DELETE WITH RETURN"),
("MATCH (n) WITH n SET n.x = 1 WITH n RETURN n", "two WITH segments"),
("MATCH (n:T) WITH n REMOVE n:T WITH n RETURN n", "WITH REMOVE WITH"),
("MATCH (a) WITH a MATCH (b) WITH a, b RETURN a, b", "triple part"),
(
"MATCH (a) WITH a MATCH (b) WITH a, b MATCH (c) WITH a, b, c RETURN a, b, c",
"three WITH segments",
),
("CREATE (n:T) WITH n SET n.x = 1 WITH n RETURN n.x AS val", "CREATE WITH SET WITH"),
("MATCH (n) SET n.x = 1 RETURN n", "SET without subsequent WITH"),
("MATCH (n) WITH n WITH n AS m RETURN m", "consecutive WITH clauses"),
(
"MATCH (n) WHERE EXISTS { MATCH (n)-[]->(m) WITH m RETURN m } RETURN n",
"EXISTS subquery with WITH inside",
),
("MATCH (n) WHERE n.x > 1 WITH n RETURN n", "WHERE WITH regression"),
("UNWIND [1, 2, 3] AS x WITH x MATCH (n {id: x}) RETURN n", "UNWIND WITH MATCH"),
]


class TestWithClausePositionParsing:
"""Parse-level tests — verify grammar accepts all WITH positions."""

@pytest.mark.parametrize("query,description", _WITH_POSITION_CASES)
def test_with_clause_positions(self, query: str, description: str) -> None:
"""Verify grammar accepts WITH in various positions."""
ast = parse_cypher(query)
assert ast is not None, f"Failed to parse: {description}"
Loading