Skip to content

Commit

Permalink
Add more unit test for ArithmeticTermsRewriting.
Browse files Browse the repository at this point in the history
- Polished ArithmeticTermsRewriting a bit.
- Add ArithmeticTermsRewritingTest, also testing with ExternalAtoms.
  • Loading branch information
AntoniusW committed Apr 15, 2021
1 parent 00b5594 commit 7ed0c23
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
* For example p(X+1) :- q(Y/2), r(f(X*2),Y), X-2 = Y*3, X = 0..9. is transformed into
* p(_A1) :- q(_A2), r(f(_A3),Y), X-2 = Y*3, _A1 = X+1, _A2 = Y/2, _A3 = X*2, X = 0..9.
*
* Copyright (c) 2020, the Alpha Team.
* Copyright (c) 2020-2021, the Alpha Team.
*/
public class ArithmeticTermsRewriting extends ProgramTransformation<NormalProgram, NormalProgram> {
private static final String ARITHMETIC_VARIABLES_PREFIX = "_A";
Expand All @@ -42,7 +42,7 @@ public NormalProgram apply(NormalProgram inputProgram) {
rewrittenRules.add(rewriteRule(inputProgramRule));
didRewrite = true;
} else {
// Keep Rule as-is if no ArithmeticTerm occurs.
// Keep rule as-is if no ArithmeticTerm occurs.
rewrittenRules.add(inputProgramRule);
}
}
Expand All @@ -53,8 +53,14 @@ public NormalProgram apply(NormalProgram inputProgram) {
return new NormalProgram(rewrittenRules, inputProgram.getFacts(), inputProgram.getInlineDirectives());
}

