Permalink
Browse files

Parameters are now more forgiving about type

  • Loading branch information...
1 parent 67904f6 commit 36b31368286db3d86111d2fb8b7c9dff0639e611 @systay systay committed Sep 17, 2011
@@ -129,15 +129,7 @@ class ExecutionEngine(graph: GraphDatabaseService) {
indexHits.asScala
})
- case NodeById(varName, id) => new StartPipe(lastPipe, varName, m => {
- id(m) match {
- case x: Traversable[Long] => x.map(graph.getNodeById).toSeq
- case x => {
- println(x)
- throw new Exception("Wut?")
- }
- }
- })
+ case NodeById(varName, id) => new StartPipe(lastPipe, varName, m => makeLongSeq(id(m), varName).map(graph.getNodeById))
case RelationshipById(varName, ids@_*) => new StartPipe(lastPipe, varName, ids.map(graph.getRelationshipById))
}
@@ -148,4 +140,30 @@ class ExecutionEngine(graph: GraphDatabaseService) {
util.Properties.versionString)
}
}
+
+ private def makeLongSeq(result: Any, name: String): Seq[Long] = {
+ if (result.isInstanceOf[Int]) {
+ return Seq(result.asInstanceOf[Int].toLong)
+ }
+
+ if (result.isInstanceOf[Long]) {
+ return Seq(result.asInstanceOf[Long])
+ }
+
+ def makeLong(x: Any): Long = x match {
+ case i: Int => i.toLong
+ case i: Long => i
+ case i: String => i.toLong
+ }
+
+ if (result.isInstanceOf[java.lang.Iterable[_]]) {
+ return result.asInstanceOf[java.lang.Iterable[_]].asScala.map(makeLong).toSeq
+ }
+
+ if (result.isInstanceOf[Traversable[_]]) {
+ return result.asInstanceOf[Traversable[_]].map(makeLong).toSeq
+ }
+
+ throw new ParameterNotFoundException("Expected " + name + " to be a Long, or an Iterable of Long. It was '" + result + "'")
+ }
}
@@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
import org.neo4j.graphdb._
import org.neo4j.cypher.{ParameterNotFoundException, SyntaxException, SymbolTable}
-abstract sealed class Value extends (Map[String,Any]=>Any) {
+abstract sealed class Value extends (Map[String, Any] => Any) {
def identifier: Identifier
def checkAvailable(symbols: SymbolTable)
@@ -38,21 +38,20 @@ case class Literal(v: Any) extends Value {
def checkAvailable(symbols: SymbolTable) {}
}
-abstract case class FunctionValue(functionName : String, arguments: Value*) extends Value {
+abstract case class FunctionValue(functionName: String, arguments: Value*) extends Value {
- def identifier: Identifier = ValueIdentifier(functionName +"(" + arguments.map(_.identifier.name).mkString(",")+")");
+ def identifier: Identifier = ValueIdentifier(functionName + "(" + arguments.map(_.identifier.name).mkString(",") + ")");
def checkAvailable(symbols: SymbolTable) {
- arguments.foreach( _.checkAvailable(symbols))
+ arguments.foreach(_.checkAvailable(symbols))
}
}
-
abstract class AggregationValue(functionName: String, inner: Value) extends Value {
def apply(m: Map[String, Any]) = m(identifier.name)
- def identifier: Identifier = AggregationIdentifier(functionName+"("+inner.identifier.name+")")
+ def identifier: Identifier = AggregationIdentifier(functionName + "(" + inner.identifier.name + ")")
def checkAvailable(symbols: SymbolTable) {
inner.checkAvailable(symbols)
@@ -61,23 +60,23 @@ abstract class AggregationValue(functionName: String, inner: Value) extends Valu
def createAggregationFunction: AggregationFunction
}
-case class Count(anInner: Value) extends AggregationValue("count",anInner) {
+case class Count(anInner: Value) extends AggregationValue("count", anInner) {
def createAggregationFunction = new CountFunction(anInner)
}
-case class Sum(anInner: Value) extends AggregationValue("sum",anInner) {
+case class Sum(anInner: Value) extends AggregationValue("sum", anInner) {
def createAggregationFunction = new SumFunction(anInner)
}
-case class Min(anInner: Value) extends AggregationValue("min",anInner) {
+case class Min(anInner: Value) extends AggregationValue("min", anInner) {
def createAggregationFunction = new MinFunction(anInner)
}
-case class Max(anInner: Value) extends AggregationValue("max",anInner) {
+case class Max(anInner: Value) extends AggregationValue("max", anInner) {
def createAggregationFunction = new MaxFunction(anInner)
}
-case class Avg(anInner: Value) extends AggregationValue("avg",anInner) {
+case class Avg(anInner: Value) extends AggregationValue("avg", anInner) {
def createAggregationFunction = new AvgFunction(anInner)
}
@@ -104,55 +103,55 @@ case class PropertyValue(entity: String, property: String) extends Value {
}
}
-case class RelationshipTypeValue(relationship: Value) extends FunctionValue("TYPE",relationship) {
+case class RelationshipTypeValue(relationship: Value) extends FunctionValue("TYPE", relationship) {
def apply(m: Map[String, Any]): Any = relationship(m).asInstanceOf[Relationship].getType.name()
override def checkAvailable(symbols: SymbolTable) {
symbols.assertHas(RelationshipIdentifier(relationship.identifier.name))
}
}
-case class ArrayLengthValue(inner: Value) extends FunctionValue("LENGTH",inner) {
+case class ArrayLengthValue(inner: Value) extends FunctionValue("LENGTH", inner) {
def apply(m: Map[String, Any]): Any = inner(m) match {
- case path:Path => path.length()
+ case path: Path => path.length()
case x => throw new SyntaxException("Expected " + inner.identifier.name + " to be an iterable, but it is not.")
}
}
-case class IdValue(inner: Value) extends FunctionValue("ID",inner) {
+case class IdValue(inner: Value) extends FunctionValue("ID", inner) {
def apply(m: Map[String, Any]): Any = inner(m) match {
- case node:Node => node.getId
- case rel:Relationship => rel.getId
+ case node: Node => node.getId
+ case rel: Relationship => rel.getId
case x => throw new SyntaxException("Expected " + inner.identifier.name + " to be a node or relationship.")
}
}
-case class PathNodesValue(path: EntityValue) extends FunctionValue("NODES",path) {
+case class PathNodesValue(path: EntityValue) extends FunctionValue("NODES", path) {
def apply(m: Map[String, Any]): Any = path(m) match {
- case p : Path => p.nodes().asScala.toSeq
+ case p: Path => p.nodes().asScala.toSeq
case x => throw new SyntaxException("Expected " + path.identifier.name + " to be a path.")
}
}
-case class PathRelationshipsValue(path: EntityValue) extends FunctionValue("RELATIONSHIPS",path) {
+case class PathRelationshipsValue(path: EntityValue) extends FunctionValue("RELATIONSHIPS", path) {
def apply(m: Map[String, Any]): Any = path(m) match {
- case p : Path => p.relationships().asScala.toSeq
+ case p: Path => p.relationships().asScala.toSeq
case x => throw new SyntaxException("Expected " + path.identifier.name + " to be a path.")
}
}
-case class EntityValue(entityName:String) extends Value {
+case class EntityValue(entityName: String) extends Value {
def apply(m: Map[String, Any]): Any = m.getOrElse(entityName, throw new NotFoundException)
def identifier: Identifier = Identifier(entityName)
def checkAvailable(symbols: SymbolTable) {
- symbols.assertHas(Identifier(entityName))
+ symbols.assertHas(Identifier(entityName))
}
}
-case class ParameterValue(parameterName:String) extends Value {
+case class ParameterValue(parameterName: String) extends Value {
def apply(m: Map[String, Any]): Any = m.getOrElse(parameterName, throw new ParameterNotFoundException("Expected a parameter named " + parameterName))
def identifier: Identifier = Identifier(parameterName)
@@ -19,17 +19,15 @@ package org.neo4j.cypher.parser
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/
+
import org.neo4j.cypher.commands._
import scala.util.parsing.combinator._
+
trait StartClause extends JavaTokenParsers with Tokens {
- def start: Parser[Start] = ignoreCase("start") ~> rep1sep(nodeByParam | nodeByIds | nodeByIndex | nodeByIndexQuery | relsByIds | relsByIndex , ",") ^^ (Start(_: _*))
-
- def nodeByParam = identity ~ "=" ~ "(" ~ "::" ~ identity ~ ")" ^^ {
- case varName ~ "=" ~ "(" ~ "::" ~ paramName ~ ")" => {
- println(varName)
- println(paramName)
- NodeById(varName, ParameterValue(paramName))
- }
+ def start: Parser[Start] = ignoreCase("start") ~> rep1sep(nodeByParam | nodeByIds | nodeByIndex | nodeByIndexQuery | relsByIds | relsByIndex, ",") ^^ (Start(_: _*))
+
+ def nodeByParam = identity ~ "=" ~ "(" ~ "::" ~ identity ~ ")" ^^ {
+ case varName ~ "=" ~ "(" ~ "::" ~ paramName ~ ")" => NodeById(varName, ParameterValue(paramName))
}
@@ -20,18 +20,13 @@
package org.neo4j.cypher.pipes
import org.neo4j.cypher.SymbolTable
-import org.neo4j.graphdb.{Relationship, Node, PropertyContainer}
-import org.neo4j.cypher.commands.{Identifier, RelationshipIdentifier, NodeIdentifier}
+import org.neo4j.graphdb.PropertyContainer
+import org.neo4j.cypher.commands.{Identifier, NodeIdentifier}
class StartPipe[T <: PropertyContainer](inner: Pipe, name: String, createSource: Map[String,Any] => Iterable[T]) extends Pipe {
def this(inner: Pipe, name: String, sourceIterable: Iterable[T]) = this(inner, name, m => sourceIterable)
-
val symbolType: Identifier =NodeIdentifier(name)
-// source match {
-// case nodes: Iterable[Node] => NodeIdentifier(name)
-// case rels: Iterable[Relationship] => RelationshipIdentifier(name)
-// }
val symbols: SymbolTable = inner.symbols.add(Seq(symbolType))
@@ -715,16 +715,37 @@ class ExecutionEngineTest extends ExecutionEngineHelper {
), result.columnAs[Path]("p").toList)
}
- @Test def shouldBeAbleToTakeParams() {
+ @Test def shouldBeAbleToTakeParamsInDifferentTypes() {
+ createNodes("A", "B", "C", "D", "E")
+
+ val query = Query.
+ start(
+ NodeById("pA", ParameterValue("a")),
+ NodeById("pB", ParameterValue("b")),
+ NodeById("pC", ParameterValue("c")),
+ NodeById("pD", ParameterValue("d")),
+ NodeById("pE", ParameterValue("e"))).
+ returns(ValueReturnItem(EntityValue("pA")), ValueReturnItem(EntityValue("pB")), ValueReturnItem(EntityValue("pC")), ValueReturnItem(EntityValue("pD")), ValueReturnItem(EntityValue("pE")))
+
+ val result = execute(query,
+ "a" -> Seq[Long](1),
+ "b" -> 2,
+ "c" -> Seq(3L).asJava,
+ "d" -> Seq(4).asJava,
+ "e" -> List(5)
+ )
+
+ assertEquals(1, result.toList.size)
+ }
+
+ @Test(expected = classOf[ParameterNotFoundException]) def parameterTypeErrorShouldBeNicelyExplained() {
createNodes("A")
val query = Query.
start(NodeById("pA", ParameterValue("a"))).
returns(ValueReturnItem(EntityValue("pA")))
- val result = execute(query, "a" -> Seq[Long](1))
-
- assertEquals(List(Map("pA" -> node("A"))), result.toList)
+ execute(query, "a" -> "Andres").toList
}
@Test def shouldBeAbleToTakeParamsFromParsedStuff() {

0 comments on commit 36b3136

Please sign in to comment.