Skip to content

Commit

Permalink
Merge branch 'master' of https://github.com/HazyResearch/ddlog into d…
Browse files Browse the repository at this point in the history
…dlog-query
  • Loading branch information
netj committed Feb 1, 2016
2 parents fd509a7 + 2438f2d commit fe75974
Show file tree
Hide file tree
Showing 21 changed files with 1,108 additions and 177 deletions.
10 changes: 9 additions & 1 deletion src/main/scala/org/deepdive/ddlog/DeepDiveLog.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ object DeepDiveLog {
, inputFiles: List[String] = List()
, query: String = null
, mode: Mode = ORIGINAL
, skipDesugar: Boolean = false
)
val commandLine = new scopt.OptionParser[Config]("ddlog") {
val commonProgramOpts = List(
Expand All @@ -38,6 +39,7 @@ object DeepDiveLog {
opt[Unit]('i', "incremental") optional() action { (_, c) => c.copy(mode = INCREMENTAL) } text("Whether to derive delta rules")
opt[Unit]("materialization") optional() action { (_, c) => c.copy(mode = MATERIALIZATION) } text("Whether to materialize origin data")
opt[Unit]("merge") optional() action { (_, c) => c.copy(mode = MERGE) } text("Whether to merge delta data")
opt[Unit]("skip-desugar") optional() action { (_, c) => c.copy(skipDesugar = true) } text("Whether to skip desugaring and assume no sugar")
arg[String]("FILE...") minOccurs(0) unbounded() action { (f, c) => c.copy(inputFiles = c.inputFiles ++ List(f)) } text("Path to DDLog program files")
checkConfig { c =>
if (c.handler == null) failure("No command specified")
Expand All @@ -64,8 +66,14 @@ trait DeepDiveLogHandler {
def run(config: DeepDiveLog.Config): Unit = try {
// parse each file into a single program
val parsedProgram = parseFiles(config.inputFiles)

// desugar unless explicitly said to skip so
val programToRun =
if (config.skipDesugar) parsedProgram
else DeepDiveLogDesugarRewriter.derive(parsedProgram)

// run handler with the parsed program
run(parsedProgram, config)
run(programToRun, config)
} catch {
case e: RuntimeException =>
if (sys.env contains "DDLOG_STACK_TRACE") throw e
Expand Down
36 changes: 22 additions & 14 deletions src/main/scala/org/deepdive/ddlog/DeepDiveLogCompiler.scala
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,8 @@ class QueryCompiler(cq : ConjunctiveQuery, ss: CompilationState) {
}

// resolve an expression
def compileExpr(e: Expr) : String = {
def compileExpr(e: Expr) : String = compileExpr(e, 0)
def compileExpr(e: Expr, level: Int) : String = {
e match {
case VarExpr(name) => compileVariable(name)
case NullConst() => "NULL"
Expand All @@ -364,30 +365,37 @@ class QueryCompiler(cq : ConjunctiveQuery, ss: CompilationState) {
resolved
}
case BinaryOpExpr(lhs, op, rhs) => {
val resovledLhs = compileExpr(lhs)
val resovledRhs = compileExpr(rhs)
s"(${resovledLhs} ${op} ${resovledRhs})"
val resovledLhs = compileExpr(lhs, level + 1)
val resovledRhs = compileExpr(rhs, level + 1)
val sql = s"${resovledLhs} ${op} ${resovledRhs}"
if (level == 0) sql else s"(${sql})"
}
case TypecastExpr(lhs, rhs) => {
val resovledLhs = compileExpr(lhs)
s"(${resovledLhs} :: ${rhs})"
}
case IfThenElseExpr(ifCondThenExprPairs, optElseExpr) => {
(ifCondThenExprPairs map {
case (ifCond, thenExpr) => s"WHEN ${compileCond(ifCond)} THEN ${compileExpr(thenExpr)}"
}) ++ List(optElseExpr map compileExpr mkString("ELSE ", "", ""))
} mkString("\nCASE ", "\n ", "\nEND")
}
}

// resolve a condition
def compileCond(cond: Cond) : String = {
def compileCond(cond: Cond) : String = compileCond(cond, 0)
def compileCond(cond: Cond, level: Int) : String = {
cond match {
case ComparisonCond(lhs, op, rhs) =>
s"${compileExpr(lhs)} ${op} ${compileExpr(rhs)}"
case NegationCond(c) => s"(NOT ${compileCond(c)})"
case ExprCond(e) => compileExpr(e)
case NegationCond(c) => s"NOT ${compileCond(c, level + 1)}"
case CompoundCond(lhs, op, rhs) => {
val resolvedLhs = s"${compileCond(lhs)}"
val resolvedRhs = s"${compileCond(rhs)}"
op match {
case LogicOperator.AND => s"(${resolvedLhs} AND ${resolvedRhs})"
case LogicOperator.OR => s"(${resolvedLhs} OR ${resolvedRhs})"
val resolvedLhs = s"${compileCond(lhs, level + 1)}"
val resolvedRhs = s"${compileCond(rhs, level + 1)}"
val sql = op match {
case LogicOperator.AND => s"${resolvedLhs} AND ${resolvedRhs}"
case LogicOperator.OR => s"${resolvedLhs} OR ${resolvedRhs}"
}
if (level == 0) sql else s"(${sql})"
}
case _ => ""
}
Expand Down Expand Up @@ -611,7 +619,7 @@ object DeepDiveLogCompiler extends DeepDiveLogHandler {
if (stmt.supervision != None) {
if (stmt.q.bodies.length > 1) ss.error(s"Scoping rule does not allow disjunction.\n")
val headStr = qc.generateSQLHead(NoAlias)
val labelCol = qc.compileVariable(stmt.supervision.get)
val labelCol = qc.compileExpr(stmt.supervision.get)
inputQueries += s"""SELECT DISTINCT ${ headStr }, 0 AS id, ${labelCol} AS label
${ qc.generateSQLBody(cqBody) }
"""
Expand Down
86 changes: 86 additions & 0 deletions src/main/scala/org/deepdive/ddlog/DeepDiveLogDesugarRewriter.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package org.deepdive.ddlog

object DeepDiveLogDesugarRewriter {

// Rewrite function call rules whose output coincides with normal rules.
def desugarUnionsImpliedByFunctionCallRules(program: DeepDiveLog.Program) = {
def indexByFirst[a,b](pairs: Seq[(a,b)]): Map[a,List[b]] =
pairs groupBy { _._1 } mapValues { _ map (_._2) toList }

val schemaByName = indexByFirst(program collect {
case decl: SchemaDeclaration => decl.a.name -> decl
}) mapValues (_ head)
val rulesWithIndexByName = indexByFirst(program.zipWithIndex collect {
case (fncall: FunctionCallRule, i) => fncall.output -> (fncall, i)
case (rule : ExtractionRule , i) => rule.headName -> (rule , i)
})
val relationNamesUsedInProgram = program collect {
case decl: SchemaDeclaration => decl.a.name
case fncall: FunctionCallRule => fncall.output
case rule: ExtractionRule => rule.headName
} toSet

// find names that have multiple function calls or mixed type rules
val relationsToDesugar = rulesWithIndexByName flatMap {
case (name, allRules) =>
val (fncalls, rules) = allRules map {_._1} partition {_.isInstanceOf[FunctionCallRule]}
if ((fncalls size) > 1 || ((fncalls size) > 0 && (rules size) > 0)) {
Some(name)
} else None
}
val rulesToRewrite = relationsToDesugar flatMap rulesWithIndexByName map {
_._1} filter {_.isInstanceOf[FunctionCallRule]} toList

// determine a separator that does not create name clashes with existing heads for each relation to rewrite
val prefixForRelation: Map[String, String] = relationsToDesugar map { name =>
name -> (
Stream.from(1) map { n => s"${name}${"_" * n}"
} dropWhile { prefix =>
relationNamesUsedInProgram exists {_ startsWith prefix}
} head
)
} toMap

// how to make names unique
def makeUnique(name: String, ordLocal: Int, ordGlobal: Int): String = {
s"${prefixForRelation(name)}${ordLocal}"
}

// plan the rewrite
val rewritePlan : Map[Statement, List[Statement]] =
program collect {
// only function call rule needs to be rewritten
case fncall: FunctionCallRule if rulesToRewrite contains fncall =>
val relationName: String = fncall.output
val rulesForTheRelationToRewriteOrdered = rulesWithIndexByName(relationName
) sortBy {_._2} filter {_._1.isInstanceOf[FunctionCallRule]}
val orderAmongRulesToRewrite = rulesForTheRelationToRewriteOrdered map {_._1} indexOf(fncall)
val orderInProgram = rulesForTheRelationToRewriteOrdered(orderAmongRulesToRewrite)._2
val nameUnique: String = makeUnique(relationName, orderAmongRulesToRewrite, orderInProgram)
val schema = schemaByName(relationName)
fncall -> List(
schema.copy(a = schema.a.copy(name = nameUnique)),
fncall.copy(output = nameUnique),
ExtractionRule(
headName = relationName,
q = ConjunctiveQuery(
headTerms = schema.a.terms map VarExpr,
bodies = List(List(BodyAtom(name = nameUnique, terms = schema.a.terms map VarPattern)))
)
)
)
// TODO add union after the last or first or somewhere
} toMap

// apply rewrite plan
program flatMap { case rule => rewritePlan getOrElse(rule, List(rule)) }
}


def derive(program: DeepDiveLog.Program): DeepDiveLog.Program = {
(List(
desugarUnionsImpliedByFunctionCallRules(_)
) reduce (_.compose(_))
)(program)
}
}
82 changes: 50 additions & 32 deletions src/main/scala/org/deepdive/ddlog/DeepDiveLogParser.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ case class NullConst extends ConstExpr
case class FuncExpr(function: String, args: List[Expr], isAggregation: Boolean) extends Expr
case class BinaryOpExpr(lhs: Expr, op: String, rhs: Expr) extends Expr
case class TypecastExpr(lhs: Expr, rhs: String) extends Expr
case class IfThenElseExpr(ifCondThenExprPairs: List[(Cond, Expr)], elseExpr: Option[Expr]) extends Expr

sealed trait Pattern
case class VarPattern(name: String) extends Pattern
Expand All @@ -39,7 +40,7 @@ case class OuterModifier extends BodyModifier
case class AllModifier extends BodyModifier

case class Attribute(name : String, terms : List[String], types : List[String], annotations : List[List[Annotation]])
case class ConjunctiveQuery(headTerms: List[Expr], bodies: List[List[Body]], isDistinct: Boolean, limit: Option[Int],
case class ConjunctiveQuery(headTerms: List[Expr], bodies: List[List[Body]], isDistinct: Boolean = false, limit: Option[Int] = None,
// optional annotations for head terms
headTermAnnotations: List[List[Annotation]] = List.empty,
// XXX This flag is not ideal, but minimizes the impact of query treatment when compared to creating another case class
Expand All @@ -57,7 +58,7 @@ case class RuleAnnotation(name: String, args: List[String])

// condition
sealed trait Cond extends Body
case class ComparisonCond(lhs: Expr, op: String, rhs: Expr) extends Cond
case class ExprCond(expr: Expr) extends Cond
case class NegationCond(cond: Cond) extends Cond
case class CompoundCond(lhs: Cond, op: LogicOperator.LogicOperator, rhs: Cond) extends Cond

Expand Down Expand Up @@ -105,7 +106,7 @@ case class SchemaDeclaration( a : Attribute
) extends Statement // atom and whether this is a query relation.
case class FunctionDeclaration( functionName: String, inputType: FunctionInputOutputType,
outputType: FunctionInputOutputType, implementations: List[FunctionImplementationDeclaration]) extends Statement
case class ExtractionRule(headName: String, q : ConjunctiveQuery, supervision: Option[String] = None) extends Statement // Extraction rule
case class ExtractionRule(headName: String, q : ConjunctiveQuery, supervision: Option[Expr] = None) extends Statement // Extraction rule
case class FunctionCallRule(output: String, function: String, q : ConjunctiveQuery, mode: Option[String], parallelism: Option[Int]) extends Statement // Extraction rule
case class InferenceRule(head: InferenceRuleHead, q : ConjunctiveQuery, weights : FactorWeight, mode: Option[String] = None) extends Statement // Weighted rule

Expand Down Expand Up @@ -186,7 +187,8 @@ class DeepDiveLogParser extends JavaTokenParsers {
}
}

def operator = "||" | "+" | "-" | "*" | "/" | "&" | "%"
def operator =
( "||" | "+" | "-" | "*" | "/" | "&" | "%" )
def typeOperator = "::"
val aggregationFunctions = Set("MAX", "SUM", "MIN", "ARRAY_ACCUM", "ARRAY_AGG", "COUNT")

Expand All @@ -197,21 +199,35 @@ class DeepDiveLogParser extends JavaTokenParsers {
| lexpr
)

def lexpr : Parser[Expr] =
( functionName ~ "(" ~ rep1sep(expr, ",") ~ ")" ^^ {
case (name ~ _ ~ args ~ _) => FuncExpr(name, args, (aggregationFunctions contains name))
def cexpr =
( expr ~ compareOperator ~ expr ^^ { case (lhs ~ op ~ rhs) => BinaryOpExpr(lhs, op, rhs) }
| expr
)

def lexpr =
( "if" ~> (cond ~ ("then" ~> expr) ~ rep(elseIfExprs) ~ opt("else" ~> expr)) <~ "end" ^^ {
case (ifCond ~ thenExpr ~ elseIfs ~ optElseExpr) =>
IfThenElseExpr((ifCond, thenExpr) :: elseIfs, optElseExpr)
}
| stringLiteralAsString ^^ { StringConst(_) }
| double ^^ { DoubleConst(_) }
| integer ^^ { IntConst(_) }
| ("TRUE" | "FALSE") ^^ { x => BooleanConst(x.toBoolean) }
| "NULL" ^^ { _ => new NullConst }
| functionName ~ "(" ~ rep1sep(expr, ",") ~ ")" ^^ {
case (name ~ _ ~ args ~ _) => FuncExpr(name, args, (aggregationFunctions contains name))
}
| variableName ^^ { VarExpr(_) }
| "(" ~> expr <~ ")"
)

def elseIfExprs =
("else" ~> "if" ~> cond) ~ ("then" ~> expr) ^^ {
case (ifCond ~ thenExpr) => (ifCond, thenExpr)
}

// conditional expressions
def compareOperator = "LIKE" | ">" | "<" | ">=" | "<=" | "!=" | "=" | "IS NOT" | "IS"
def compareOperator = "LIKE" | ">" | "<" | ">=" | "<=" | "!=" | "=" | "IS" ~ "NOT" ^^ { _ => "IS NOT" } | "IS"

def cond : Parser[Cond] =
( acond ~ (";") ~ cond ^^ { case (lhs ~ op ~ rhs) =>
Expand All @@ -229,9 +245,7 @@ class DeepDiveLogParser extends JavaTokenParsers {
| bcond
)
def bcond : Parser[Cond] =
( expr ~ compareOperator ~ expr ^^ { case (lhs ~ op ~ rhs) =>
ComparisonCond(lhs, op, rhs)
}
( cexpr ^^ ExprCond
| "[" ~> cond <~ "]"
)

Expand Down Expand Up @@ -292,48 +306,52 @@ class DeepDiveLogParser extends JavaTokenParsers {
FunctionDeclaration(a, inTy, outTy, implementationDecls)
}

def cqBody: Parser[Body] = cond | quantifiedBody | atom
def cqBody: Parser[Body] = quantifiedBody | atom | cond

def cqConjunctiveBody: Parser[List[Body]] = rep1sep(cqBody, ",")

def cqHeadTerms = "(" ~> rep1sep(expr, ",") <~ ")"

def conjunctiveQueryBody : Parser[ConjunctiveQuery] =
opt("*") ~ opt("|" ~> decimalNumber) ~ ":-" ~ rep1sep(cqConjunctiveBody, ";") ^^ {
case (isDistinct ~ limit ~ ":-" ~ disjunctiveBodies) =>
ConjunctiveQuery(List.empty, disjunctiveBodies, isDistinct != None, limit map (_.toInt))
}

def conjunctiveQuery : Parser[ConjunctiveQuery] =
// TODO fill headTermAnnotations as done in queryWithOptionalHeadTerms to support @order_by
cqHeadTerms ~ opt("*") ~ opt("|" ~> decimalNumber) ~ ":-" ~ rep1sep(cqConjunctiveBody, ";") ^^ {
case (head ~ isDistinct ~ limit ~ ":-" ~ disjunctiveBodies) =>
ConjunctiveQuery(head, disjunctiveBodies, isDistinct != None, limit map (_.toInt))
}
cqHeadTerms ~ conjunctiveQueryBody ^^ {
case (head ~ cq) => cq.copy(headTerms = head)
}

def functionMode = "@mode" ~> commit("(" ~> functionModeType <~ ")" ^? ({
case "inc" => "inc"
}, (s) => s"${s}: unrecognized mode"))

def parallelism = "@parallelism" ~> "(" ~> integer <~ ")"

def supervision = "=" ~> (variableName | "TRUE" | "FALSE")

def conjunctiveQueryWithSupervision = // returns Parser[String], Parser[ConjunctiveQuery]
cqHeadTerms ~ opt("*") ~ opt("|" ~> decimalNumber) ~ opt(supervision) ~ ":-" ~ rep1sep(cqConjunctiveBody, ";") ^^ {
case (head ~ isDistinct ~ limit ~ sup ~ ":-" ~ disjunctiveBodies) =>
(sup, ConjunctiveQuery(head, disjunctiveBodies, isDistinct != None, limit map (_.toInt)))
}

def functionCallRule : Parser[FunctionCallRule] =
opt(functionMode) ~ opt(parallelism) ~ relationName ~ "+=" ~ functionName ~ conjunctiveQuery ^^ {
case (mode ~ parallelism ~ out ~ _ ~ func ~ cq) => FunctionCallRule(out, func, cq, mode, parallelism)
}

def oldstyleSupervision = "@label" ~> "(" ~> (variableName | "TRUE" | "FALSE") <~ ")"
def supervisionAnnotation = "@label" ~> "(" ~> expr <~ ")"

def conjunctiveQueryWithSupervision : Parser[ConjunctiveQuery] =
cqHeadTerms ~ opt("*") ~ opt("|" ~> decimalNumber) ~ ":-" ~ rep1sep(cqConjunctiveBody, ";") ^^ {
case (head ~ isDistinct ~ limit ~ ":-" ~ disjunctiveBodies) =>
ConjunctiveQuery(head, disjunctiveBodies, isDistinct != None, limit map (_.toInt))
}

def extractionRule =
( relationName ~ conjunctiveQueryWithSupervision ^^ {
case (head ~ cq) => ExtractionRule(head, cq._2, cq._1)
}
| oldstyleSupervision ~ relationName ~ conjunctiveQuery ^^ {
case (sup ~ head ~ cq) => ExtractionRule(head, cq, Some(sup))
}
)
( opt(supervisionAnnotation) ~ relationName ~ conjunctiveQuery ^^ {
case (sup ~ head ~ cq) => ExtractionRule(head, cq, sup)
}
| relationName ~ cqHeadTerms ~ ("=" ~> expr) ~ conjunctiveQueryBody ^^ {
case (head ~ headTerms ~ sup ~ cq) =>
ExtractionRule(head, cq.copy(headTerms = headTerms), Some(sup))
}
)

def factorWeight = "@weight" ~> "(" ~> rep1sep(expr, ",") <~ ")" ^^ { FactorWeight(_) }
def inferenceMode = "@mode" ~> commit("(" ~> inferenceModeType <~ ")" ^? ({
Expand Down

0 comments on commit fe75974

Please sign in to comment.