Skip to content

Commit

Permalink
improve suggestion for repeated math operations
Browse files Browse the repository at this point in the history
  • Loading branch information
Luro02 committed Apr 6, 2024
1 parent df1047b commit fbba113
Show file tree
Hide file tree
Showing 3 changed files with 267 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,27 @@
import de.firemage.autograder.core.integrated.IntegratedCheck;
import de.firemage.autograder.core.integrated.SpoonUtil;
import de.firemage.autograder.core.integrated.StaticAnalysis;
import de.firemage.autograder.core.integrated.evaluator.Evaluator;
import de.firemage.autograder.core.integrated.evaluator.fold.Fold;
import spoon.processing.AbstractProcessor;
import spoon.reflect.code.BinaryOperatorKind;
import spoon.reflect.code.CtBinaryOperator;
import spoon.reflect.code.CtExpression;
import spoon.reflect.code.CtFieldRead;
import spoon.reflect.code.CtVariableRead;
import spoon.reflect.code.UnaryOperatorKind;
import spoon.reflect.declaration.CtElement;
import spoon.reflect.factory.TypeFactory;
import spoon.reflect.reference.CtVariableReference;

import java.util.Comparator;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.BiFunction;
import java.util.stream.Collectors;
import java.util.stream.Stream;

Expand All @@ -25,54 +36,156 @@ public class RepeatedMathOperationCheck extends IntegratedCheck {
private static final Map<BinaryOperatorKind, Integer> OCCURRENCE_THRESHOLDS =
Map.of(BinaryOperatorKind.PLUS, 2, BinaryOperatorKind.MUL, 3);

private record Variable(CtVariableReference<?> ctVariableReference, CtExpression<?> target) {}
private record Variable(CtVariableReference<?> ctVariableReference, CtExpression<?> target) {
}


private static List<CtExpression<?>> splitOperator(CtBinaryOperator<?> ctBinaryOperator, BinaryOperatorKind kind) {
List<CtExpression<?>> result = new ArrayList<>();

if (ctBinaryOperator.getKind() != kind) {
return new ArrayList<>(List.of(ctBinaryOperator));
}

CtExpression<?> left = ctBinaryOperator.getLeftHandOperand();
CtExpression<?> right = ctBinaryOperator.getRightHandOperand();

// The right hand side can also be a binary operator, e.g. a + (b + c)
if (right instanceof CtBinaryOperator<?> rightOperator) {
List<CtExpression<?>> rightOperands = splitOperator(rightOperator, kind);
// the right operands have to be added in reverse order
Collections.reverse(rightOperands);
result.addAll(rightOperands);
} else {
result.add(right);
}

while (left instanceof CtBinaryOperator<?> lhs && lhs.getKind() == kind) {
result.add(lhs.getRightHandOperand());
left = lhs.getLeftHandOperand();
}

result.add(left);

Collections.reverse(result);

return result;
}

/**
* This class optimizes repeated operations like `a + a + a + a` to `a * 4`.
*/
private record OperatorFolder(BinaryOperatorKind kind, int threshold, BiFunction<CtExpression<?>, Integer, CtExpression<?>> function) implements Fold {
@Override
public CtElement enter(CtElement ctElement) {
return this.fold(ctElement);
}

@Override
public CtElement exit(CtElement ctElement) {
return ctElement;
}


@Override
@SuppressWarnings("unchecked")
public <T> CtExpression<T> foldCtBinaryOperator(CtBinaryOperator<T> ctBinaryOperator) {
// skip if the operator is not supported
if (!OCCURRENCE_THRESHOLDS.containsKey(ctBinaryOperator.getKind()) ||
!SpoonUtil.isPrimitiveNumeric(ctBinaryOperator.getType())) {
return ctBinaryOperator;
}

List<CtExpression<?>> operands = splitOperator(ctBinaryOperator, this.kind);

Map<CtExpression<?>, Integer> occurrences = operands.stream()
.collect(Collectors.toMap(o -> o, o -> 1, Integer::sum, LinkedHashMap::new));

// reconstruct the binary operator (note: this will destroy the original parentheses)
return (CtExpression<T>) occurrences.entrySet()
.stream()
.map(entry -> {
var expression = entry.getKey();
int count = entry.getValue();

if (count < this.threshold) {
return repeatExpression(this.kind, expression, count);
} else {
return this.function.apply(expression, count);
}
})
.reduce((left, right) -> SpoonUtil.createBinaryOperator(left, right, this.kind))
.orElseThrow();
}
}


public static CtExpression<?> repeatExpression(BinaryOperatorKind kind, CtExpression<?> expression, int count) {
CtExpression<?>[] array = new CtExpression<?>[count - 1];
Arrays.fill(array, expression);
return joinExpressions(kind, expression, array);
}

public static CtExpression<?> joinExpressions(BinaryOperatorKind kind, CtExpression<?> first, CtExpression<?>... others) {
return Arrays.stream(others).reduce(first, (left, right) -> SpoonUtil.createBinaryOperator(left, right, kind));
}

@Override
protected void check(StaticAnalysis staticAnalysis, DynamicAnalysis dynamicAnalysis) {
staticAnalysis.processWith(new AbstractProcessor<CtBinaryOperator<?>>() {
staticAnalysis.processWith(new AbstractProcessor<CtExpression<?>>() {
@Override
public void process(CtBinaryOperator<?> operator) {
if (!OCCURRENCE_THRESHOLDS.containsKey(operator.getKind())) {
public void process(CtExpression<?> ctExpression) {
if (ctExpression.isImplicit()
|| !ctExpression.getPosition().isValidPosition()
// we only want to look at top level expressions:
|| ctExpression.getParent(CtExpression.class) != null) {
return;
}

// Only look at the top statement
if (operator.getParent() instanceof CtBinaryOperator<?> parent &&
parent.getKind() == operator.getKind()) {
return;
}

var occurrences = countOccurrences(operator, operator.getKind());

var optionalVariable = occurrences.entrySet().stream()
.filter(e -> e.getValue() >= OCCURRENCE_THRESHOLDS.get(operator.getKind()))
.max(Comparator.comparingInt(Map.Entry::getValue));

optionalVariable.ifPresent(ctVariableReferenceIntegerEntry -> {
Variable variable = ctVariableReferenceIntegerEntry.getKey();
String variableName = "%s".formatted(variable.ctVariableReference().getSimpleName());
if (variable.target() != null) {
variableName = "%s.%s".formatted(
variable.target().prettyprint(),
variable.ctVariableReference().getSimpleName()
AtomicInteger plusOptimizations = new AtomicInteger();
AtomicInteger mulOptimizations = new AtomicInteger();

Fold plusFolder = new OperatorFolder(
BinaryOperatorKind.PLUS,
OCCURRENCE_THRESHOLDS.get(BinaryOperatorKind.PLUS),
(expression, count) -> {
plusOptimizations.addAndGet(1);
return SpoonUtil.createBinaryOperator(
expression,
SpoonUtil.makeLiteralNumber(expression.getType(), count),
BinaryOperatorKind.MUL
);
}

int count = ctVariableReferenceIntegerEntry.getValue();
String suggestion = "%s * %d".formatted(variableName, count);
if (operator.getKind() == BinaryOperatorKind.MUL) {
suggestion = "Math.pow(%s, %d)".formatted(variableName, count);
);

Fold mulFolder = new OperatorFolder(
BinaryOperatorKind.MUL,
OCCURRENCE_THRESHOLDS.get(BinaryOperatorKind.MUL),
(expression, count) -> {
TypeFactory typeFactory = expression.getFactory().Type();
mulOptimizations.addAndGet(1);
return SpoonUtil.createStaticInvocation(
typeFactory.get(java.lang.Math.class).getReference(),
"pow",
expression,
SpoonUtil.makeLiteralNumber(typeFactory.INTEGER_PRIMITIVE, count)
);
}
);

CtExpression<?> suggestion = new Evaluator(plusFolder).evaluate(ctExpression);
suggestion = new Evaluator(mulFolder).evaluate(suggestion);

if (plusOptimizations.get() > 0 || mulOptimizations.get() > 0) {
addLocalProblem(
operator,
ctExpression,
new LocalizedMessage(
"common-reimplementation",
Map.of("suggestion", suggestion)
),
ProblemType.REPEATED_MATH_OPERATION
);
});
}
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,13 @@ public static boolean areLiteralsEqual(
return valLeft.longValue() == valRight.longValue();
}

@SuppressWarnings("unchecked")
public static <T> CtLiteral<T> makeLiteralNumber(CtTypeReference<T> ctTypeReference, Number number) {
Object value = FoldUtils.convert(ctTypeReference, number);

return SpoonUtil.makeLiteral(ctTypeReference, (T) value);
}

/**
* Makes a new literal with the given value and type.
*
Expand Down Expand Up @@ -733,10 +740,17 @@ public static <T> CtInvocation<T> createStaticInvocation(
CtExpression<?>... parameters
) {
Factory factory = targetType.getFactory();
CtMethod<T> methodHandle = targetType.getTypeDeclaration().getMethod(
methodName,
Arrays.stream(parameters).map(SpoonUtil::getExpressionType).toArray(CtTypeReference[]::new)
);

CtMethod<T> methodHandle = null;
List<CtMethod<?>> potentialMethods = targetType.getTypeDeclaration().getMethodsByName(methodName);
if (potentialMethods.size() == 1) {
methodHandle = (CtMethod<T>) potentialMethods.get(0);
} else {
methodHandle = targetType.getTypeDeclaration().getMethod(
methodName,
Arrays.stream(parameters).map(SpoonUtil::getExpressionType).toArray(CtTypeReference[]::new)
);
}

return factory.createInvocation(
factory.createTypeAccess(methodHandle.getDeclaringType().getReference()),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,70 @@ public static void main(String[] args) {
problems.assertExhausted();
}

@Test
void testMinimumThresholdPlus() throws IOException, LinterException {
ProblemIterator problems = this.checkIterator(StringSourceInfo.fromSourceString(
JavaVersion.JAVA_17,
"Test",
"""
public class Test {
public static void main(String[] args) {
int[] array = new int[10];
int[] array2 = new int[10];
int a = 4;
int b = 3;
int c = a + a;
int d = b + b + b;
int e = array.length + array.length + array.length + array.length;
int f = array2.length + array2.length + array2.length + array2.length + array2.length;
}
}
"""
), PROBLEM_TYPES);

assertEqualsRepeat("a * 2", problems.next());
assertEqualsRepeat("b * 3", problems.next());
assertEqualsRepeat("array.length * 4", problems.next());
assertEqualsRepeat("array2.length * 5", problems.next());

problems.assertExhausted();
}

@Test
void testMinimumThresholdMul() throws IOException, LinterException {
ProblemIterator problems = this.checkIterator(StringSourceInfo.fromSourceString(
JavaVersion.JAVA_17,
"Test",
"""
public class Test {
public static void main(String[] args) {
int[] array = new int[10];
int[] array2 = new int[10];
int a = 4;
int b = 3;
int c = a * a;
int d = b * b * b;
int e = array.length * array.length * array.length * array.length;
int f = array2.length * array2.length * array2.length * array2.length * array2.length;
}
}
"""
), PROBLEM_TYPES);

assertEqualsRepeat("Math.pow(b, 3)", problems.next());
assertEqualsRepeat("Math.pow(array.length, 4)", problems.next());
assertEqualsRepeat("Math.pow(array2.length, 5)", problems.next());

problems.assertExhausted();
}


@Test
void testFalsePositiveFieldAccess() throws IOException, LinterException {
ProblemIterator problems = this.checkIterator(StringSourceInfo.fromSourceString(
Expand All @@ -77,4 +141,45 @@ public static void main(String[] args) {

problems.assertExhausted();
}

@Test
void testRecursiveSuggestion() throws IOException, LinterException {
ProblemIterator problems = this.checkIterator(StringSourceInfo.fromSourceString(
JavaVersion.JAVA_17,
"Test",
"""
public class Test {
public static void main(String[] args) {
int index = args.length;
System.out.println(args[index + index + 1]);
int b = index * 2;
System.out.println(args[(index + index) + (b + b + b)]);
}
}
"""
), PROBLEM_TYPES);
assertEqualsRepeat("System.out.println(args[index * 2 + 1])", problems.next());
assertEqualsRepeat("System.out.println(args[index * 2 + b * 3])", problems.next());

problems.assertExhausted();
}

@Test
void testMulAndPlusComplexOptimization() throws IOException, LinterException {
ProblemIterator problems = this.checkIterator(StringSourceInfo.fromSourceString(
JavaVersion.JAVA_17,
"Test",
"""
public class Test {
public static int test(int a, int b, int c, int d) {
return a + a + (a + b - c) + 1 + d * a * a * (b * b * (a * a));
}
}
"""
), PROBLEM_TYPES);
assertEqualsRepeat("a * 2 + (a + b - c) + 1 + d * Math.pow(a, 4) * (b * b)", problems.next());

problems.assertExhausted();
}
}

0 comments on commit fbba113

Please sign in to comment.