Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arithmetic term rewriter #263

Merged
merged 10 commits into from
Apr 21, 2021
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Copyright (c) 2017-2019, the Alpha Team.
* Copyright (c) 2017-2020, the Alpha Team.
* All rights reserved.
*
* Additional changes made by Siemens.
Expand Down Expand Up @@ -27,16 +27,16 @@
*/
package at.ac.tuwien.kr.alpha.common.atoms;

import static at.ac.tuwien.kr.alpha.Util.join;
import at.ac.tuwien.kr.alpha.common.Predicate;
import at.ac.tuwien.kr.alpha.common.fixedinterpretations.PredicateInterpretation;
import at.ac.tuwien.kr.alpha.common.terms.Term;
import at.ac.tuwien.kr.alpha.grounder.Substitution;

import java.util.Collections;
import java.util.List;
import java.util.stream.Collectors;

import at.ac.tuwien.kr.alpha.common.Predicate;
import at.ac.tuwien.kr.alpha.common.fixedinterpretations.PredicateInterpretation;
import at.ac.tuwien.kr.alpha.common.terms.Term;
import at.ac.tuwien.kr.alpha.grounder.Substitution;
import static at.ac.tuwien.kr.alpha.Util.join;

public class ExternalAtom extends Atom implements VariableNormalizableAtom {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,17 @@ public static Term getInstance(Term left, ArithmeticOperator arithmeticOperator,
return INTERNER.intern(new ArithmeticTerm(left, arithmeticOperator, right));
}

public Term getLeft() {
return left;
}

public Term getRight() {
return right;
}

public ArithmeticOperator getArithmeticOperator() {
return arithmeticOperator;
}

@Override
public boolean isGround() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,14 @@ public int getUpperBound() {
return this.upperBound;
}

public Term getLowerBoundTerm() {
return lowerBoundTerm;
}

public Term getUpperBoundTerm() {
return upperBoundTerm;
}

@Override
public List<VariableTerm> getOccurringVariables() {
LinkedList<VariableTerm> variables = new LinkedList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ private NaiveGrounder(InternalProgram program, AtomStore atomStore, GrounderHeur
final Set<InternalRule> uniqueGroundRulePerGroundHead = getRulesWithUniqueHead();
choiceRecorder = new ChoiceRecorder(atomStore);
noGoodGenerator = new NoGoodGenerator(atomStore, choiceRecorder, factsFromProgram, this.program, uniqueGroundRulePerGroundHead);

this.debugInternalChecks = debugInternalChecks;

// Initialize RuleInstantiator and instantiation strategy. Note that the instantiation strategy also
Expand Down Expand Up @@ -276,10 +276,10 @@ public AnswerSet assignmentToAnswerSet(Iterable<Integer> trueAtoms) {
if (knownPredicates.isEmpty()) {
return BasicAnswerSet.EMPTY;
}

return new BasicAnswerSet(knownPredicates, predicateInstances);
}

/**
* Prepares facts of the input program for joining and derives all NoGoods representing ground rules. May only be called once.
* @return
Expand Down
20 changes: 20 additions & 0 deletions src/main/java/at/ac/tuwien/kr/alpha/grounder/Unification.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
package at.ac.tuwien.kr.alpha.grounder;

import at.ac.tuwien.kr.alpha.common.atoms.Atom;
import at.ac.tuwien.kr.alpha.common.terms.ArithmeticTerm;
import at.ac.tuwien.kr.alpha.common.terms.FunctionTerm;
import at.ac.tuwien.kr.alpha.common.terms.Term;
import at.ac.tuwien.kr.alpha.common.terms.VariableTerm;
Expand Down Expand Up @@ -110,6 +111,25 @@ private static boolean unifyTerms(Term left, Term right, Unifier currentSubstitu
}
return true;
}
if (leftSubs instanceof ArithmeticTerm && rightSubs instanceof ArithmeticTerm) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since Unification#unifyTerms is already pretty long and also recursive, would it make sense to split it up into a few methods? I was thinking something along the lines of unifyFunctionTerms, unifyArithmeticTerms, unifyWithVariableTerm, basically getting each of the ifs and accompanying recursive calls into their own method.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As we discussed, I do not think that splitting this up into multiple methods increases readability much.

// ArithmeticTerms are similar to FunctionTerms, i.e. if the operator is the same and its subterms unify, the ArithmeticTerms unify.
final ArithmeticTerm leftArithmeticTerm = (ArithmeticTerm) leftSubs;
final ArithmeticTerm rightArithmeticTerm = (ArithmeticTerm) rightSubs;
if (!leftArithmeticTerm.getArithmeticOperator().equals(rightArithmeticTerm.getArithmeticOperator())) {
return false;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Codecov is complaining that this and the other return false;s in this method aren't covered by any tests. In general, I couldn't find any dedicated unit tests for Unification. I thínk it would make sense to create a new UnificationTest that specifically checks correct behavior of Unifications public methods (thereby covering the whole class).

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test cases added, I hope those statements are now better covered.

}
final Term leftTermLeftSubterm = leftArithmeticTerm.getLeft();
final Term rightTermLeftSubterm = rightArithmeticTerm.getLeft();
if (!unifyTerms(leftTermLeftSubterm, rightTermLeftSubterm, currentSubstitution, keepLeftAsIs)) {
return false;
}
final Term leftTermRightSubterm = leftArithmeticTerm.getRight();
final Term rightTermRightSubterm = rightArithmeticTerm.getRight();
if (!unifyTerms(leftTermRightSubterm, rightTermRightSubterm, currentSubstitution, keepLeftAsIs)) {
return false;
}
return true;
}
return false;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
package at.ac.tuwien.kr.alpha.grounder.transformation;

import at.ac.tuwien.kr.alpha.common.ComparisonOperator;
import at.ac.tuwien.kr.alpha.common.atoms.Atom;
import at.ac.tuwien.kr.alpha.common.atoms.BasicAtom;
import at.ac.tuwien.kr.alpha.common.atoms.ComparisonAtom;
import at.ac.tuwien.kr.alpha.common.atoms.ExternalAtom;
import at.ac.tuwien.kr.alpha.common.atoms.Literal;
import at.ac.tuwien.kr.alpha.common.program.NormalProgram;
import at.ac.tuwien.kr.alpha.common.rule.NormalRule;
import at.ac.tuwien.kr.alpha.common.rule.head.NormalHead;
import at.ac.tuwien.kr.alpha.common.terms.ArithmeticTerm;
import at.ac.tuwien.kr.alpha.common.terms.ConstantTerm;
import at.ac.tuwien.kr.alpha.common.terms.FunctionTerm;
import at.ac.tuwien.kr.alpha.common.terms.IntervalTerm;
import at.ac.tuwien.kr.alpha.common.terms.Term;
import at.ac.tuwien.kr.alpha.common.terms.VariableTerm;
import at.ac.tuwien.kr.alpha.grounder.atoms.IntervalAtom;

import java.util.ArrayList;
import java.util.List;

import static at.ac.tuwien.kr.alpha.Util.oops;

/**
* Transforms rules such that arithmetic terms only occur in comparison predicates.
* For example p(X+1) :- q(Y/2), r(f(X*2),Y), X-2 = Y*3, X = 0..9. is transformed into
* p(_A1) :- q(_A2), r(f(_A3),Y), X-2 = Y*3, _A1 = X+1, _A2 = Y/2, _A3 = X*2, X = 0..9.
*
* Copyright (c) 2020, the Alpha Team.
*/
public class ArithmeticTermsRewriting extends ProgramTransformation<NormalProgram, NormalProgram> {
private static final String ARITHMETIC_VARIABLES_PREFIX = "_A";
private int numArithmeticVariables;

@Override
public NormalProgram apply(NormalProgram inputProgram) {
List<NormalRule> rewrittenRules = new ArrayList<>();
boolean didRewrite = false;
for (NormalRule inputProgramRule : inputProgram.getRules()) {
if (containsArithmeticTermsToRewrite(inputProgramRule)) {
rewrittenRules.add(rewriteRule(inputProgramRule));
didRewrite = true;
} else {
// Keep Rule as-is if no ArithmeticTerm occurs.
rewrittenRules.add(inputProgramRule);
}
}
if (!didRewrite) {
return inputProgram;
}
// Create new program with rewritten rules.
return new NormalProgram(rewrittenRules, inputProgram.getFacts(), inputProgram.getInlineDirectives());
}

private NormalRule rewriteRule(NormalRule inputProgramRule) {
numArithmeticVariables = 0; // Reset numbers for introduced variables for each rule.
NormalHead rewrittenHead = null;
List<Literal> rewrittenBodyLiterals = new ArrayList<>();
// Rewrite head.
if (!inputProgramRule.isConstraint()) {
Atom headAtom = inputProgramRule.getHeadAtom();
if (containsArithmeticTermsToRewrite(headAtom)) {
rewrittenHead = new NormalHead(rewriteAtom(headAtom, rewrittenBodyLiterals));
} else {
rewrittenHead = inputProgramRule.getHead();
}
}
// Rewrite body.
for (Literal literal : inputProgramRule.getBody()) {
if (!containsArithmeticTermsToRewrite(literal.getAtom())) {
// Keep body literal as-is if no ArithmeticTerm occurs.
rewrittenBodyLiterals.add(literal);
continue;
}
rewrittenBodyLiterals.add(rewriteAtom(literal.getAtom(), rewrittenBodyLiterals).toLiteral(!literal.isNegated()));
}
return new NormalRule(rewrittenHead, rewrittenBodyLiterals);
}

private boolean containsArithmeticTermsToRewrite(NormalRule inputProgramRule) {
if (!inputProgramRule.isConstraint()) {
Atom headAtom = inputProgramRule.getHeadAtom();
if (containsArithmeticTermsToRewrite(headAtom)) {
return true;
}
}
// Check whether body contains an ArithmeticTerm.
for (Literal literal : inputProgramRule.getBody()) {
if (containsArithmeticTermsToRewrite(literal.getAtom())) {
return true;
}
}
return false;
}

private Term rewriteArithmeticSubterms(Term term, List<Literal> bodyLiterals) {
// Keep term as-is if it contains no ArithmeticTerm.
if (!containsArithmeticTerm(term)) {
return term;
}
// Switch on term type.
if (term instanceof ArithmeticTerm) {
VariableTerm replacementVariable = VariableTerm.getInstance(ARITHMETIC_VARIABLES_PREFIX + numArithmeticVariables++);
bodyLiterals.add(new ComparisonAtom(replacementVariable, term, ComparisonOperator.EQ).toLiteral());
return replacementVariable;
} else if (term instanceof VariableTerm || term instanceof ConstantTerm) {
return term;
} else if (term instanceof FunctionTerm) {
List<Term> termList = ((FunctionTerm) term).getTerms();
List<Term> rewrittenTermList = new ArrayList<>();
for (Term subterm : termList) {
rewrittenTermList.add(rewriteArithmeticSubterms(subterm, bodyLiterals));
}
return FunctionTerm.getInstance(((FunctionTerm) term).getSymbol(), rewrittenTermList);
} else {
throw oops("Rewriting unknown Term type: " + term.getClass());
}
}

private Atom rewriteAtom(Atom atomToRewrite, List<Literal> bodyLiterals) {
List<Term> rewrittenTerms = new ArrayList<>();
for (Term atomTerm : atomToRewrite.getTerms()) {
// Rewrite arithmetic term.
rewrittenTerms.add(rewriteArithmeticSubterms(atomTerm, bodyLiterals));
}

// NOTE: we have to construct an atom of the same type as atomToRewrite, there should be a nicer way than that instanceof checks below.
if (atomToRewrite instanceof BasicAtom) {
return new BasicAtom(atomToRewrite.getPredicate(), rewrittenTerms);
} else if (atomToRewrite instanceof ComparisonAtom) {
throw oops("Trying to rewrite ComparisonAtom.");
} else if (atomToRewrite instanceof ExternalAtom) {
// Rewrite output terms, as so-far only the input terms list has been rewritten.
List<Term> rewrittenOutputTerms = new ArrayList<>();
for (Term term : ((ExternalAtom) atomToRewrite).getOutput()) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Codecov is complaining that the ExternalAtom part is not covered by any tests. Generally, I think it would make sense to have a dedicated ArithmeticTermsRewritingTest in addition to the existing (more "regession-test-like") ArithmeticTermsTest. The advantage of such a "real unit test" would be (in addition to codecov being able to measure coverage better) that a set of more fine-grained tests would also be an ideal starting point for debugging if ever any issues related to this code occurred.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ArithmeticTermsRewritingTest is now also included.

rewrittenOutputTerms.add(rewriteArithmeticSubterms(term, bodyLiterals));
}
return new ExternalAtom(atomToRewrite.getPredicate(), ((ExternalAtom) atomToRewrite).getInterpretation(), rewrittenTerms, rewrittenOutputTerms);
} else if (atomToRewrite instanceof IntervalAtom) {
return new IntervalAtom((IntervalTerm) rewrittenTerms.get(0), rewrittenTerms.get(1));
} else {
throw oops("Unknown Atom type: " + atomToRewrite.getClass());
}
}

private boolean containsArithmeticTermsToRewrite(Atom atom) {
// ComparisonAtom needs no rewriting.
if (atom instanceof ComparisonAtom) {
return false;
}
for (Term term : atom.getTerms()) {
if (containsArithmeticTerm(term)) {
return true;
}
}
return false;
}

private boolean containsArithmeticTerm(Term term) {
// Note: this check probably should be part of the Term interface and done by subtype polymorphism.
if (term instanceof ArithmeticTerm) {
return true;
} else if (term instanceof ConstantTerm || term instanceof VariableTerm) {
return false;
} else if (term instanceof FunctionTerm) {
for (Term subterm : ((FunctionTerm) term).getTerms()) {
if (containsArithmeticTerm(subterm)) {
return true;
}
}
return false;
} else if (term instanceof IntervalTerm) {
return containsArithmeticTerm(((IntervalTerm) term).getLowerBoundTerm()) || containsArithmeticTerm(((IntervalTerm) term).getUpperBoundTerm());
} else {
throw oops("Unexpected term type: " + term.getClass());
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ public NormalProgram apply(InputProgram inputProgram) {
NormalProgram retVal = NormalProgram.fromInputProgram(tmpPrg);
// Transform intervals - CAUTION - this MUST come before VariableEqualityRemoval!
retVal = new IntervalTermToIntervalAtom().apply(retVal);
// Rewrite ArithmeticTerms.
retVal = new ArithmeticTermsRewriting().apply(retVal);
// Remove variable equalities.
retVal = new VariableEqualityRemoval().apply(retVal);
return retVal;
Expand Down
2 changes: 1 addition & 1 deletion src/test/java/at/ac/tuwien/kr/alpha/api/AlphaTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ public void withExternalSubtype() throws Exception {

Alpha system = new Alpha();

InputProgram prog = new InputProgram(singletonList(rule), emptyList(), new InlineDirectives());
InputProgram prog = new InputProgram(new ArrayList<>(singleton(rule)), emptyList(), new InlineDirectives());

Set<AnswerSet> actual = system.solve(prog).collect(Collectors.toSet());
Set<AnswerSet> expected = new HashSet<>(singletonList(new AnswerSetBuilder().predicate("p").instance("x").build()));
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package at.ac.tuwien.kr.alpha.solver;

import org.junit.Test;

import java.io.IOException;

/**
* Tests ASP programs containing arithmetic terms at arbitrary positions.
*
* Copyright (c) 2020, the Alpha Team.
*/
public class ArithmeticTermsTest extends AbstractSolverTests {

@Test
public void testArithmeticTermInHead() throws IOException {
String program = "dom(1). dom(2)."
+ "p(X+3) :- dom(X).";
assertAnswerSet(program, "dom(1),dom(2),p(4),p(5)");
}

@Test
public void testArithmeticTermInRule() throws IOException {
String program = "dom(1). dom(2)."
+ "p(Y+4) :- dom(X+1), dom(X), Y=X, X=Y.";
assertAnswerSet(program, "dom(1),dom(2),p(5)");
}

@Test
public void testArithmeticTermInChoiceRule() throws IOException {
String program = "cycle_max(4). cycle(1)." +
"{ cycle(N+1) } :- cycle(N), cycle_max(K), N<K.";
assertAnswerSetsWithBase(program, "cycle_max(4),cycle(1)", "", "cycle(2)", "cycle(2),cycle(3)", "cycle(2),cycle(3),cycle(4)");
}

@Test
public void testMultipleArithmeticTermsInRules() throws IOException {
String program = "q(1). q(3). r(f(40),6)." +
"p(X+1) :- q(Y/2), r(f(X*2),Y), X-2 = Y*3, X = 0..20." +
"bar(X,Y) :- q(Y/2), r(f(X*2),Y), X-2 = Y*3, X = 20.";
assertAnswerSet(program, "q(1),q(3),r(f(40),6),p(21),bar(20,6)");
}
}