-
Notifications
You must be signed in to change notification settings - Fork 10
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Arithmetic term rewriter #263
Changes from 6 commits
043a448
0782bce
128f0c4
29e713d
0dd4656
0f8ff14
00b5594
7ed0c23
3b59689
52e895a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -28,6 +28,7 @@ | |
package at.ac.tuwien.kr.alpha.grounder; | ||
|
||
import at.ac.tuwien.kr.alpha.common.atoms.Atom; | ||
import at.ac.tuwien.kr.alpha.common.terms.ArithmeticTerm; | ||
import at.ac.tuwien.kr.alpha.common.terms.FunctionTerm; | ||
import at.ac.tuwien.kr.alpha.common.terms.Term; | ||
import at.ac.tuwien.kr.alpha.common.terms.VariableTerm; | ||
|
@@ -110,6 +111,25 @@ private static boolean unifyTerms(Term left, Term right, Unifier currentSubstitu | |
} | ||
return true; | ||
} | ||
if (leftSubs instanceof ArithmeticTerm && rightSubs instanceof ArithmeticTerm) { | ||
// ArithmeticTerms are similar to FunctionTerms, i.e. if the operator is the same and its subterms unify, the ArithmeticTerms unify. | ||
final ArithmeticTerm leftArithmeticTerm = (ArithmeticTerm) leftSubs; | ||
final ArithmeticTerm rightArithmeticTerm = (ArithmeticTerm) rightSubs; | ||
if (!leftArithmeticTerm.getArithmeticOperator().equals(rightArithmeticTerm.getArithmeticOperator())) { | ||
return false; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Codecov is complaining that this and the other There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Test cases added, I hope those statements are now better covered. |
||
} | ||
final Term leftTermLeftSubterm = leftArithmeticTerm.getLeft(); | ||
final Term rightTermLeftSubterm = rightArithmeticTerm.getLeft(); | ||
if (!unifyTerms(leftTermLeftSubterm, rightTermLeftSubterm, currentSubstitution, keepLeftAsIs)) { | ||
return false; | ||
} | ||
final Term leftTermRightSubterm = leftArithmeticTerm.getRight(); | ||
final Term rightTermRightSubterm = rightArithmeticTerm.getRight(); | ||
if (!unifyTerms(leftTermRightSubterm, rightTermRightSubterm, currentSubstitution, keepLeftAsIs)) { | ||
return false; | ||
} | ||
return true; | ||
} | ||
return false; | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
package at.ac.tuwien.kr.alpha.grounder.transformation; | ||
|
||
import at.ac.tuwien.kr.alpha.common.ComparisonOperator; | ||
import at.ac.tuwien.kr.alpha.common.atoms.Atom; | ||
import at.ac.tuwien.kr.alpha.common.atoms.BasicAtom; | ||
import at.ac.tuwien.kr.alpha.common.atoms.ComparisonAtom; | ||
import at.ac.tuwien.kr.alpha.common.atoms.ExternalAtom; | ||
import at.ac.tuwien.kr.alpha.common.atoms.Literal; | ||
import at.ac.tuwien.kr.alpha.common.program.NormalProgram; | ||
import at.ac.tuwien.kr.alpha.common.rule.NormalRule; | ||
import at.ac.tuwien.kr.alpha.common.rule.head.NormalHead; | ||
import at.ac.tuwien.kr.alpha.common.terms.ArithmeticTerm; | ||
import at.ac.tuwien.kr.alpha.common.terms.ConstantTerm; | ||
import at.ac.tuwien.kr.alpha.common.terms.FunctionTerm; | ||
import at.ac.tuwien.kr.alpha.common.terms.IntervalTerm; | ||
import at.ac.tuwien.kr.alpha.common.terms.Term; | ||
import at.ac.tuwien.kr.alpha.common.terms.VariableTerm; | ||
import at.ac.tuwien.kr.alpha.grounder.atoms.IntervalAtom; | ||
|
||
import java.util.ArrayList; | ||
import java.util.List; | ||
|
||
import static at.ac.tuwien.kr.alpha.Util.oops; | ||
|
||
/** | ||
* Transforms rules such that arithmetic terms only occur in comparison predicates. | ||
* For example p(X+1) :- q(Y/2), r(f(X*2),Y), X-2 = Y*3, X = 0..9. is transformed into | ||
* p(_A1) :- q(_A2), r(f(_A3),Y), X-2 = Y*3, _A1 = X+1, _A2 = Y/2, _A3 = X*2, X = 0..9. | ||
* | ||
* Copyright (c) 2020, the Alpha Team. | ||
*/ | ||
public class ArithmeticTermsRewriting extends ProgramTransformation<NormalProgram, NormalProgram> { | ||
private static final String ARITHMETIC_VARIABLES_PREFIX = "_A"; | ||
private int numArithmeticVariables; | ||
|
||
@Override | ||
public NormalProgram apply(NormalProgram inputProgram) { | ||
List<NormalRule> rewrittenRules = new ArrayList<>(); | ||
boolean didRewrite = false; | ||
for (NormalRule inputProgramRule : inputProgram.getRules()) { | ||
if (containsArithmeticTermsToRewrite(inputProgramRule)) { | ||
rewrittenRules.add(rewriteRule(inputProgramRule)); | ||
didRewrite = true; | ||
} else { | ||
// Keep Rule as-is if no ArithmeticTerm occurs. | ||
rewrittenRules.add(inputProgramRule); | ||
} | ||
} | ||
if (!didRewrite) { | ||
return inputProgram; | ||
} | ||
// Create new program with rewritten rules. | ||
return new NormalProgram(rewrittenRules, inputProgram.getFacts(), inputProgram.getInlineDirectives()); | ||
} | ||
|
||
private NormalRule rewriteRule(NormalRule inputProgramRule) { | ||
numArithmeticVariables = 0; // Reset numbers for introduced variables for each rule. | ||
NormalHead rewrittenHead = null; | ||
List<Literal> rewrittenBodyLiterals = new ArrayList<>(); | ||
// Rewrite head. | ||
if (!inputProgramRule.isConstraint()) { | ||
Atom headAtom = inputProgramRule.getHeadAtom(); | ||
if (containsArithmeticTermsToRewrite(headAtom)) { | ||
rewrittenHead = new NormalHead(rewriteAtom(headAtom, rewrittenBodyLiterals)); | ||
} else { | ||
rewrittenHead = inputProgramRule.getHead(); | ||
} | ||
} | ||
// Rewrite body. | ||
for (Literal literal : inputProgramRule.getBody()) { | ||
if (!containsArithmeticTermsToRewrite(literal.getAtom())) { | ||
// Keep body literal as-is if no ArithmeticTerm occurs. | ||
rewrittenBodyLiterals.add(literal); | ||
continue; | ||
} | ||
rewrittenBodyLiterals.add(rewriteAtom(literal.getAtom(), rewrittenBodyLiterals).toLiteral(!literal.isNegated())); | ||
} | ||
return new NormalRule(rewrittenHead, rewrittenBodyLiterals); | ||
} | ||
|
||
private boolean containsArithmeticTermsToRewrite(NormalRule inputProgramRule) { | ||
if (!inputProgramRule.isConstraint()) { | ||
Atom headAtom = inputProgramRule.getHeadAtom(); | ||
if (containsArithmeticTermsToRewrite(headAtom)) { | ||
return true; | ||
} | ||
} | ||
// Check whether body contains an ArithmeticTerm. | ||
for (Literal literal : inputProgramRule.getBody()) { | ||
if (containsArithmeticTermsToRewrite(literal.getAtom())) { | ||
return true; | ||
} | ||
} | ||
return false; | ||
} | ||
|
||
private Term rewriteArithmeticSubterms(Term term, List<Literal> bodyLiterals) { | ||
// Keep term as-is if it contains no ArithmeticTerm. | ||
if (!containsArithmeticTerm(term)) { | ||
return term; | ||
} | ||
// Switch on term type. | ||
if (term instanceof ArithmeticTerm) { | ||
VariableTerm replacementVariable = VariableTerm.getInstance(ARITHMETIC_VARIABLES_PREFIX + numArithmeticVariables++); | ||
bodyLiterals.add(new ComparisonAtom(replacementVariable, term, ComparisonOperator.EQ).toLiteral()); | ||
return replacementVariable; | ||
} else if (term instanceof VariableTerm || term instanceof ConstantTerm) { | ||
return term; | ||
} else if (term instanceof FunctionTerm) { | ||
List<Term> termList = ((FunctionTerm) term).getTerms(); | ||
List<Term> rewrittenTermList = new ArrayList<>(); | ||
for (Term subterm : termList) { | ||
rewrittenTermList.add(rewriteArithmeticSubterms(subterm, bodyLiterals)); | ||
} | ||
return FunctionTerm.getInstance(((FunctionTerm) term).getSymbol(), rewrittenTermList); | ||
} else { | ||
throw oops("Rewriting unknown Term type: " + term.getClass()); | ||
} | ||
} | ||
|
||
private Atom rewriteAtom(Atom atomToRewrite, List<Literal> bodyLiterals) { | ||
List<Term> rewrittenTerms = new ArrayList<>(); | ||
for (Term atomTerm : atomToRewrite.getTerms()) { | ||
// Rewrite arithmetic term. | ||
rewrittenTerms.add(rewriteArithmeticSubterms(atomTerm, bodyLiterals)); | ||
} | ||
|
||
// NOTE: we have to construct an atom of the same type as atomToRewrite, there should be a nicer way than that instanceof checks below. | ||
if (atomToRewrite instanceof BasicAtom) { | ||
return new BasicAtom(atomToRewrite.getPredicate(), rewrittenTerms); | ||
} else if (atomToRewrite instanceof ComparisonAtom) { | ||
throw oops("Trying to rewrite ComparisonAtom."); | ||
} else if (atomToRewrite instanceof ExternalAtom) { | ||
// Rewrite output terms, as so-far only the input terms list has been rewritten. | ||
List<Term> rewrittenOutputTerms = new ArrayList<>(); | ||
for (Term term : ((ExternalAtom) atomToRewrite).getOutput()) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Codecov is complaining that the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
rewrittenOutputTerms.add(rewriteArithmeticSubterms(term, bodyLiterals)); | ||
} | ||
return new ExternalAtom(atomToRewrite.getPredicate(), ((ExternalAtom) atomToRewrite).getInterpretation(), rewrittenTerms, rewrittenOutputTerms); | ||
} else if (atomToRewrite instanceof IntervalAtom) { | ||
return new IntervalAtom((IntervalTerm) rewrittenTerms.get(0), rewrittenTerms.get(1)); | ||
} else { | ||
throw oops("Unknown Atom type: " + atomToRewrite.getClass()); | ||
} | ||
} | ||
|
||
private boolean containsArithmeticTermsToRewrite(Atom atom) { | ||
// ComparisonAtom needs no rewriting. | ||
if (atom instanceof ComparisonAtom) { | ||
return false; | ||
} | ||
for (Term term : atom.getTerms()) { | ||
if (containsArithmeticTerm(term)) { | ||
return true; | ||
} | ||
} | ||
return false; | ||
} | ||
|
||
private boolean containsArithmeticTerm(Term term) { | ||
// Note: this check probably should be part of the Term interface and done by subtype polymorphism. | ||
if (term instanceof ArithmeticTerm) { | ||
return true; | ||
} else if (term instanceof ConstantTerm || term instanceof VariableTerm) { | ||
return false; | ||
} else if (term instanceof FunctionTerm) { | ||
for (Term subterm : ((FunctionTerm) term).getTerms()) { | ||
if (containsArithmeticTerm(subterm)) { | ||
return true; | ||
} | ||
} | ||
return false; | ||
} else if (term instanceof IntervalTerm) { | ||
return containsArithmeticTerm(((IntervalTerm) term).getLowerBoundTerm()) || containsArithmeticTerm(((IntervalTerm) term).getUpperBoundTerm()); | ||
} else { | ||
throw oops("Unexpected term type: " + term.getClass()); | ||
} | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
package at.ac.tuwien.kr.alpha.solver; | ||
|
||
import org.junit.Test; | ||
|
||
import java.io.IOException; | ||
|
||
/** | ||
* Tests ASP programs containing arithmetic terms at arbitrary positions. | ||
* | ||
* Copyright (c) 2020, the Alpha Team. | ||
*/ | ||
public class ArithmeticTermsTest extends AbstractSolverTests { | ||
|
||
@Test | ||
public void testArithmeticTermInHead() throws IOException { | ||
String program = "dom(1). dom(2)." | ||
+ "p(X+3) :- dom(X)."; | ||
assertAnswerSet(program, "dom(1),dom(2),p(4),p(5)"); | ||
} | ||
|
||
@Test | ||
public void testArithmeticTermInRule() throws IOException { | ||
String program = "dom(1). dom(2)." | ||
+ "p(Y+4) :- dom(X+1), dom(X), Y=X, X=Y."; | ||
assertAnswerSet(program, "dom(1),dom(2),p(5)"); | ||
} | ||
|
||
@Test | ||
public void testArithmeticTermInChoiceRule() throws IOException { | ||
String program = "cycle_max(4). cycle(1)." + | ||
"{ cycle(N+1) } :- cycle(N), cycle_max(K), N<K."; | ||
assertAnswerSetsWithBase(program, "cycle_max(4),cycle(1)", "", "cycle(2)", "cycle(2),cycle(3)", "cycle(2),cycle(3),cycle(4)"); | ||
} | ||
|
||
@Test | ||
public void testMultipleArithmeticTermsInRules() throws IOException { | ||
String program = "q(1). q(3). r(f(40),6)." + | ||
"p(X+1) :- q(Y/2), r(f(X*2),Y), X-2 = Y*3, X = 0..20." + | ||
"bar(X,Y) :- q(Y/2), r(f(X*2),Y), X-2 = Y*3, X = 20."; | ||
assertAnswerSet(program, "q(1),q(3),r(f(40),6),p(21),bar(20,6)"); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since
Unification#unifyTerms
is already pretty long and also recursive, would it make sense to split it up into a few methods? I was thinking something along the lines ofunifyFunctionTerms
,unifyArithmeticTerms
,unifyWithVariableTerm
, basically getting each of theif
s and accompanying recursive calls into their own method.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As we discussed, I do not think that splitting this up into multiple methods increases readability much.