Skip to content

Commit

Permalink
Parameters are now more forgiving about type
Browse files Browse the repository at this point in the history
  • Loading branch information
systay committed Sep 17, 2011
1 parent 67904f6 commit 36b3136
Show file tree
Hide file tree
Showing 5 changed files with 83 additions and 52 deletions.
36 changes: 27 additions & 9 deletions cypher/src/main/scala/org/neo4j/cypher/ExecutionEngine.scala
Expand Up @@ -129,15 +129,7 @@ class ExecutionEngine(graph: GraphDatabaseService) {
indexHits.asScala
})

case NodeById(varName, id) => new StartPipe(lastPipe, varName, m => {
id(m) match {
case x: Traversable[Long] => x.map(graph.getNodeById).toSeq
case x => {
println(x)
throw new Exception("Wut?")
}
}
})
case NodeById(varName, id) => new StartPipe(lastPipe, varName, m => makeLongSeq(id(m), varName).map(graph.getNodeById))
case RelationshipById(varName, ids@_*) => new StartPipe(lastPipe, varName, ids.map(graph.getRelationshipById))
}

Expand All @@ -148,4 +140,30 @@ class ExecutionEngine(graph: GraphDatabaseService) {
util.Properties.versionString)
}
}

private def makeLongSeq(result: Any, name: String): Seq[Long] = {
if (result.isInstanceOf[Int]) {
return Seq(result.asInstanceOf[Int].toLong)
}

if (result.isInstanceOf[Long]) {
return Seq(result.asInstanceOf[Long])
}

def makeLong(x: Any): Long = x match {
case i: Int => i.toLong
case i: Long => i
case i: String => i.toLong
}

if (result.isInstanceOf[java.lang.Iterable[_]]) {
return result.asInstanceOf[java.lang.Iterable[_]].asScala.map(makeLong).toSeq
}

if (result.isInstanceOf[Traversable[_]]) {
return result.asInstanceOf[Traversable[_]].map(makeLong).toSeq
}

throw new ParameterNotFoundException("Expected " + name + " to be a Long, or an Iterable of Long. It was '" + result + "'")
}
}
47 changes: 23 additions & 24 deletions cypher/src/main/scala/org/neo4j/cypher/commands/Value.scala
Expand Up @@ -24,7 +24,7 @@ import scala.collection.JavaConverters._
import org.neo4j.graphdb._
import org.neo4j.cypher.{ParameterNotFoundException, SyntaxException, SymbolTable}

