Skip to content

Commit

Permalink
Moved over bulk of provenance code to provenance translator.
Browse files Browse the repository at this point in the history
  • Loading branch information
azreika committed Nov 13, 2020
1 parent 0c6f5e3 commit e87cd96
Show file tree
Hide file tree
Showing 6 changed files with 400 additions and 321 deletions.
2 changes: 2 additions & 0 deletions src/Makefile.am
Expand Up @@ -203,6 +203,8 @@ souffle_sources = \
ast2ram/Location.h \
ast2ram/ProvenanceClauseTranslator.cpp \
ast2ram/ProvenanceClauseTranslator.h \
ast2ram/ProvenanceTranslator.cpp \
ast2ram/ProvenanceTranslator.h \
ast2ram/ValueIndex.cpp \
ast2ram/ValueIndex.h \
ast2ram/ValueTranslator.cpp \
Expand Down
292 changes: 3 additions & 289 deletions src/ast2ram/AstToRamTranslator.cpp
Expand Up @@ -130,7 +130,7 @@ AstToRamTranslator::AstToRamTranslator() = default;
AstToRamTranslator::~AstToRamTranslator() = default;

/** append statement to a list of statements */
inline void appendStmt(VecOwn<ram::Statement>& stmtList, Own<ram::Statement> stmt) {
void AstToRamTranslator::appendStmt(VecOwn<ram::Statement>& stmtList, Own<ram::Statement> stmt) {
if (stmt) {
stmtList.push_back(std::move(stmt));
}
Expand Down Expand Up @@ -383,9 +383,6 @@ VecOwn<ram::Statement> AstToRamTranslator::translateSCC(size_t scc, size_t idx)
const auto& internIns = sccGraph->getInternalInputRelations(scc);
const auto& internOuts = sccGraph->getInternalOutputRelations(scc);

// make a variable for all relations that are expired at the current SCC
const auto& internExps = relationSchedule->schedule().at(idx).expired();

// load all internal input relations from the facts dir with a .facts extension
for (const auto& relation : internIns) {
makeRamLoad(current, relation);
Expand All @@ -402,8 +399,8 @@ VecOwn<ram::Statement> AstToRamTranslator::translateSCC(size_t scc, size_t idx)
makeRamStore(current, relation);
}

// if provenance is disabled, drop all relations expired as per the topological order
if (!Global::config().has("provenance")) {
const auto& internExps = relationSchedule->schedule().at(idx).expired();
for (const auto& relation : internExps) {
makeRamClear(current, relation);
}
Expand All @@ -413,9 +410,7 @@ VecOwn<ram::Statement> AstToRamTranslator::translateSCC(size_t scc, size_t idx)
}

void AstToRamTranslator::addNegation(ast::Clause& clause, const ast::Atom* atom) {
if (Global::config().has("provenance")) {
clause.addToBody(mk<ast::ProvenanceNegation>(souffle::clone(atom)));
} else if (clause.getHead()->getArity() > 0) {
if (clause.getHead()->getArity() > 0) {
clause.addToBody(mk<ast::Negation>(souffle::clone(atom)));
}
}
Expand Down Expand Up @@ -625,262 +620,6 @@ Own<ram::Statement> AstToRamTranslator::translateRecursiveRelation(
return mk<ram::Sequence>(std::move(res));
}

/** make a subroutine to search for subproofs */
Own<ram::Statement> AstToRamTranslator::makeSubproofSubroutine(const ast::Clause& clause) {
auto intermediateClause = mk<ast::Clause>(souffle::clone(clause.getHead()));

// create a clone where all the constraints are moved to the end
for (auto bodyLit : clause.getBodyLiterals()) {
// first add all the things that are not constraints
if (!isA<ast::Constraint>(bodyLit)) {
intermediateClause->addToBody(souffle::clone(bodyLit));
}
}

// now add all constraints
for (auto bodyLit : ast::getBodyLiterals<ast::Constraint>(clause)) {
intermediateClause->addToBody(souffle::clone(bodyLit));
}

// name unnamed variables
nameUnnamedVariables(intermediateClause.get());

// add constraint for each argument in head of atom
ast::Atom* head = intermediateClause->getHead();
size_t auxiliaryArity = auxArityAnalysis->getArity(head);
auto args = head->getArguments();
for (size_t i = 0; i < head->getArity() - auxiliaryArity; i++) {
auto arg = args[i];

if (auto var = dynamic_cast<ast::Variable*>(arg)) {
// FIXME: float equiv (`FEQ`)
auto constraint = mk<ast::BinaryConstraint>(
BinaryConstraintOp::EQ, souffle::clone(var), mk<ast::SubroutineArgument>(i));
constraint->setFinalType(BinaryConstraintOp::EQ);
intermediateClause->addToBody(std::move(constraint));
} else if (auto func = dynamic_cast<ast::Functor*>(arg)) {
TypeAttribute returnType;
if (auto* inf = dynamic_cast<ast::IntrinsicFunctor*>(func)) {
assert(inf->getFinalReturnType().has_value() && "functor has missing return type");
returnType = inf->getFinalReturnType().value();
} else if (auto* udf = dynamic_cast<ast::UserDefinedFunctor*>(func)) {
assert(udf->getFinalReturnType().has_value() && "functor has missing return type");
returnType = udf->getFinalReturnType().value();
} else {
assert(false && "unexpected functor type");
}
auto opEq = returnType == TypeAttribute::Float ? BinaryConstraintOp::FEQ : BinaryConstraintOp::EQ;
auto constraint =
mk<ast::BinaryConstraint>(opEq, souffle::clone(func), mk<ast::SubroutineArgument>(i));
constraint->setFinalType(opEq);
intermediateClause->addToBody(std::move(constraint));
} else if (auto rec = dynamic_cast<ast::RecordInit*>(arg)) {
auto constraint = mk<ast::BinaryConstraint>(
BinaryConstraintOp::EQ, souffle::clone(rec), mk<ast::SubroutineArgument>(i));
constraint->setFinalType(BinaryConstraintOp::EQ);
intermediateClause->addToBody(std::move(constraint));
}
}

// index of level argument in argument list
size_t levelIndex = head->getArguments().size() - auxiliaryArity;

// add level constraints, i.e., that each body literal has height less than that of the head atom
const auto& bodyLiterals = intermediateClause->getBodyLiterals();
for (auto lit : bodyLiterals) {
if (auto atom = dynamic_cast<ast::Atom*>(lit)) {
auto arity = atom->getArity();
auto atomArgs = atom->getArguments();
// arity - 1 is the level number in body atoms
auto constraint = mk<ast::BinaryConstraint>(BinaryConstraintOp::LT,
souffle::clone(atomArgs[arity - 1]), mk<ast::SubroutineArgument>(levelIndex));
constraint->setFinalType(BinaryConstraintOp::LT);
intermediateClause->addToBody(std::move(constraint));
}
}
return ProvenanceClauseTranslator(*this).translateClause(*intermediateClause, clause);
}

/** make a subroutine to search for subproofs for the non-existence of a tuple */
Own<ram::Statement> AstToRamTranslator::makeNegationSubproofSubroutine(const ast::Clause& clause) {
// TODO (taipan-snake): Currently we only deal with atoms (no constraints or negations or aggregates
// or anything else...)
//
// The resulting subroutine looks something like this:
// IF (arg(0), arg(1), _, _) IN rel_1:
// return 1
// IF (arg(0), arg(1), _ ,_) NOT IN rel_1:
// return 0
// ...

// clone clause for mutation, rearranging constraints to be at the end
auto clauseReplacedAggregates = mk<ast::Clause>(souffle::clone(clause.getHead()));

// create a clone where all the constraints are moved to the end
for (auto bodyLit : clause.getBodyLiterals()) {
// first add all the things that are not constraints
if (!isA<ast::Constraint>(bodyLit)) {
clauseReplacedAggregates->addToBody(souffle::clone(bodyLit));
}
}

// now add all constraints
for (auto bodyLit : ast::getBodyLiterals<ast::Constraint>(clause)) {
clauseReplacedAggregates->addToBody(souffle::clone(bodyLit));
}

struct AggregatesToVariables : public ast::NodeMapper {
mutable int aggNumber{0};

AggregatesToVariables() = default;

Own<ast::Node> operator()(Own<ast::Node> node) const override {
if (dynamic_cast<ast::Aggregator*>(node.get()) != nullptr) {
return mk<ast::Variable>("agg_" + std::to_string(aggNumber++));
}

node->apply(*this);
return node;
}
};

AggregatesToVariables aggToVar;
clauseReplacedAggregates->apply(aggToVar);

// build a vector of unique variables
std::vector<const ast::Variable*> uniqueVariables;

visitDepthFirst(*clauseReplacedAggregates, [&](const ast::Variable& var) {
if (var.getName().find("@level_num") == std::string::npos) {
// use find_if since uniqueVariables stores pointers, and we need to dereference the pointer to
// check equality
if (std::find_if(uniqueVariables.begin(), uniqueVariables.end(),
[&](const ast::Variable* v) { return *v == var; }) == uniqueVariables.end()) {
uniqueVariables.push_back(&var);
}
}
});

// a mapper to replace variables with subroutine arguments
struct VariablesToArguments : public ast::NodeMapper {
const std::vector<const ast::Variable*>& uniqueVariables;

VariablesToArguments(const std::vector<const ast::Variable*>& uniqueVariables)
: uniqueVariables(uniqueVariables) {}

Own<ast::Node> operator()(Own<ast::Node> node) const override {
// replace unknown variables
if (auto varPtr = dynamic_cast<const ast::Variable*>(node.get())) {
if (varPtr->getName().find("@level_num") == std::string::npos) {
size_t argNum = std::find_if(uniqueVariables.begin(), uniqueVariables.end(),
[&](const ast::Variable* v) { return *v == *varPtr; }) -
uniqueVariables.begin();

return mk<ast::SubroutineArgument>(argNum);
} else {
return mk<ast::UnnamedVariable>();
}
}

// apply recursive
node->apply(*this);

// otherwise nothing
return node;
}
};

auto makeRamAtomExistenceCheck = [&](ast::Atom* atom) {
auto relName = getConcreteRelationName(atom);
size_t auxiliaryArity = auxArityAnalysis->getArity(atom);

// translate variables to subroutine arguments
VariablesToArguments varsToArgs(uniqueVariables);
atom->apply(varsToArgs);

// construct a query
VecOwn<ram::Expression> query;
auto atomArgs = atom->getArguments();

// add each value (subroutine argument) to the search query
for (size_t i = 0; i < atom->getArity() - auxiliaryArity; i++) {
auto arg = atomArgs[i];
query.push_back(translateValue(arg, ValueIndex()));
}

// fill up query with nullptrs for the provenance columns
for (size_t i = 0; i < auxiliaryArity; i++) {
query.push_back(mk<ram::UndefValue>());
}

// ensure the length of query tuple is correct
assert(query.size() == atom->getArity() && "wrong query tuple size");

// create existence checks to check if the tuple exists or not
return mk<ram::ExistenceCheck>(relName, std::move(query));
};

auto makeRamReturnTrue = [&]() {
VecOwn<ram::Expression> returnTrue;
returnTrue.push_back(mk<ram::SignedConstant>(1));
return mk<ram::SubroutineReturn>(std::move(returnTrue));
};

auto makeRamReturnFalse = [&]() {
VecOwn<ram::Expression> returnFalse;
returnFalse.push_back(mk<ram::SignedConstant>(0));
return mk<ram::SubroutineReturn>(std::move(returnFalse));
};

// the structure of this subroutine is a sequence where each nested statement is a search in each
// relation
VecOwn<ram::Statement> searchSequence;

// make a copy so that when we mutate clause, pointers to objects in newClause are not affected
auto newClause = souffle::clone(clauseReplacedAggregates);

// go through each body atom and create a return
size_t litNumber = 0;
for (const auto& lit : newClause->getBodyLiterals()) {
if (auto atom = dynamic_cast<ast::Atom*>(lit)) {
auto existenceCheck = makeRamAtomExistenceCheck(atom);
auto negativeExistenceCheck = mk<ram::Negation>(souffle::clone(existenceCheck));

// create a ram::Query to return true/false
appendStmt(searchSequence,
mk<ram::Query>(mk<ram::Filter>(std::move(existenceCheck), makeRamReturnTrue())));
appendStmt(searchSequence,
mk<ram::Query>(mk<ram::Filter>(std::move(negativeExistenceCheck), makeRamReturnFalse())));
} else if (auto neg = dynamic_cast<ast::Negation*>(lit)) {
auto atom = neg->getAtom();
auto existenceCheck = makeRamAtomExistenceCheck(atom);
auto negativeExistenceCheck = mk<ram::Negation>(souffle::clone(existenceCheck));

// create a ram::Query to return true/false
appendStmt(searchSequence,
mk<ram::Query>(mk<ram::Filter>(std::move(existenceCheck), makeRamReturnFalse())));
appendStmt(searchSequence,
mk<ram::Query>(mk<ram::Filter>(std::move(negativeExistenceCheck), makeRamReturnTrue())));
} else if (auto con = dynamic_cast<ast::Constraint*>(lit)) {
VariablesToArguments varsToArgs(uniqueVariables);
con->apply(varsToArgs);

// translate to a ram::Condition
auto condition = translateConstraint(con, ValueIndex());
auto negativeCondition = mk<ram::Negation>(souffle::clone(condition));

appendStmt(searchSequence,
mk<ram::Query>(mk<ram::Filter>(std::move(condition), makeRamReturnTrue())));
appendStmt(searchSequence,
mk<ram::Query>(mk<ram::Filter>(std::move(negativeCondition), makeRamReturnFalse())));
}

litNumber++;
}

return mk<ram::Sequence>(std::move(searchSequence));
}

bool AstToRamTranslator::removeADTs(const ast::TranslationUnit& translationUnit) {
struct ADTsFuneral : public ast::NodeMapper {
mutable bool changed{false};
Expand Down Expand Up @@ -1016,26 +755,6 @@ void AstToRamTranslator::createRamRelation(size_t scc) {
}
}

void AstToRamTranslator::addProvenanceClauseSubroutines(const ast::Program* program) {
visitDepthFirst(*program, [&](const ast::Clause& clause) {
std::stringstream relName;
relName << clause.getHead()->getQualifiedName();

// do not add subroutines for info relations or facts
if (relName.str().find("@info") != std::string::npos || clause.getBodyLiterals().empty()) {
return;
}

std::string subroutineLabel =
relName.str() + "_" + std::to_string(getClauseNum(program, &clause)) + "_subproof";
ramSubs[subroutineLabel] = makeSubproofSubroutine(clause);

std::string negationSubroutineLabel =
relName.str() + "_" + std::to_string(getClauseNum(program, &clause)) + "_negation_subproof";
ramSubs[negationSubroutineLabel] = makeNegationSubproofSubroutine(clause);
});
}

/** translates the given datalog program into an equivalent RAM program */
void AstToRamTranslator::translateProgram(const ast::TranslationUnit& translationUnit) {
// keep track of relevant analyses
Expand Down Expand Up @@ -1111,11 +830,6 @@ void AstToRamTranslator::translateProgram(const ast::TranslationUnit& translatio

// done for main prog
ramMain = mk<ram::Sequence>(std::move(res));

// add subroutines for each clause
if (Global::config().has("provenance")) {
addProvenanceClauseSubroutines(program);
}
}

Own<ram::TranslationUnit> AstToRamTranslator::translateUnit(ast::TranslationUnit& tu) {
Expand Down

0 comments on commit e87cd96

Please sign in to comment.