Skip to content

Commit

Permalink
[DROOLS-7197] fix generics introspection on superclasses during exec …
Browse files Browse the repository at this point in the history
…model generation (#5944)
  • Loading branch information
mariofusco committed May 15, 2024
1 parent 1913397 commit 785ef86
Show file tree
Hide file tree
Showing 10 changed files with 355 additions and 112 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,7 @@ String getMethodName(BinaryExpr.Operator operator) {

static SpecialComparisonCase specialComparisonFactory(TypedExpression left, TypedExpression right) {
if (isNumber(left) && !isObject(right.getRawClass()) || isNumber(right) && !isObject(left.getRawClass())) { // Don't coerce Object yet. EvaluationUtil will handle it dynamically later
Optional<Class<?>> leftCast = typeNeedsCast(left.getType());
Optional<Class<?>> rightCast = typeNeedsCast(right.getType());
if (leftCast.isPresent() || rightCast.isPresent()) {
if (typeNeedsCast(left.getType()) || typeNeedsCast(right.getType())) {
return new ComparisonWithCast(true, left, right, of(Number.class), of(Number.class));
} else {
return new NumberComparisonWithoutCast(left, right);
Expand All @@ -67,13 +65,8 @@ static SpecialComparisonCase specialComparisonFactory(TypedExpression left, Type
return new PlainEvaluation(left, right);
}

private static Optional<Class<?>> typeNeedsCast(Type t) {
boolean needCast = isObject((Class<?>)t) || isMap((Class<?>) t) || isList((Class<?>) t);
if (needCast) {
return of((Class<?>) t);
} else {
return Optional.empty();
}
private static boolean typeNeedsCast(Type t) {
return t instanceof Class && ( isObject((Class<?>)t) || isMap((Class<?>) t) || isList((Class<?>) t) );
}

private static boolean isList(Class<?> t) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.lang.reflect.ParameterizedType;
import java.lang.reflect.TypeVariable;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
Expand Down Expand Up @@ -94,6 +93,7 @@
import org.drools.mvelcompiler.CompiledExpressionResult;
import org.drools.mvelcompiler.ConstraintCompiler;
import org.drools.mvelcompiler.util.BigDecimalArgumentCoercion;
import org.drools.util.ClassUtils;
import org.drools.util.MethodUtils;
import org.drools.util.Pair;
import org.drools.util.TypeResolver;
Expand Down Expand Up @@ -125,6 +125,7 @@
import static org.drools.mvel.parser.MvelParser.parseType;
import static org.drools.mvel.parser.printer.PrintUtil.printNode;
import static org.drools.util.ClassUtils.extractGenericType;
import static org.drools.util.ClassUtils.actualTypeFromGenerics;
import static org.drools.util.ClassUtils.getTypeArgument;
import static org.drools.util.ClassUtils.getter2property;
import static org.drools.util.ClassUtils.toRawClass;
Expand Down Expand Up @@ -922,16 +923,7 @@ private TypedExpressionCursor parseMethodCallExpr(MethodCallExpr methodCallExpr,
return new TypedExpressionCursor(methodCallExpr, ((ParameterizedType) originalTypeCursor).getActualTypeArguments()[0]);
}

java.lang.reflect.Type genericReturnType = m.getGenericReturnType();
if (genericReturnType instanceof TypeVariable) {
if (originalTypeCursor instanceof ParameterizedType) {
return new TypedExpressionCursor( methodCallExpr, getActualType( rawClassCursor, ( ParameterizedType ) originalTypeCursor, ( TypeVariable ) genericReturnType ) );
} else {
return new TypedExpressionCursor(methodCallExpr, Object.class);
}
} else {
return new TypedExpressionCursor(methodCallExpr, genericReturnType);
}
return new TypedExpressionCursor(methodCallExpr, actualTypeFromGenerics(originalTypeCursor, m.getGenericReturnType(), rawClassCursor));
}

private void promoteBigDecimalParameters(MethodCallExpr methodCallExpr, Class[] argsType, Class<?>[] actualArgumentTypes) {
Expand Down Expand Up @@ -967,17 +959,6 @@ private Optional<TypedExpressionCursor> checkStartsWithMVEL(MethodCallExpr metho
}
}

private java.lang.reflect.Type getActualType(Class<?> rawClassCursor, ParameterizedType originalTypeCursor, TypeVariable genericReturnType) {
int genericPos = 0;
for (TypeVariable typeVar : rawClassCursor.getTypeParameters()) {
if (typeVar.equals( genericReturnType )) {
return originalTypeCursor.getActualTypeArguments()[genericPos];
}
genericPos++;
}
throw new RuntimeException( "Unknonw generic type " + genericReturnType + " for type " + originalTypeCursor );
}

private TypedExpressionCursor objectCreationExpr(ObjectCreationExpr objectCreationExpr) {
parseNodeArguments( objectCreationExpr );
return new TypedExpressionCursor(objectCreationExpr, getClassFromType(ruleContext.getTypeResolver(), objectCreationExpr.getType()));
Expand Down Expand Up @@ -1163,7 +1144,7 @@ private Optional<TypedExpressionCursor> drlNameExpr(Expression drlxExpr, DrlName
return of(new TypedExpressionCursor(addCastToExpression(typeWithoutDollar, fieldAccessor, false), typeOfFirstAccessor ) );
}

return of(new TypedExpressionCursor(fieldAccessor, firstAccessor.getGenericReturnType() ) );
return of( new TypedExpressionCursor(fieldAccessor, ClassUtils.actualTypeFromGenerics(originalTypeCursor, firstAccessor.getGenericReturnType()) ) );
}

Field field = DrlxParseUtil.getField( classCursor, firstName );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@

import java.util.ArrayList;
import java.util.List;
import java.util.Objects;

import org.drools.model.codegen.execmodel.domain.Address;
import org.drools.model.codegen.execmodel.domain.Person;

import org.junit.Test;
import org.kie.api.runtime.KieSession;

Expand Down Expand Up @@ -117,4 +117,173 @@ public void testClassWithGenericField() {
ksession.insert(classWithGenericField);
assertThat(ksession.fireAllRules()).isEqualTo(1);
}

@Test
public void testGenericsOnSuperclass() {
// KIE-DROOLS-5925
String str =
"import " + DieselCar.class.getCanonicalName() + ";\n " +
"dialect \"mvel\"\n" +
"\n" +
"rule \"Diesel vehicles with more than 95 kW use high-octane fuel (diesel has no octane, this is a test)\"\n" +
" when\n" +
" $v: DieselCar(motor.kw > 95, score<=0, !motor.highOctane)\n" +
" then\n" +
" System.out.println(\"Diesel vehicle with more than 95 kW: \" + $v+\", score=\"+$v.score);\n" +
" $v.engine.highOctane = true;\n" +
" update($v);\n" +
"end\n" +
"\n" +
"rule \"High-octane fuel engines newer serial numbers have slightly higher score\"\n" +
" when\n" +
" $v: DieselCar(engine.highOctane, score<=1, motor.serialNumber > 50000)\n" +
" then\n" +
" System.out.println(\"High octane engine vehicle with newer serial number: \" + $v.motor.serialNumber);\n" +
" $v.score = $v.score + 1;\n" +
" update($v);\n" +
"end";

KieSession ksession = getKieSession(str);

DieselCar vehicle1 = new DieselCar("Volkswagen", "Passat", 100);
vehicle1.setFrameMaxTorque(500);
vehicle1.getEngine().setMaxTorque(350);
vehicle1.getEngine().setSerialNumber(75_000);
vehicle1.setScore(0);

DieselCar vehicle2 = new DieselCar("Peugeot", "208", 50);
vehicle2.setFrameMaxTorque(100);
vehicle2.getEngine().setMaxTorque(200);
vehicle2.setScore(0);

ksession.insert(vehicle1);
ksession.insert(vehicle2);
assertThat(ksession.fireAllRules()).isEqualTo(3);
}

public static abstract class Vehicle<TEngine extends Engine> {

private final String maker;
private final String model;

private int score;

public Vehicle(String maker, String model) {
this.maker = Objects.requireNonNull(maker);
this.model = Objects.requireNonNull(model);
}

public String getMaker() {
return maker;
}

public String getModel() {
return model;
}

public abstract TEngine getEngine();

public TEngine getMotor() {
return getEngine();
}

public int getScore() {
return score;
}

public void setScore(int score) {
this.score = score;
}

@Override
public String toString() {
return "Vehicle{" + "maker='" + maker + '\'' + ", model='" + model + '\'' + '}';
}
}

public static abstract class Engine {

private final int kw;

public Engine(int kw) {
this.kw = kw;
}

public int getKw() {
return kw;
}

public abstract boolean isZeroEmissions();

}

public static class DieselEngine extends Engine {

// diesel has no octanes... but let's pretend it does
private boolean highOctane;

private int maxTorque;

private long serialNumber;

public DieselEngine(int kw) {
super(kw);
}

@Override
public boolean isZeroEmissions() {
return false;
}

public boolean isHighOctane() {
return highOctane;
}

public void setHighOctane(boolean highOctane) {
this.highOctane = highOctane;
}

public int getMaxTorque() {
return maxTorque;
}

public void setMaxTorque(int maxTorque) {
this.maxTorque = maxTorque;
}

public void setSerialNumber(long serialNumber) {
this.serialNumber = serialNumber;
}

public long getSerialNumber() {
return serialNumber;
}

}

public static class DieselCar extends Vehicle<DieselEngine> {
private final DieselEngine engine;

private long frameMaxTorque;



public DieselCar(String maker, String model, int kw) {
super(maker, model);
this.engine = new DieselEngine(kw);
}

@Override
public DieselEngine getEngine() {
return engine;
}

public long getFrameMaxTorque() {
return frameMaxTorque;
}

public void setFrameMaxTorque(long frameMaxTorque) {
this.frameMaxTorque = frameMaxTorque;
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -216,12 +216,14 @@ private Optional<TypedExpression> asEnum(SimpleName n) {
private Optional<TypedExpression> asPropertyAccessor(SimpleName n, VisitorContext arg) {
Optional<TypedExpression> lastTypedExpression = arg.getScope();

Optional<Type> scopeType = lastTypedExpression.filter(ListAccessExprT.class::isInstance)
.map(ListAccessExprT.class::cast)
.map(expr -> expr.getElementType())
.orElse(arg.getScopeType());
Optional<Type> propertyType = lastTypedExpression.filter(ListAccessExprT.class::isInstance)
.map(ListAccessExprT.class::cast)
.map(expr -> expr.getElementType())
.orElse(arg.getScopeType());

Optional<Method> optAccessor = scopeType.flatMap(t -> ofNullable(getAccessor(classFromType(t), n.asString())));
Optional<Type> scopeType = lastTypedExpression.flatMap(TypedExpression::getScopeType);

Optional<Method> optAccessor = propertyType.flatMap(t -> ofNullable(getAccessor(classFromType(t, scopeType.orElse(null)), n.asString())));

return map2(lastTypedExpression, optAccessor, FieldToAccessorTExpr::new);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ public Optional<Type> getType() {
return Optional.of(type);
}

@Override
public Optional<Type> getScopeType() {
return scope.getType();
}

@Override
public Node toJavaExpression() {
List<Expression> expressionArguments = this.arguments.stream()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,11 @@ public Optional<Type> getType() {
return type;
}

@Override
public Optional<Type> getScopeType() {
return scope.flatMap(TypedExpression::getScopeType);
}

@Override
public Node toJavaExpression() {
Node scopeE = scope.map(TypedExpression::toJavaExpression).orElse(null);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,9 @@ public interface TypedExpression {
Optional<Type> getType();

Node toJavaExpression();

default Optional<Type> getScopeType() {
return Optional.empty();
}
}

Loading

0 comments on commit 785ef86

Please sign in to comment.