abstract sealed class Value extends (Map[String,Any]=>Any) {
abstract sealed class Value extends (Map[String, Any] => Any) {
def identifier: Identifier

def checkAvailable(symbols: SymbolTable)
Expand All @@ -38,21 +38,20 @@ case class Literal(v: Any) extends Value {
def checkAvailable(symbols: SymbolTable) {}
}

abstract case class FunctionValue(functionName : String, arguments: Value*) extends Value {
abstract case class FunctionValue(functionName: String, arguments: Value*) extends Value {

def identifier: Identifier = ValueIdentifier(functionName +"(" + arguments.map(_.identifier.name).mkString(",")+")");
def identifier: Identifier = ValueIdentifier(functionName + "(" + arguments.map(_.identifier.name).mkString(",") + ")");

def checkAvailable(symbols: SymbolTable) {
arguments.foreach( _.checkAvailable(symbols))
arguments.foreach(_.checkAvailable(symbols))
}
}



abstract class AggregationValue(functionName: String, inner: Value) extends Value {
def apply(m: Map[String, Any]) = m(identifier.name)

def identifier: Identifier = AggregationIdentifier(functionName+"("+inner.identifier.name+")")
def identifier: Identifier = AggregationIdentifier(functionName + "(" + inner.identifier.name + ")")

def checkAvailable(symbols: SymbolTable) {
inner.checkAvailable(symbols)
Expand All @@ -61,23 +60,23 @@ abstract class AggregationValue(functionName: String, inner: Value) extends Valu
def createAggregationFunction: AggregationFunction
}

case class Count(anInner: Value) extends AggregationValue("count",anInner) {
case class Count(anInner: Value) extends AggregationValue("count", anInner) {
def createAggregationFunction = new CountFunction(anInner)
}

case class Sum(anInner: Value) extends AggregationValue("sum",anInner) {
case class Sum(anInner: Value) extends AggregationValue("sum", anInner) {
def createAggregationFunction = new SumFunction(anInner)
}

case class Min(anInner: Value) extends AggregationValue("min",anInner) {
case class Min(anInner: Value) extends AggregationValue("min", anInner) {
def createAggregationFunction = new MinFunction(anInner)
}

case class Max(anInner: Value) extends AggregationValue("max",anInner) {
case class Max(anInner: Value) extends AggregationValue("max", anInner) {
def createAggregationFunction = new MaxFunction(anInner)
}

case class Avg(anInner: Value) extends AggregationValue("avg",anInner) {
case class Avg(anInner: Value) extends AggregationValue("avg", anInner) {
def createAggregationFunction = new AvgFunction(anInner)
}

Expand All @@ -104,55 +103,55 @@ case class PropertyValue(entity: String, property: String) extends Value {
}
}

case class RelationshipTypeValue(relationship: Value) extends FunctionValue("TYPE",relationship) {
case class RelationshipTypeValue(relationship: Value) extends FunctionValue("TYPE", relationship) {
def apply(m: Map[String, Any]): Any = relationship(m).asInstanceOf[Relationship].getType.name()

override def checkAvailable(symbols: SymbolTable) {
symbols.assertHas(RelationshipIdentifier(relationship.identifier.name))
}
}

case class ArrayLengthValue(inner: Value) extends FunctionValue("LENGTH",inner) {
case class ArrayLengthValue(inner: Value) extends FunctionValue("LENGTH", inner) {
def apply(m: Map[String, Any]): Any = inner(m) match {
case path:Path => path.length()
case path: Path => path.length()
case x => throw new SyntaxException("Expected " + inner.identifier.name + " to be an iterable, but it is not.")
}
}


case class IdValue(inner: Value) extends FunctionValue("ID",inner) {
case class IdValue(inner: Value) extends FunctionValue("ID", inner) {
def apply(m: Map[String, Any]): Any = inner(m) match {
case node:Node => node.getId
case rel:Relationship => rel.getId
case node: Node => node.getId
case rel: Relationship => rel.getId
case x => throw new SyntaxException("Expected " + inner.identifier.name + " to be a node or relationship.")
}
}

case class PathNodesValue(path: EntityValue) extends FunctionValue("NODES",path) {
case class PathNodesValue(path: EntityValue) extends FunctionValue("NODES", path) {
def apply(m: Map[String, Any]): Any = path(m) match {
case p : Path => p.nodes().asScala.toSeq
case p: Path => p.nodes().asScala.toSeq
case x => throw new SyntaxException("Expected " + path.identifier.name + " to be a path.")
}
}

case class PathRelationshipsValue(path: EntityValue) extends FunctionValue("RELATIONSHIPS",path) {
case class PathRelationshipsValue(path: EntityValue) extends FunctionValue("RELATIONSHIPS", path) {
def apply(m: Map[String, Any]): Any = path(m) match {
case p : Path => p.relationships().asScala.toSeq
case p: Path => p.relationships().asScala.toSeq
case x => throw new SyntaxException("Expected " + path.identifier.name + " to be a path.")
}
}

case class EntityValue(entityName:String) extends Value {
case class EntityValue(entityName: String) extends Value {
def apply(m: Map[String, Any]): Any = m.getOrElse(entityName, throw new NotFoundException)

def identifier: Identifier = Identifier(entityName)

def checkAvailable(symbols: SymbolTable) {
symbols.assertHas(Identifier(entityName))
symbols.assertHas(Identifier(entityName))
}
}

case class ParameterValue(parameterName:String) extends Value {
case class ParameterValue(parameterName: String) extends Value {
def apply(m: Map[String, Any]): Any = m.getOrElse(parameterName, throw new ParameterNotFoundException("Expected a parameter named " + parameterName))

def identifier: Identifier = Identifier(parameterName)
Expand Down
14 changes: 6 additions & 8 deletions cypher/src/main/scala/org/neo4j/cypher/parser/StartClause.scala
Expand Up @@ -19,17 +19,15 @@ package org.neo4j.cypher.parser
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <http://www.gnu.org/licenses/>.
*/

import org.neo4j.cypher.commands._
import scala.util.parsing.combinator._

trait StartClause extends JavaTokenParsers with Tokens {
def start: Parser[Start] = ignoreCase("start") ~> rep1sep(nodeByParam | nodeByIds | nodeByIndex | nodeByIndexQuery | relsByIds | relsByIndex , ",") ^^ (Start(_: _*))

def nodeByParam = identity ~ "=" ~ "(" ~ "::" ~ identity ~ ")" ^^ {
case varName ~ "=" ~ "(" ~ "::" ~ paramName ~ ")" => {
println(varName)
println(paramName)
NodeById(varName, ParameterValue(paramName))
}
def start: Parser[Start] = ignoreCase("start") ~> rep1sep(nodeByParam | nodeByIds | nodeByIndex | nodeByIndexQuery | relsByIds | relsByIndex, ",") ^^ (Start(_: _*))

def nodeByParam = identity ~ "=" ~ "(" ~ "::" ~ identity ~ ")" ^^ {
case varName ~ "=" ~ "(" ~ "::" ~ paramName ~ ")" => NodeById(varName, ParameterValue(paramName))
}


Expand Down
9 changes: 2 additions & 7 deletions cypher/src/main/scala/org/neo4j/cypher/pipes/StartPipe.scala
Expand Up @@ -20,18 +20,13 @@
package org.neo4j.cypher.pipes

import org.neo4j.cypher.SymbolTable
import org.neo4j.graphdb.{Relationship, Node, PropertyContainer}
import org.neo4j.cypher.commands.{Identifier, RelationshipIdentifier, NodeIdentifier}
import org.neo4j.graphdb.PropertyContainer
import org.neo4j.cypher.commands.{Identifier, NodeIdentifier}

class StartPipe[T <: PropertyContainer](inner: Pipe, name: String, createSource: Map[String,Any] => Iterable[T]) extends Pipe {
def this(inner: Pipe, name: String, sourceIterable: Iterable[T]) = this(inner, name, m => sourceIterable)


val symbolType: Identifier =NodeIdentifier(name)
// source match {
// case nodes: Iterable[Node] => NodeIdentifier(name)
// case rels: Iterable[Relationship] => RelationshipIdentifier(name)
// }

val symbols: SymbolTable = inner.symbols.add(Seq(symbolType))

Expand Down
29 changes: 25 additions & 4 deletions cypher/src/test/scala/org/neo4j/cypher/ExecutionEngineTest.scala
Expand Up @@ -715,16 +715,37 @@ class ExecutionEngineTest extends ExecutionEngineHelper {
), result.columnAs[Path]("p").toList)
}

@Test def shouldBeAbleToTakeParams() {
@Test def shouldBeAbleToTakeParamsInDifferentTypes() {
createNodes("A", "B", "C", "D", "E")

val query = Query.
start(
NodeById("pA", ParameterValue("a")),
NodeById("pB", ParameterValue("b")),
NodeById("pC", ParameterValue("c")),
NodeById("pD", ParameterValue("d")),
NodeById("pE", ParameterValue("e"))).
returns(ValueReturnItem(EntityValue("pA")), ValueReturnItem(EntityValue("pB")), ValueReturnItem(EntityValue("pC")), ValueReturnItem(EntityValue("pD")), ValueReturnItem(EntityValue("pE")))

val result = execute(query,
"a" -> Seq[Long](1),
"b" -> 2,
"c" -> Seq(3L).asJava,
"d" -> Seq(4).asJava,
"e" -> List(5)
)

assertEquals(1, result.toList.size)
}

@Test(expected = classOf[ParameterNotFoundException]) def parameterTypeErrorShouldBeNicelyExplained() {
createNodes("A")

val query = Query.
start(NodeById("pA", ParameterValue("a"))).
returns(ValueReturnItem(EntityValue("pA")))

val result = execute(query, "a" -> Seq[Long](1))

assertEquals(List(Map("pA" -> node("A"))), result.toList)
execute(query, "a" -> "Andres").toList
}

@Test def shouldBeAbleToTakeParamsFromParsedStuff() {
Expand Down

0 comments on commit 36b3136

Please sign in to comment.