Skip to content

Commit

Permalink
First stage of modularising recursive-relation translation.
Browse files Browse the repository at this point in the history
  • Loading branch information
azreika committed Nov 17, 2020
1 parent da4a969 commit 26bfdd9
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 117 deletions.
224 changes: 115 additions & 109 deletions src/ast2ram/AstToRamTranslator.cpp
Expand Up @@ -339,7 +339,7 @@ VecOwn<ram::Statement> AstToRamTranslator::clearExpiredRelations(
return stmts;
}

void AstToRamTranslator::addNegation(ast::Clause& clause, const ast::Atom* atom) {
void AstToRamTranslator::addNegation(ast::Clause& clause, const ast::Atom* atom) const {
if (clause.getHead()->getArity() > 0) {
clause.addToBody(mk<ast::Negation>(souffle::clone(atom)));
}
Expand Down Expand Up @@ -368,140 +368,141 @@ Own<ram::Statement> AstToRamTranslator::mergeRelations(
return stmt;
}

VecOwn<ram::Statement> AstToRamTranslator::createRecursiveClauseVersions(
const std::set<const ast::Relation*>& scc, const ast::Relation* rel) {
assert(contains(scc, rel) && "relation should belong to scc");
VecOwn<ram::Statement> loopRelSeq;

/* Find clauses for relation rel */
for (const auto& cl : relDetail->getClauses(rel->getQualifiedName())) {
// skip non-recursive clauses
if (!recursiveClauses->recursive(cl)) {
continue;
}

// each recursive rule results in several operations
int version = 0;
const auto& atoms = ast::getBodyLiterals<ast::Atom>(*cl);
for (size_t j = 0; j < atoms.size(); ++j) {
const ast::Atom* atom = atoms[j];
const ast::Relation* atomRelation = getAtomRelation(atom, program);

// only interested in atoms within the same SCC
if (!contains(scc, atomRelation)) {
continue;
}

// modify the processed rule to use delta relation and write to new relation
auto r1 = souffle::clone(cl);
r1->getHead()->setQualifiedName(getNewRelationName(rel));
ast::getBodyLiterals<ast::Atom>(*r1)[j]->setQualifiedName(getDeltaRelationName(atomRelation));
addNegation(*r1, cl->getHead());

// replace wildcards with variables to reduce indices
nameUnnamedVariables(r1.get());

// reduce R to P ...
for (size_t k = j + 1; k < atoms.size(); k++) {
if (contains(scc, getAtomRelation(atoms[k], program))) {
auto cur = souffle::clone(ast::getBodyLiterals<ast::Atom>(*r1)[k]);
cur->setQualifiedName(getDeltaRelationName(getAtomRelation(atoms[k], program)));
r1->addToBody(mk<ast::Negation>(std::move(cur)));
}
}

Own<ram::Statement> rule = ClauseTranslator(*this).translateClause(*r1, *cl, version);

// add loging
if (Global::config().has("profile")) {
const std::string& relationName = toString(rel->getQualifiedName());
const auto& srcLocation = cl->getSrcLoc();
const std::string clauseText = stringify(toString(*cl));
const std::string logTimerStatement =
LogStatement::tRecursiveRule(relationName, version, srcLocation, clauseText);
const std::string logSizeStatement =
LogStatement::nRecursiveRule(relationName, version, srcLocation, clauseText);
rule = mk<ram::LogRelationTimer>(std::move(rule), logTimerStatement, getNewRelationName(rel));
}

// add debug info
std::ostringstream ds;
ds << toString(*cl) << "\nin file ";
ds << cl->getSrcLoc();
rule = mk<ram::DebugInfo>(std::move(rule), ds.str());

// add to loop body
appendStmt(loopRelSeq, std::move(rule));

// increment version counter
version++;
}

// check that the correct number of versions have been created
if (cl->getExecutionPlan() != nullptr) {
int maxVersion = -1;
for (auto const& cur : cl->getExecutionPlan()->getOrders()) {
maxVersion = std::max(cur.first, maxVersion);
}
assert(version > maxVersion && "missing clause versions");
}
}

return loopRelSeq;
}

/** generate RAM code for recursive relations in a strongly-connected component */
Own<ram::Statement> AstToRamTranslator::translateRecursiveRelation(
const std::set<const ast::Relation*>& scc) {
// initialize sections
// -- Initialise all the individual sections --
VecOwn<ram::Statement> preamble;
VecOwn<ram::Statement> loopSeq;
VecOwn<ram::Statement> updateTable;
VecOwn<ram::Statement> postamble;

// --- create preamble ---
// Generate preamble
for (const ast::Relation* rel : scc) {
// Generate code for the non-recursive part of the relation */
appendStmt(preamble, translateNonRecursiveRelation(*rel));

/* Compute non-recursive clauses for relations in scc and push
the results in their delta tables. */
// Copy the result into the delta relation
appendStmt(preamble, mergeRelations(rel, getDeltaRelationName(rel), getConcreteRelationName(rel)));
}

// Generate in-between table updates
for (const ast::Relation* rel : scc) {
/* create update statements for fixpoint (even iteration) */
// Copy @new into main relation, @delta := @new, and empty out @new
Own<ram::Statement> updateRelTable =
mk<ram::Sequence>(mergeRelations(rel, getConcreteRelationName(rel), getNewRelationName(rel)),
mk<ram::Swap>(getDeltaRelationName(rel), getNewRelationName(rel)),
mk<ram::Clear>(getNewRelationName(rel)));

/* measure update time for each relation */
// Measure update time
if (Global::config().has("profile")) {
updateRelTable = mk<ram::LogRelationTimer>(std::move(updateRelTable),
LogStatement::cRecursiveRelation(toString(rel->getQualifiedName()), rel->getSrcLoc()),
getNewRelationName(rel));
}

/* drop temporary tables after recursion */
appendStmt(postamble, mk<ram::Clear>(getDeltaRelationName(rel)));
appendStmt(postamble, mk<ram::Clear>(getNewRelationName(rel)));

/* Generate code for non-recursive part of relation */
/* Generate merge operation for temp tables */
appendStmt(preamble, translateNonRecursiveRelation(*rel));
appendStmt(preamble, mergeRelations(rel, getDeltaRelationName(rel), getConcreteRelationName(rel)));

/* Add update operations of relations to parallel statements */
appendStmt(updateTable, std::move(updateRelTable));
}

// --- build main loop ---

VecOwn<ram::Statement> loopSeq;

// create a utility to check SCC membership
auto isInSameSCC = [&](const ast::Relation* rel) {
return std::find(scc.begin(), scc.end(), rel) != scc.end();
};

/* Compute temp for the current tables */
// Generate postamble
for (const ast::Relation* rel : scc) {
VecOwn<ram::Statement> loopRelSeq;

/* Find clauses for relation rel */
for (const auto& cl : relDetail->getClauses(rel->getQualifiedName())) {
// skip non-recursive clauses
if (!recursiveClauses->recursive(cl)) {
continue;
}

// each recursive rule results in several operations
int version = 0;
const auto& atoms = ast::getBodyLiterals<ast::Atom>(*cl);
for (size_t j = 0; j < atoms.size(); ++j) {
const ast::Atom* atom = atoms[j];
const ast::Relation* atomRelation = getAtomRelation(atom, program);

// only interested in atoms within the same SCC
if (!isInSameSCC(atomRelation)) {
continue;
}

// modify the processed rule to use delta relation and write to new relation
auto r1 = souffle::clone(cl);
r1->getHead()->setQualifiedName(getNewRelationName(rel));
ast::getBodyLiterals<ast::Atom>(*r1)[j]->setQualifiedName(getDeltaRelationName(atomRelation));
addNegation(*r1, cl->getHead());

// replace wildcards with variables (reduces indices when wildcards are used in recursive
// atoms)
nameUnnamedVariables(r1.get());

// reduce R to P ...
for (size_t k = j + 1; k < atoms.size(); k++) {
if (isInSameSCC(getAtomRelation(atoms[k], program))) {
auto cur = souffle::clone(ast::getBodyLiterals<ast::Atom>(*r1)[k]);
cur->setQualifiedName(getDeltaRelationName(getAtomRelation(atoms[k], program)));
r1->addToBody(mk<ast::Negation>(std::move(cur)));
}
}

Own<ram::Statement> rule = ClauseTranslator(*this).translateClause(*r1, *cl, version);

/* add logging */
if (Global::config().has("profile")) {
const std::string& relationName = toString(rel->getQualifiedName());
const auto& srcLocation = cl->getSrcLoc();
const std::string clauseText = stringify(toString(*cl));
const std::string logTimerStatement =
LogStatement::tRecursiveRule(relationName, version, srcLocation, clauseText);
const std::string logSizeStatement =
LogStatement::nRecursiveRule(relationName, version, srcLocation, clauseText);
rule = mk<ram::LogRelationTimer>(
std::move(rule), logTimerStatement, getNewRelationName(rel));
}

// add debug info
std::ostringstream ds;
ds << toString(*cl) << "\nin file ";
ds << cl->getSrcLoc();
rule = mk<ram::DebugInfo>(std::move(rule), ds.str());

// add to loop body
appendStmt(loopRelSeq, std::move(rule));

// increment version counter
version++;
}
// Drop temporary tables after recursion
appendStmt(postamble, mk<ram::Clear>(getDeltaRelationName(rel)));
appendStmt(postamble, mk<ram::Clear>(getNewRelationName(rel)));
}

if (cl->getExecutionPlan() != nullptr) {
// ensure that all required versions have been created, as expected
int maxVersion = -1;
for (auto const& cur : cl->getExecutionPlan()->getOrders()) {
maxVersion = std::max(cur.first, maxVersion);
}
assert(version > maxVersion && "missing clause versions");
}
}
// Generate the main loop
for (const ast::Relation* rel : scc) {
auto loopRelSeq = createRecursiveClauseVersions(scc, rel);

// if there was no rule, continue
// if there were no rules, continue
if (loopRelSeq.empty()) {
continue;
}

// label all versions
// add profiling information
if (Global::config().has("profile")) {
const std::string& relationName = toString(rel->getQualifiedName());
const auto& srcLocation = rel->getSrcLoc();
Expand All @@ -513,12 +514,11 @@ Own<ram::Statement> AstToRamTranslator::translateRecursiveRelation(
appendStmt(loopRelSeq, std::move(newStmt));
}

/* add rule computations of a relation to parallel statement */
appendStmt(loopSeq, mk<ram::Sequence>(std::move(loopRelSeq)));
}
auto loop = mk<ram::Parallel>(std::move(loopSeq));

/* construct exit conditions for odd and even iteration */
// --- Combine the individual sections into the final fixpoint loop --
// Construct exit conditions for odd and even iteration
auto addCondition = [](Own<ram::Condition>& cond, Own<ram::Condition> clause) {
cond = ((cond) ? mk<ram::Conjunction>(std::move(cond), std::move(clause)) : std::move(clause));
};
Expand All @@ -535,11 +535,15 @@ Own<ram::Statement> AstToRamTranslator::translateRecursiveRelation(
}
}

/* construct fixpoint loop */
VecOwn<ram::Statement> res;

// Add in the preamble
if (!preamble.empty()) {
appendStmt(res, mk<ram::Sequence>(std::move(preamble)));
}

// Add in the main loop and update sections
auto loop = mk<ram::Parallel>(std::move(loopSeq));
if (!loop->getStatements().empty() && exitCond && !updateTable.empty()) {
auto ramExitCondition = mk<ram::Exit>(std::move(exitCond));
auto ramExitSequence = mk<ram::Sequence>(std::move(exitStmts));
Expand All @@ -548,6 +552,8 @@ Own<ram::Statement> AstToRamTranslator::translateRecursiveRelation(
std::move(ramExitSequence), std::move(ramUpdateSequence)));
appendStmt(res, std::move(ramLoopSequence));
}

// Add in the postamble
if (!postamble.empty()) {
appendStmt(res, mk<ram::Sequence>(std::move(postamble)));
}
Expand Down
13 changes: 7 additions & 6 deletions src/ast2ram/AstToRamTranslator.h
Expand Up @@ -86,18 +86,16 @@ class AstToRamTranslator {
return sipsMetric.get();
}

size_t getEvaluationArity(const ast::Atom* atom) const;
const ram::Relation* lookupRelation(const std::string& name) const;

/** AST->RAM translation methods */
Own<ram::TranslationUnit> translateUnit(ast::TranslationUnit& tu);
Own<ram::Expression> translateValue(const ast::Argument* arg, const ValueIndex& index) const;
Own<ram::Condition> translateConstraint(const ast::Literal* arg, const ValueIndex& index);
Own<ram::Expression> translateConstant(const ast::Constant& c);
virtual Own<ram::Sequence> translateProgram(const ast::TranslationUnit& translationUnit);

/** determine the auxiliary for relations */
size_t getEvaluationArity(const ast::Atom* atom) const;

const ram::Relation* lookupRelation(const std::string& name) const;

protected:
const ast::Program* program = nullptr;
Own<ast::SipsMetric> sipsMetric;
Expand All @@ -119,7 +117,7 @@ class AstToRamTranslator {
* Translation methods
*/
Own<ram::Sequence> translateSCC(size_t scc, size_t idx);
virtual void addNegation(ast::Clause& clause, const ast::Atom* atom);
virtual void addNegation(ast::Clause& clause, const ast::Atom* atom) const;
virtual VecOwn<ram::Statement> clearExpiredRelations(
const std::set<const ast::Relation*>& expiredRelations) const;
RamDomain getConstantRamRepresentation(const ast::Constant& constant);
Expand Down Expand Up @@ -160,6 +158,9 @@ class AstToRamTranslator {

Own<ram::Statement> mergeRelations(
const ast::Relation* rel, const std::string& destRelation, const std::string& srcRelation) const;

VecOwn<ram::Statement> createRecursiveClauseVersions(
const std::set<const ast::Relation*>& scc, const ast::Relation* rel);
};

} // namespace souffle::ast2ram
2 changes: 1 addition & 1 deletion src/ast2ram/ProvenanceTranslator.cpp
Expand Up @@ -53,7 +53,7 @@ VecOwn<ram::Statement> ProvenanceTranslator::clearExpiredRelations(
return {};
}

void ProvenanceTranslator::addNegation(ast::Clause& clause, const ast::Atom* atom) {
void ProvenanceTranslator::addNegation(ast::Clause& clause, const ast::Atom* atom) const {
clause.addToBody(mk<ast::ProvenanceNegation>(souffle::clone(atom)));
}

Expand Down
2 changes: 1 addition & 1 deletion src/ast2ram/ProvenanceTranslator.h
Expand Up @@ -25,7 +25,7 @@ class ProvenanceTranslator : public AstToRamTranslator {

protected:
Own<ram::Sequence> translateProgram(const ast::TranslationUnit& translationUnit) override;
void addNegation(ast::Clause& clause, const ast::Atom* atom) override;
void addNegation(ast::Clause& clause, const ast::Atom* atom) const override;
VecOwn<ram::Statement> clearExpiredRelations(
const std::set<const ast::Relation*>& expiredRelations) const override;

Expand Down

0 comments on commit 26bfdd9

Please sign in to comment.