Skip to content

Commit

Permalink
Merge pull request #1577 from azreika/min-aggr
Browse files Browse the repository at this point in the history
Extended program minimiser to support clauses with aggregators.
  • Loading branch information
b-scholz committed Aug 11, 2020
2 parents 3a60ce4 + 70be9bd commit 01fab76
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 13 deletions.
66 changes: 53 additions & 13 deletions src/ast/transform/MinimiseProgram.cpp
Expand Up @@ -16,6 +16,7 @@

#include "ast/transform/MinimiseProgram.h"
#include "BinaryConstraintOps.h"
#include "ast/Aggregator.h"
#include "ast/Argument.h"
#include "ast/Atom.h"
#include "ast/BinaryConstraint.h"
Expand Down Expand Up @@ -62,7 +63,7 @@ class MinimiseProgramTransformer::NormalisedClauseRepr {

NormalisedClauseRepr(const AstClause* clause) {
// head
AstQualifiedName name("min:head");
AstQualifiedName name("@min:head");
std::vector<std::string> headVars;
for (const auto* arg : clause->getHead()->getArguments()) {
headVars.push_back(normaliseArgument(arg));
Expand All @@ -71,7 +72,7 @@ class MinimiseProgramTransformer::NormalisedClauseRepr {

// body
for (const auto* lit : clause->getBodyLiterals()) {
addClauseBodyLiteral(lit);
addClauseBodyLiteral("@min:scope:0", lit);
}
}

Expand All @@ -93,19 +94,20 @@ class MinimiseProgramTransformer::NormalisedClauseRepr {

private:
bool fullyNormalised{true};
size_t aggrScopeCount{0};
std::set<std::string> variables{};
std::set<std::string> constants{};
std::vector<NormalisedClauseElementRepr> clauseElements;
std::vector<NormalisedClauseElementRepr> clauseElements{};

/**
* Parse an atom with a preset name qualifier into the element list.
*/
void addClauseAtom(std::string qualifier, const AstAtom* atom);
void addClauseAtom(const std::string& qualifier, const std::string& scopeID, const AstAtom* atom);

/**
* Parse a body literal into the element list.
*/
void addClauseBodyLiteral(const AstLiteral* lit);
void addClauseBodyLiteral(const std::string& scopeID, const AstLiteral* lit);

/**
* Return a normalised string repr of an argument.
Expand All @@ -114,33 +116,39 @@ class MinimiseProgramTransformer::NormalisedClauseRepr {
};

void MinimiseProgramTransformer::NormalisedClauseRepr::addClauseAtom(
std::string qualifier, const AstAtom* atom) {
const std::string& qualifier, const std::string& scopeID, const AstAtom* atom) {
AstQualifiedName name(atom->getQualifiedName());
name.prepend(qualifier);

std::vector<std::string> vars;
vars.push_back(scopeID);
for (const auto* arg : atom->getArguments()) {
vars.push_back(normaliseArgument(arg));
}
clauseElements.push_back({.name = name, .params = vars});
}

void MinimiseProgramTransformer::NormalisedClauseRepr::addClauseBodyLiteral(const AstLiteral* lit) {
void MinimiseProgramTransformer::NormalisedClauseRepr::addClauseBodyLiteral(
const std::string& scopeID, const AstLiteral* lit) {
if (const auto* atom = dynamic_cast<const AstAtom*>(lit)) {
addClauseAtom("@min:atom", atom);
addClauseAtom("@min:atom", scopeID, atom);
} else if (const auto* neg = dynamic_cast<const AstNegation*>(lit)) {
addClauseAtom("@min:neg", neg->getAtom());
addClauseAtom("@min:neg", scopeID, neg->getAtom());
} else if (const auto* bc = dynamic_cast<const AstBinaryConstraint*>(lit)) {
AstQualifiedName name(toBinaryConstraintSymbol(bc->getOperator()));
name.prepend("@min:operator");
std::vector<std::string> vars;
vars.push_back(scopeID);
vars.push_back(normaliseArgument(bc->getLHS()));
vars.push_back(normaliseArgument(bc->getRHS()));
clauseElements.push_back({.name = name, .params = vars});
} else {
assert(lit != nullptr && "unexpected nullptr lit");
fullyNormalised = false;
std::stringstream qualifier;
qualifier << "@min:unhandled:lit:" << scopeID;
AstQualifiedName name(toString(*lit));
name.prepend("@min:unhandled:lit");
name.prepend(qualifier.str());
clauseElements.push_back({.name = name, .params = std::vector<std::string>()});
}
}
Expand Down Expand Up @@ -169,6 +177,39 @@ std::string MinimiseProgramTransformer::NormalisedClauseRepr::normaliseArgument(
name << "@min:unnamed:" << countUnnamed++;
variables.insert(name.str());
return name.str();
} else if (auto* aggr = dynamic_cast<const AstAggregator*>(arg)) {
// Set the scope to uniquely identify the aggregator
std::stringstream scopeID;
scopeID << "@min:scope:" << ++aggrScopeCount;
variables.insert(scopeID.str());

// Set the type signature of this aggregator
std::stringstream aggrTypeSignature;
aggrTypeSignature << "@min:aggrtype";
std::vector<std::string> aggrTypeSignatureComponents;

// - the operator is fixed and cannot be changed
aggrTypeSignature << ":" << aggr->getOperator();

// - the scope can be remapped as a variable
aggrTypeSignatureComponents.push_back(scopeID.str());

// - the normalised target expression can be remapped as a variable
if (aggr->getTargetExpression() != nullptr) {
std::string normalisedExpr = normaliseArgument(aggr->getTargetExpression());
aggrTypeSignatureComponents.push_back(normalisedExpr);
}

// Type signature is its own special atom
clauseElements.push_back({.name = aggrTypeSignature.str(), .params = aggrTypeSignatureComponents});

// Add each contained normalised clause literal, tying it with the new scope ID
for (const auto* literal : aggr->getBodyLiterals()) {
addClauseBodyLiteral(scopeID.str(), literal);
}

// Aggregator identified by the scope ID
return scopeID.str();
} else {
fullyNormalised = false;
return "@min:unhandled:arg";
Expand Down Expand Up @@ -353,9 +394,8 @@ bool MinimiseProgramTransformer::areBijectivelyEquivalent(
}

// create permutation matrix
permutationMatrix[0][0] = 1;
for (size_t i = 1; i < size; i++) {
for (size_t j = 1; j < size; j++) {
for (size_t i = 0; i < size; i++) {
for (size_t j = 0; j < size; j++) {
if (leftElements[i].name == rightElements[j].name) {
permutationMatrix[i][j] = 1;
}
Expand Down
57 changes: 57 additions & 0 deletions src/tests/ast_transformers_test.cpp
Expand Up @@ -319,6 +319,63 @@ TEST(AstTransformers, CheckClausalEquivalence) {
toString(*cMinClauses[1]));
}

/**
* Test the equivalence (or lack of equivalence) of aggregators using the MinimiseProgramTransfomer.
*/
TEST(AstTransformers, CheckAggregatorEquivalence) {
ErrorReport errorReport;
DebugReport debugReport;

std::unique_ptr<AstTranslationUnit> tu = ParserDriver::parseTranslationUnit(
R"(
.decl A,B,C,D(X:number) input
// first and second are equivalent
D(X) :-
B(X),
X < max Y : { C(Y), B(Y), Y < 2 },
A(Z),
Z = sum A : { C(A), B(A), A > count : { A(M), C(M) } }.
D(V) :-
B(V),
A(W),
W = sum test1 : { C(test1), B(test1), test1 > count : { C(X), A(X) } },
V < max test2 : { C(test2), B(test2), test2 < 2 }.
// third not equivalent
D(V) :-
B(V),
A(W),
W = min test1 : { C(test1), B(test1), test1 > count : { C(X), A(X) } },
V < max test2 : { C(test2), B(test2), test2 < 2 }.
.output D()
)",
errorReport, debugReport);

const auto& program = *tu->getProgram();
std::make_unique<MinimiseProgramTransformer>()->apply(*tu);

// A, B, C, D should still be the relations
EXPECT_EQ(4, program.getRelations().size());
EXPECT_NE(nullptr, getRelation(program, "A"));
EXPECT_NE(nullptr, getRelation(program, "B"));
EXPECT_NE(nullptr, getRelation(program, "C"));
EXPECT_NE(nullptr, getRelation(program, "D"));

// D should now only have the two clauses non-equivalent clauses
const auto& dClauses = getClauses(program, "D");
EXPECT_EQ(2, dClauses.size());
EXPECT_EQ(
"D(X) :- \n B(X),\n X < max Y : { C(Y),B(Y),Y < 2 },\n A(Z),\n Z = sum A : { C(A),B(A),A "
"> count : { A(M),C(M) } }.",
toString(*dClauses[0]));
EXPECT_EQ(
"D(V) :- \n B(V),\n A(W),\n W = min test1 : { C(test1),B(test1),test1 > count : { "
"C(X),A(X) } },\n V < max test2 : { C(test2),B(test2),test2 < 2 }.",
toString(*dClauses[1]));
}

/**
* Test the removal of redundancies within clauses using the MinimiseProgramTransformer.
*
Expand Down

0 comments on commit 01fab76

Please sign in to comment.