Skip to content

Commit

Permalink
refactor: use Vector in Fixpoint/Ram (flix#6035)
Browse files Browse the repository at this point in the history
  • Loading branch information
sockmaster27 committed Jun 7, 2023
1 parent e7cfcdc commit 538c0a6
Show file tree
Hide file tree
Showing 8 changed files with 78 additions and 81 deletions.
24 changes: 12 additions & 12 deletions main/src/library/Fixpoint/Compiler.flix
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ mod Fixpoint {
Map.insertWith(List.append, ruleStratum, rule :: Nil)
}, Map#{}, rules) |>
Map.forEach((_, s) -> compileStratum(stmts, s));
RamStmt.Seq(MutList.toList(stmts))
RamStmt.Seq(MutList.toVector(stmts))
}
case _ => bug!("Datalog normalization bug")
}
Expand Down Expand Up @@ -103,8 +103,9 @@ mod Fixpoint {
Map.mapWithKey(predSym -> match (arity, den) -> {
BoolExp.Empty(RamSym.Delta(predSym, arity, den))
}, idb) |>
Map.valuesOf;
let untilBody = RamStmt.Seq(MutList.toList(loopBody));
Map.valuesOf |>
List.toVector;
let untilBody = RamStmt.Seq(MutList.toVector(loopBody));
let fixpoint = RamStmt.Until(loopTest, untilBody);
MutList.push!(fixpoint, stmts)
}
Expand All @@ -127,8 +128,8 @@ mod Fixpoint {
case Constraint(HeadAtom(headSym, headDen, headTerms), body) =>
let augBody = augmentBody(body);
let env = unifyVars(augBody);
let ramTerms = Vector.toList(Vector.map(compileHeadTerm(env), headTerms));
let arity = List.length(ramTerms);
let ramTerms = Vector.map(compileHeadTerm(env), headTerms);
let arity = Vector.length(ramTerms);
let projection = RelOp.Project(ramTerms, RamSym.Full(headSym, arity, headDen));
let join = compileBody(env, augBody);
let loopBody = RelOp.If(join, projection);
Expand Down Expand Up @@ -184,10 +185,9 @@ mod Fixpoint {
let env = unifyVars(augBody);
let ramTerms = Vector.map(compileHeadTerm(env), headTerms);
let arity = Vector.length(ramTerms);
let ramTermsList = Vector.toList(ramTerms);
let projection = RelOp.Project(ramTermsList, RamSym.New(headSym, arity, headDen));
let projection = RelOp.Project(ramTerms, RamSym.New(headSym, arity, headDen));
let join = compileBody(env, augBody);
let loopBody = RelOp.If(BoolExp.NotMemberOf(ramTermsList, RamSym.Full(headSym, arity, headDen)) :: join, projection);
let loopBody = RelOp.If(Vector.append(Vector#{BoolExp.NotMemberOf(ramTerms, RamSym.Full(headSym, arity, headDen))}, join), projection);
let compile = delta -> { // `delta` designates the focused atom.
let insert =
Vector.foldRight(match (atom, rowVar) -> acc -> match atom {
Expand Down Expand Up @@ -322,8 +322,8 @@ mod Fixpoint {
/// (2) comes from the negative atom `not A(x)`.
/// (3) is a function call that computes the expression `x > 0`.
///
def compileBody(env: Map[VarSym, RamTerm[v]], body: Vector[(BodyPredicate[v], RowVar)]): List[BoolExp[v]] =
Vector.toList(Vector.flatMap(match (atom, rowVar) ->
def compileBody(env: Map[VarSym, RamTerm[v]], body: Vector[(BodyPredicate[v], RowVar)]): Vector[BoolExp[v]] =
Vector.flatMap(match (atom, rowVar) ->
let compileBodyTerm = j -> term -> match term {
case BodyTerm.Wild => RamTerm.RowLoad(rowVar, j)
case BodyTerm.Var(var) => unwrap(Map.get(var, env))
Expand All @@ -343,7 +343,7 @@ mod Fixpoint {
case BodyAtom(bodySym, denotation, Polarity.Negative, _, terms) =>
let ramTerms = Vector.mapWithIndex(compileBodyTerm, terms);
let arity = Vector.length(ramTerms);
Vector#{BoolExp.NotMemberOf(Vector.toList(ramTerms), RamSym.Full(bodySym, arity, denotation))}
Vector#{BoolExp.NotMemberOf(ramTerms, RamSym.Full(bodySym, arity, denotation))}
case Functional(_, _, _) => Vector.empty()
case Guard0(f) =>
Vector#{BoolExp.Guard0(f)}
Expand Down Expand Up @@ -372,7 +372,7 @@ mod Fixpoint {
let t4 = unwrap(Map.get(v4, env));
let t5 = unwrap(Map.get(v5, env));
Vector#{BoolExp.Guard5(f, t1, t2, t3, t4, t5)}
}, body))
}, body)

def unwrap(o: Option[a]): a = match o {
case Some(a) => a
Expand Down
53 changes: 25 additions & 28 deletions main/src/library/Fixpoint/IndexSelection.flix
Original file line number Diff line number Diff line change
Expand Up @@ -63,14 +63,14 @@ mod Fixpoint {
pub def queryStmt(stmt: RamStmt[v]): RamStmt[v] = match stmt {
case RamStmt.Insert(op) =>
let (innerOp, ground) = queryOp(op, Set#{});
if (List.isEmpty(ground))
if (Vector.isEmpty(ground))
RamStmt.Insert(innerOp)
else
RamStmt.Insert(RelOp.If(ground, innerOp))
case RamStmt.Merge(_, _) => stmt
case RamStmt.Assign(_, _) => stmt
case RamStmt.Purge(_) => stmt
case RamStmt.Seq(xs) => RamStmt.Seq(List.map(queryStmt, xs))
case RamStmt.Seq(xs) => RamStmt.Seq(Vector.map(queryStmt, xs))
case RamStmt.Until(test, body) => RamStmt.Until(test, queryStmt(body))
case RamStmt.Comment(_) => stmt
}
Expand All @@ -80,15 +80,15 @@ mod Fixpoint {
/// `freeVars` is the set of variables bound by an outer loop.
/// Returns the optimized op and the conditions that occur in `op` that have to be hoisted.
///
def queryOp(op: RelOp[v], freeVars: Set[RowVar]): (RelOp[v], List[BoolExp[v]]) = match op {
def queryOp(op: RelOp[v], freeVars: Set[RowVar]): (RelOp[v], Vector[BoolExp[v]]) = match op {
case RelOp.Search(var, ramSym, body) =>
use Fixpoint.Ram.BoolExp.Eq;
use Fixpoint.Ram.RamTerm.{RowLoad, Lit};
let (innerOp, innerGround) = queryOp(body, Set.insert(var, freeVars));
let (ground, notGround) = List.partition(isExpGround(freeVars), innerGround);
let (ground, notGround) = Vector.partition(isExpGround(freeVars), innerGround);
let (varQuery, rest1) =
// Make sure `var` is on the lhs of all equalities.
List.map(exp -> match exp {
Vector.map(exp -> match exp {
case Eq(RowLoad(row1, i), RowLoad(row2, j)) =>
if (row2 == var)
Eq(RowLoad(row2, j), RowLoad(row1, i))
Expand All @@ -98,72 +98,69 @@ mod Fixpoint {
case _ => exp
}, notGround) |>
// Partition into those equalities that have `var` on the lhs and those that don't.
List.partition(exp -> match exp {
Vector.partition(exp -> match exp {
case Eq(RowLoad(row1, _), RowLoad(row2, _)) => row1 != row2 and row1 == var
case Eq(RowLoad(row, _), Lit(_)) => row == var
case _ => false
});
let (prefixQuery, rest2) = longestPrefixQuery(varQuery);
let test = rest1 ::: rest2;
if (List.isEmpty(prefixQuery))
if (List.isEmpty(test))
let test = Vector.append(rest1, rest2);
if (Vector.isEmpty(prefixQuery))
if (Vector.isEmpty(test))
let search = RelOp.Search(var, ramSym, innerOp);
(search, ground)
else
let search = RelOp.Search(var, ramSym, RelOp.If(test, innerOp));
(search, ground)
else
let query =
List.map(x -> match x {
Vector.map(x -> match x {
case Eq(RamTerm.RowLoad(_, j), rhs) => (j, rhs)
case _ => ???
}, prefixQuery);
if (List.isEmpty(test))
if (Vector.isEmpty(test))
let search = RelOp.Query(var, ramSym, query, innerOp);
(search, ground)
else
let search = RelOp.Query(var, ramSym, query, RelOp.If(test, innerOp));
(search, ground)
case RelOp.Query(_) => (op, Nil)
case RelOp.Functional(_) => (op, Nil)
case RelOp.Project(_) => (op, Nil)
case RelOp.Query(_) => (op, Vector.empty())
case RelOp.Functional(_) => (op, Vector.empty())
case RelOp.Project(_) => (op, Vector.empty())
case RelOp.If(test, then) =>
let (innerOp, innerGround) = queryOp(then, freeVars);
(innerOp, test ::: innerGround)
(innerOp, Vector.append(test, innerGround))
}

def longestPrefixQuery(varQuery: List[BoolExp[v]]): (List[BoolExp[v]], List[BoolExp[v]]) =
def longestPrefixQuery(varQuery: Vector[BoolExp[v]]): (Vector[BoolExp[v]], Vector[BoolExp[v]]) =
use Fixpoint.Ram.BoolExp.Eq;
use Fixpoint.Ram.RamTerm.{RowLoad, Lit};
// Sort equalities of the form `var[i] = rhs` ascending on `i`.
List.sortWith(x -> y -> match (x, y) {
Vector.sortWith(x -> y -> match (x, y) {
case (Eq(RowLoad(_, index1), _), Eq(RowLoad(_, index2), _)) => Comparison.fromInt32(index1 - index2)
case _ => ???
}, varQuery) |>
// Group `var[i] = rhs` by `i`.
List.groupBy(x -> y -> match (x, y) {
Vector.groupBy(x -> y -> match (x, y) {
case (Eq(RowLoad(_, index1), _), Eq(RowLoad(_, index2), _)) => index1 == index2
case _ => ???
}) |>
prefixHelper(0, Nil, Nil)

def prefixHelper(i: Int32, prefix: List[BoolExp[v]], rest: List[BoolExp[v]], eqs: List[List[BoolExp[v]]]): (List[BoolExp[v]], List[BoolExp[v]]) =
match eqs {
case e :: es =>
let (p, r) = List.partition(be -> match be {
Vector.foldLeft(match (i, prefix, rest) -> e -> {
let (p, r) = Vector.partition(be -> match be {
case BoolExp.Eq(RamTerm.RowLoad(_, j), _) => i == j
case _ => ???
}, e);
prefixHelper(i + 1, prefix ::: p, rest ::: r, es)
case Nil => (prefix, rest)
}
(i + 1, Vector.append(prefix, p), Vector.append(rest, r))
}, (0, Vector.empty(), Vector.empty())) |>
match (_, prefix, rest) ->
(prefix, rest)

///
/// An expression is ground if all its terms are ground.
///
def isExpGround(freeVars: Set[RowVar], exp: BoolExp[v]): Bool = match exp {
case BoolExp.Empty(_) => true
case BoolExp.NotMemberOf(terms, _) => List.forAll(isTermGround(freeVars), terms)
case BoolExp.NotMemberOf(terms, _) => Vector.forAll(isTermGround(freeVars), terms)
case BoolExp.Eq(lhs, rhs) => isTermGround(freeVars, lhs) and isTermGround(freeVars, rhs)
case BoolExp.Leq(_, lhs, rhs) => isTermGround(freeVars, lhs) and isTermGround(freeVars, rhs)
case BoolExp.Guard0(_) => true
Expand Down
30 changes: 15 additions & 15 deletions main/src/library/Fixpoint/Interpreter.flix
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ mod Fixpoint {
case RamStmt.Assign(lhs, rhs) =>
MutMap.put!(lhs, MutMap.getWithDefault(rhs, MutMap.new(r), db), db)
case RamStmt.Purge(ramSym) => MutMap.remove!(ramSym, db)
case RamStmt.Seq(stmts) => List.forEach(evalStmt(r, db), stmts)
case RamStmt.Seq(stmts) => Vector.forEach(evalStmt(r, db), stmts)
case RamStmt.Until(test, body) =>
if (evalBoolExp(r, db, (Array#{} @ r, Array#{} @ r), test)) {
()
Expand Down Expand Up @@ -74,7 +74,7 @@ mod Fixpoint {
}, MutMap.getWithDefault(ramSym, MutMap.new(r1), db))
case RelOp.Query(RowVar.Index(i), ramSym, query, body) =>
let (tupleEnv, latEnv) = env;
MutMap.queryWith(evalQuery(env, query), t -> l -> {
MutMap.queryWith(evalQuery(env, Vector.toList(query)), t -> l -> {
Array.put(t, i, tupleEnv);
Array.put(l, i, latEnv);
evalOp(r1, db, env, body)
Expand All @@ -92,15 +92,14 @@ mod Fixpoint {
let rel = MutMap.getOrElsePut!(ramSym, MutMap.new(r1), db);
match toDenotation(ramSym) {
case Denotation.Relational =>
let tuple = List.toVector(List.map(evalTerm(env), terms));
let tuple = Vector.map(evalTerm(env), terms);
MutMap.put!(tuple, Reflect.default(), rel)
case Denotation.Latticenal(bot, leq, lub, _) =>
// assume that length(terms) > 0
let len = List.length(terms);
let keyList = terms |> List.map(evalTerm(env));
let (relKeys, latValList) = List.splitAt(len-1, keyList);
let key = List.toVector(relKeys);
let latVal = match List.head(latValList) {
let len = Vector.length(terms);
let keyList = terms |> Vector.map(evalTerm(env));
let (key, latValList) = Vector.splitAt(len - 1, keyList);
let latVal = match Vector.head(latValList) {
case None => bug!("Found predicate without terms")
case Some(k) => k
};
Expand All @@ -125,21 +124,22 @@ mod Fixpoint {
}
}

def evalBoolExp(r1: Region[r1], db: Database[v, r1], env: SearchEnv[v, r2], es: List[BoolExp[v]]): Bool \ { r1, r2 } with Order[v] =
List.forAll(exp -> match exp {
def evalBoolExp(r1: Region[r1], db: Database[v, r1], env: SearchEnv[v, r2], es: Vector[BoolExp[v]]): Bool \ { r1, r2 } with Order[v] =
Vector.forAll(exp -> match exp {
case BoolExp.Empty(ramSym) =>
MutMap.isEmpty(MutMap.getWithDefault(ramSym, MutMap.new(r1), db))
case BoolExp.NotMemberOf(terms, ramSym) =>
let rel = MutMap.getWithDefault(ramSym, MutMap.new(r1), db);
match toDenotation(ramSym) {
case Denotation.Relational =>
let tuple = List.toVector(List.map(evalTerm(env), terms));
let tuple = Vector.map(evalTerm(env), terms);
not MutMap.memberOf(tuple, rel)
case Denotation.Latticenal(bot, leq, _, _) =>
let len = List.length(terms);
let (keyTerms, latTermList) = terms |> List.map(evalTerm(env)) |> List.splitAt(len - 1);
let key = List.toVector(keyTerms);
let latTerm = match List.head(latTermList) {
let len = Vector.length(terms);
let evalTerms = Vector.map(evalTerm(env), terms);
let key = Vector.take(len - 1, evalTerms);
let latTerms = Vector.drop(len - 1, evalTerms);
let latTerm = match Vector.head(latTerms) {
case None => bug!("Found predicate without terms")
case Some(hd) => hd
};
Expand Down
4 changes: 2 additions & 2 deletions main/src/library/Fixpoint/Ram/BoolExp.flix
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ mod Fixpoint.Ram {
@Internal
pub enum BoolExp[v] {
case Empty(RamSym[v])
case NotMemberOf(List[RamTerm[v]], RamSym[v])
case NotMemberOf(Vector[RamTerm[v]], RamSym[v])
case Eq(RamTerm[v], RamTerm[v])
case Leq(v -> v -> Bool, RamTerm[v], RamTerm[v])
case Guard0(Unit -> Bool)
Expand All @@ -34,7 +34,7 @@ mod Fixpoint.Ram {
pub def toString(exp: BoolExp[v]): String =
match exp {
case BoolExp.Empty(ramSym) => "${ramSym} == ∅"
case BoolExp.NotMemberOf(terms, ramSym) => "(${terms |> List.join(", ")}) ∉ ${ramSym}"
case BoolExp.NotMemberOf(terms, ramSym) => "(${terms |> Vector.join(", ")}) ∉ ${ramSym}"
case BoolExp.Eq(lhs, rhs) => "${lhs} == ${rhs}"
case BoolExp.Leq(_, lhs, rhs) => "${lhs} ≤ ${rhs}"
case BoolExp.Guard0(_) => "<clo>()"
Expand Down
8 changes: 4 additions & 4 deletions main/src/library/Fixpoint/Ram/RamStmt.flix
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ mod Fixpoint.Ram {
case Merge(RamSym[v], RamSym[v])
case Assign(RamSym[v], RamSym[v])
case Purge(RamSym[v])
case Seq(List[RamStmt[v]])
case Until(List[BoolExp[v]], RamStmt[v])
case Seq(Vector[RamStmt[v]])
case Until(Vector[BoolExp[v]], RamStmt[v])
case Comment(String)
}

Expand All @@ -37,9 +37,9 @@ mod Fixpoint.Ram {
case RamStmt.Merge(src, dst) => "merge ${src} into ${dst}"
case RamStmt.Assign(lhs, rhs) => "${lhs} := ${rhs}"
case RamStmt.Purge(ramSym) => "purge ${ramSym}"
case RamStmt.Seq(xs) => List.join(";${nl}", xs)
case RamStmt.Seq(xs) => Vector.join(";${nl}", xs)
case RamStmt.Until(test, body) =>
let tst = test |> List.join(" Λ ");
let tst = test |> Vector.join(" Λ ");
"until(${tst}) do${nl}${String.indent(4, "${body}")}end"
case RamStmt.Comment(comment) => "// ${comment}"
}
Expand Down
12 changes: 6 additions & 6 deletions main/src/library/Fixpoint/Ram/RelOp.flix
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,10 @@ mod Fixpoint.Ram {
@Internal
pub enum RelOp[v] {
case Search(RowVar, RamSym[v], RelOp[v])
case Query(RowVar, RamSym[v], List[(Int32, RamTerm[v])], RelOp[v])
case Query(RowVar, RamSym[v], Vector[(Int32, RamTerm[v])], RelOp[v])
case Functional(RowVar, Vector[v] -> Vector[Vector[v]], Vector[RamTerm[v]], RelOp[v])
case Project(List[RamTerm[v]], RamSym[v])
case If(List[BoolExp[v]], RelOp[v])
case Project(Vector[RamTerm[v]], RamSym[v])
case If(Vector[BoolExp[v]], RelOp[v])
}

instance ToString[RelOp[v]] {
Expand All @@ -32,16 +32,16 @@ mod Fixpoint.Ram {
case RelOp.Search(var, ramSym, body) =>
"search ${var} ∈ ${ramSym} do${nl}${String.indent(4, "${body}")}end"
case RelOp.Query(var, ramSym, prefixQuery, body) =>
let query = List.joinWith(match (i, term) -> {
let query = Vector.joinWith(match (i, term) -> {
ToString.toString(BoolExp.Eq(RamTerm.RowLoad(var, i), term))
}, " ∧ ", prefixQuery);
"query {${var} ∈ ${ramSym} | ${query}} do${nl}${String.indent(4, "${body}")}end"
case RelOp.Functional(rowVar, _, terms, body) =>
"loop(${rowVar} <- f(${terms |> Vector.join(", ")})) do${nl}${String.indent(4, "${body}")}end"
case RelOp.Project(terms, ramSym) =>
"project (${terms |> List.join(", ")}) into ${ramSym}"
"project (${terms |> Vector.join(", ")}) into ${ramSym}"
case RelOp.If(test, then) =>
let tst = test |> List.join(" ∧ ");
let tst = test |> Vector.join(" ∧ ");
"if(${tst}) then${nl}${String.indent(4, "${then}")}end"
}
}
Expand Down
Loading

0 comments on commit 538c0a6

Please sign in to comment.