/**
* Takes a normal rule and rewrites it such that {@link ArithmeticTerm}s only appear inside {@link at.ac.tuwien.kr.alpha.common.atoms.ComparisonLiteral}s.
*
* @param inputProgramRule the rule to rewrite.
* @return the rewritten rule. Note that a new {@link NormalRule} is returned for every call of this method.
*/
private NormalRule rewriteRule(NormalRule inputProgramRule) {
numArithmeticVariables = 0; // Reset numbers for introduced variables for each rule.
numArithmeticVariables = 0; // Reset number of introduced variables for each rule.
NormalHead rewrittenHead = null;
List<Literal> rewrittenBodyLiterals = new ArrayList<>();
// Rewrite head.
Expand All @@ -78,6 +84,12 @@ private NormalRule rewriteRule(NormalRule inputProgramRule) {
return new NormalRule(rewrittenHead, rewrittenBodyLiterals);
}

/**
* Checks whether a normal rule contains an {@link ArithmeticTerm} outside of a {@link at.ac.tuwien.kr.alpha.common.atoms.ComparisonLiteral}.
*
* @param inputProgramRule the rule to check for presence of arithmetic terms outside comparison literals.
* @return true if the inputProgramRule contains an {@link ArithmeticTerm} outside of a {@link at.ac.tuwien.kr.alpha.common.atoms.ComparisonLiteral}.
*/
private boolean containsArithmeticTermsToRewrite(NormalRule inputProgramRule) {
if (!inputProgramRule.isConstraint()) {
Atom headAtom = inputProgramRule.getHeadAtom();
Expand Down Expand Up @@ -127,7 +139,7 @@ private Atom rewriteAtom(Atom atomToRewrite, List<Literal> bodyLiterals) {

// NOTE: we have to construct an atom of the same type as atomToRewrite, there should be a nicer way than that instanceof checks below.
if (atomToRewrite instanceof BasicAtom) {
return new BasicAtom(atomToRewrite.getPredicate(), rewrittenTerms);
return atomToRewrite.withTerms(rewrittenTerms);
} else if (atomToRewrite instanceof ComparisonAtom) {
throw oops("Trying to rewrite ComparisonAtom.");
} else if (atomToRewrite instanceof ExternalAtom) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
package at.ac.tuwien.kr.alpha.grounder.transformation;

import at.ac.tuwien.kr.alpha.api.externals.Externals;
import at.ac.tuwien.kr.alpha.api.externals.Predicate;
import at.ac.tuwien.kr.alpha.common.atoms.ExternalAtom;
import at.ac.tuwien.kr.alpha.common.atoms.ExternalLiteral;
import at.ac.tuwien.kr.alpha.common.atoms.Literal;
import at.ac.tuwien.kr.alpha.common.fixedinterpretations.PredicateInterpretation;
import at.ac.tuwien.kr.alpha.common.program.NormalProgram;
import at.ac.tuwien.kr.alpha.common.rule.NormalRule;
import at.ac.tuwien.kr.alpha.common.terms.ConstantTerm;
import at.ac.tuwien.kr.alpha.common.terms.Terms;
import at.ac.tuwien.kr.alpha.common.terms.VariableTerm;
import at.ac.tuwien.kr.alpha.grounder.parser.ProgramParser;
import org.junit.Test;

import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.Set;

import static java.util.stream.Collectors.toList;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

/**
* Copyright (c) 2021, the Alpha Team.
*/
public class ArithmeticTermsRewritingTest {

private final Map<String, PredicateInterpretation> externalsOfThisClass = Externals.scan(ArithmeticTermsRewritingTest.class);
private final ProgramParser parser = new ProgramParser(externalsOfThisClass); // Create parser that knows an implementation of external atom &extArithTest[]().

@Predicate(name = "extArithTest")
public static Set<List<ConstantTerm<Integer>>> externalForArithmeticTermsRewriting(Integer in) {
List<ConstantTerm<Integer>> terms = Terms.asTermList(
in * 314);
return Collections.singleton(terms);
}

@Test
public void rewriteRule() {
NormalProgram inputProgram = NormalProgram.fromInputProgram(parser.parse("p(X+1) :- q(Y/2), r(f(X*2),Y), X-2 = Y*3, X = 0..9."));
assertEquals(1, inputProgram.getRules().size());
ArithmeticTermsRewriting arithmeticTermsRewriting = new ArithmeticTermsRewriting();
NormalProgram rewrittenProgram = arithmeticTermsRewriting.apply(inputProgram);
// Expect the rewritten program to be one rule with: p(_A0) :- _A0 = X+1, _A1 = Y/2, q(_A1), _A2 = X*2, r(f(_A2),Y), X-2 = Y*3, X = 0..9.
assertEquals(1, rewrittenProgram.getRules().size());
NormalRule rewrittenRule = rewrittenProgram.getRules().get(0);
assertTrue(rewrittenRule.getHeadAtom().getTerms().get(0) instanceof VariableTerm);
assertEquals(7, rewrittenRule.getBody().size());
}

@Test
public void rewriteExternalAtom() {
NormalProgram inputProgram = NormalProgram.fromInputProgram(parser.parse("p :- Y = 13, &extArithTest[Y*5](Y-4)."));
assertEquals(1, inputProgram.getRules().size());
ArithmeticTermsRewriting arithmeticTermsRewriting = new ArithmeticTermsRewriting();
NormalProgram rewrittenProgram = arithmeticTermsRewriting.apply(inputProgram);
assertEquals(1, rewrittenProgram.getRules().size());
NormalRule rewrittenRule = rewrittenProgram.getRules().get(0);
assertEquals(4, rewrittenRule.getBody().size());
List<Literal> externalLiterals = rewrittenRule.getBody().stream().filter(lit -> lit instanceof ExternalLiteral).collect(toList());
assertEquals(1, externalLiterals.size());
ExternalAtom rewrittenExternal = ((ExternalLiteral) externalLiterals.get(0)).getAtom();
assertEquals(1, rewrittenExternal.getInput().size());
assertTrue(rewrittenExternal.getInput().get(0) instanceof VariableTerm);
assertEquals(1, rewrittenExternal.getOutput().size());
assertTrue(rewrittenExternal.getOutput().get(0) instanceof VariableTerm);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@

import org.junit.Test;

import java.io.IOException;

/**
* Tests ASP programs containing arithmetic terms at arbitrary positions.
*
Expand All @@ -12,31 +10,45 @@
public class ArithmeticTermsTest extends AbstractSolverTests {

@Test
public void testArithmeticTermInHead() throws IOException {
public void testArithmeticTermInHead() {
String program = "dom(1). dom(2)."
+ "p(X+3) :- dom(X).";
assertAnswerSet(program, "dom(1),dom(2),p(4),p(5)");
}

@Test
public void testArithmeticTermInRule() throws IOException {
public void testArithmeticTermInRule() {
String program = "dom(1). dom(2)."
+ "p(Y+4) :- dom(X+1), dom(X), Y=X, X=Y.";
assertAnswerSet(program, "dom(1),dom(2),p(5)");
}

@Test
public void testArithmeticTermInChoiceRule() throws IOException {
public void testArithmeticTermInChoiceRule() {
String program = "cycle_max(4). cycle(1)." +
"{ cycle(N+1) } :- cycle(N), cycle_max(K), N<K.";
assertAnswerSetsWithBase(program, "cycle_max(4),cycle(1)", "", "cycle(2)", "cycle(2),cycle(3)", "cycle(2),cycle(3),cycle(4)");
}

@Test
public void testMultipleArithmeticTermsInRules() throws IOException {
public void testMultipleArithmeticTermsInRules() {
String program = "q(1). q(3). r(f(40),6)." +
"p(X+1) :- q(Y/2), r(f(X*2),Y), X-2 = Y*3, X = 0..20." +
"bar(X,Y) :- q(Y/2), r(f(X*2),Y), X-2 = Y*3, X = 20.";
assertAnswerSet(program, "q(1),q(3),r(f(40),6),p(21),bar(20,6)");
}

@Test
public void testMultipleArithmeticTermsInFunctionTermsInHead() {
String program = "dom(1). dom(2)."
+ "p(f(X+1),g(X+3)) :- dom(X).";
assertAnswerSet(program, "dom(1),dom(2),p(f(2),g(4)),p(f(3),g(5))");
}

@Test
public void testMultipleNestedArithmeticTermsInRules() {
String program = "domx(1). domx(2). domy(6)."
+ "p(f(X+(Y/2))) :- domx(X), domy(Y).";
assertAnswerSet(program, "domx(1),domx(2),domy(6),p(f(4)),p(f(5))");
}
}

0 comments on commit 7ed0c23

Please sign in to comment.