Skip to content
Permalink
Browse files
GROOVY-10597: SC: allow spread expression(s) for variadic parameter
  • Loading branch information
eric-milles committed May 9, 2022
1 parent e752790 commit d62fb4dec213fdd8d4ac32dc10243c0f01101b4a
Showing 3 changed files with 95 additions and 98 deletions.
@@ -165,58 +165,58 @@ private void generateGetCallSiteArray() {
mv.visitEnd();
}

private void generateCreateCallSiteArray() {
List<String> callSiteInitMethods = new LinkedList<String>();
int index = 0;
int methodIndex = 0;
final int size = callSites.size();
final int maxArrayInit = 5000;
private void generateCreateCallSiteArray() {
List<String> callSiteInitMethods = new LinkedList<String>();
int index = 0;
int methodIndex = 0;
final int size = callSites.size();
final int maxArrayInit = 5000;
// create array initialization methods
while (index < size) {
methodIndex++;
String methodName = "$createCallSiteArray_" + methodIndex;
callSiteInitMethods.add(methodName);
while (index < size) {
methodIndex++;
String methodName = "$createCallSiteArray_" + methodIndex;
callSiteInitMethods.add(methodName);
MethodVisitor mv = controller.getClassVisitor().visitMethod(MOD_PRIVSS, methodName, "([Ljava/lang/String;)V", null, null);
controller.setMethodVisitor(mv);
mv.visitCode();
int methodLimit = size;
mv.visitCode();
int methodLimit = size;
// check if the next block is over the max allowed
if ((methodLimit - index) > maxArrayInit) {
methodLimit = index + maxArrayInit;
}
for (; index < methodLimit; index++) {
mv.visitVarInsn(ALOAD, 0);
mv.visitLdcInsn(index);
mv.visitLdcInsn(callSites.get(index));
mv.visitInsn(AASTORE);
}
mv.visitInsn(RETURN);
mv.visitMaxs(2,1);
mv.visitEnd();
if ((methodLimit - index) > maxArrayInit) {
methodLimit = index + maxArrayInit;
}
for (; index < methodLimit; index++) {
mv.visitVarInsn(ALOAD, 0);
mv.visitLdcInsn(index);
mv.visitLdcInsn(callSites.get(index));
mv.visitInsn(AASTORE);
}
mv.visitInsn(RETURN);
mv.visitMaxs(2,1);
mv.visitEnd();
}
// create base createCallSiteArray method
MethodVisitor mv = controller.getClassVisitor().visitMethod(MOD_PRIVSS, CREATE_CSA_METHOD, GET_CALLSITEARRAY_DESC, null, null);
controller.setMethodVisitor(mv);
mv.visitCode();
mv.visitLdcInsn(size);
mv.visitTypeInsn(ANEWARRAY, "java/lang/String");
mv.visitVarInsn(ASTORE, 0);
for (String methodName : callSiteInitMethods) {
mv.visitCode();
mv.visitLdcInsn(size);
mv.visitTypeInsn(ANEWARRAY, "java/lang/String");
mv.visitVarInsn(ASTORE, 0);
for (String methodName : callSiteInitMethods) {
mv.visitVarInsn(ALOAD, 0);
mv.visitMethodInsn(INVOKESTATIC, controller.getInternalClassName(), methodName, "([Ljava/lang/String;)V", false);
}
}

mv.visitTypeInsn(NEW, CALLSITE_ARRAY_CLASS);
mv.visitInsn(DUP);
controller.getAcg().visitClassExpression(new ClassExpression(controller.getClassNode()));
mv.visitTypeInsn(NEW, CALLSITE_ARRAY_CLASS);
mv.visitInsn(DUP);
controller.getAcg().visitClassExpression(new ClassExpression(controller.getClassNode()));

mv.visitVarInsn(ALOAD, 0);

mv.visitMethodInsn(INVOKESPECIAL, CALLSITE_ARRAY_CLASS, "<init>", "(Ljava/lang/Class;[Ljava/lang/String;)V", false);
mv.visitInsn(ARETURN);
mv.visitMaxs(0,0);
mv.visitEnd();
}
mv.visitInsn(ARETURN);
mv.visitMaxs(0,0);
mv.visitEnd();
}

private int allocateIndex(String name) {
callSites.add(name);
@@ -294,54 +294,66 @@ public void makeGetPropertySite(Expression receiver, String methodName, boolean
invokeSafe(safe, "callGetProperty", "callGetPropertySafe");
}

public void makeCallSite(Expression receiver, String message, Expression arguments, boolean safe, boolean implicitThis, boolean callCurrent, boolean callStatic) {
public void makeCallSite(final Expression receiver, final String message, final Expression arguments,
final boolean safe, final boolean implicitThis, final boolean callCurrent, final boolean callStatic) {
prepareSiteAndReceiver(receiver, message, implicitThis);

CompileStack compileStack = controller.getCompileStack();
compileStack.pushImplicitThis(implicitThis);
compileStack.pushLHS(false);
boolean constructor = message.equals(CONSTRUCTOR);
OperandStack operandStack = controller.getOperandStack();
AsmClassGenerator acg = controller.getAcg();
CompileStack cs = controller.getCompileStack();
OperandStack os = controller.getOperandStack();
MethodVisitor mv = controller.getMethodVisitor();

cs.pushLHS(false);
cs.pushImplicitThis(implicitThis);

// arguments
boolean containsSpreadExpression = AsmClassGenerator.containsSpreadExpression(arguments);
int numberOfArguments = containsSpreadExpression ? -1 : AsmClassGenerator.argumentSize(arguments);
int operandsToReplace = 1;
int numberOfArguments = AsmClassGenerator.argumentSize(arguments), operandsToReplace = 1;
if (numberOfArguments > MethodCallerMultiAdapter.MAX_ARGS || containsSpreadExpression) {
ArgumentListExpression ae = InvocationWriter.makeArgumentList(arguments);
controller.getCompileStack().pushImplicitThis(false);
ArgumentListExpression list = InvocationWriter.makeArgumentList(arguments);
cs.pushImplicitThis(false);
if (containsSpreadExpression) {
numberOfArguments = -1;
controller.getAcg().despreadList(ae.getExpressions(), true);
acg.despreadList(list.getExpressions(), true);
} else {
numberOfArguments = ae.getExpressions().size();
for (int i = 0; i < numberOfArguments; i++) {
Expression argument = ae.getExpression(i);
argument.visit(controller.getAcg());
operandStack.box();
if (argument instanceof CastExpression) controller.getAcg().loadWrapper(argument);
numberOfArguments = list.getExpressions().size();
for (Expression argument : list) {
argument.visit(acg);
os.box();
if (argument instanceof CastExpression) {
acg.loadWrapper(argument);
}
}
operandsToReplace += numberOfArguments;
}
controller.getCompileStack().popImplicitThis();
cs.popImplicitThis();
}
controller.getCompileStack().popLHS();
controller.getCompileStack().popImplicitThis();

MethodVisitor mv = controller.getMethodVisitor();
cs.popLHS();
cs.popImplicitThis();

if (numberOfArguments > 4) {
final String createArraySignature = getCreateArraySignature(numberOfArguments);
mv.visitMethodInsn(INVOKESTATIC, "org/codehaus/groovy/runtime/ArrayUtil", "createArray", createArraySignature, false);
//TODO: use pre-generated Object[]
operandStack.replace(ClassHelper.OBJECT_TYPE.makeArray(),numberOfArguments);
operandsToReplace = operandsToReplace-numberOfArguments+1;
String desc;
switch (numberOfArguments) {
case 0:
desc = ")Ljava/lang/Object;"; break;
case 1:
desc = "Ljava/lang/Object;)Ljava/lang/Object;"; break;
case 2:
desc = "Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;"; break;
case 3:
desc = "Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;"; break;
case 4:
desc = "Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;"; break;
default:
mv.visitMethodInsn(INVOKESTATIC, "org/codehaus/groovy/runtime/ArrayUtil", "createArray", getCreateArraySignature(numberOfArguments), false);
os.replace(ClassHelper.OBJECT_TYPE.makeArray(), numberOfArguments);
operandsToReplace = operandsToReplace - numberOfArguments + 1;
case -1: // spread expression case produces Object[]
desc = "[Ljava/lang/Object;)Ljava/lang/Object;";
}

final String desc = getDescForParamNum(numberOfArguments);
if (callStatic) {
mv.visitMethodInsn(INVOKEINTERFACE, CALLSITE_CLASS, "callStatic", "(Ljava/lang/Class;" + desc, true);
} else if (constructor) {
} else if (message.equals(CONSTRUCTOR)) {
mv.visitMethodInsn(INVOKEINTERFACE, CALLSITE_CLASS, "callConstructor", "(Ljava/lang/Object;" + desc, true);
} else if (callCurrent) {
mv.visitMethodInsn(INVOKEINTERFACE, CALLSITE_CLASS, "callCurrent", "(Lgroovy/lang/GroovyObject;" + desc, true);
@@ -350,24 +362,8 @@ public void makeCallSite(Expression receiver, String message, Expression argumen
} else {
mv.visitMethodInsn(INVOKEINTERFACE, CALLSITE_CLASS, "call", "(Ljava/lang/Object;" + desc, true);
}
operandStack.replace(ClassHelper.OBJECT_TYPE,operandsToReplace);
}

private static String getDescForParamNum(int numberOfArguments) {
switch (numberOfArguments) {
case 0:
return ")Ljava/lang/Object;";
case 1:
return "Ljava/lang/Object;)Ljava/lang/Object;";
case 2:
return "Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;";
case 3:
return "Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;";
case 4:
return "Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;Ljava/lang/Object;)Ljava/lang/Object;";
default:
return "[Ljava/lang/Object;)Ljava/lang/Object;";
}
os.replace(ClassHelper.OBJECT_TYPE, operandsToReplace);
}

public List<String> getCallSites() {
@@ -37,6 +37,7 @@
import org.codehaus.groovy.ast.expr.ExpressionTransformer;
import org.codehaus.groovy.ast.expr.MethodCallExpression;
import org.codehaus.groovy.ast.expr.PropertyExpression;
import org.codehaus.groovy.ast.expr.SpreadExpression;
import org.codehaus.groovy.ast.expr.TupleExpression;
import org.codehaus.groovy.ast.expr.VariableExpression;
import org.codehaus.groovy.ast.stmt.ForStatement;
@@ -430,24 +431,31 @@ protected void loadArguments(final List<Expression> argumentList, final Paramete
|| isGStringType(lastArgType) && isStringType(lastPrmType.getComponentType())))
)) {
OperandStack operandStack = controller.getOperandStack();
int stackLength = operandStack.getStackLength() + nArgs;
// first arguments/parameters as usual
for (int i = 0; i < nPrms - 1; i += 1) {
visitArgument(argumentList.get(i), parameters[i].getType());
}
// wrap remaining arguments in an array for last parameter
boolean spread = false;
List<Expression> lastArgs = new ArrayList<>();
for (int i = nPrms - 1; i < nArgs; i += 1) {
lastArgs.add(argumentList.get(i));
Expression arg = argumentList.get(i);
lastArgs.add(arg);
spread = spread || arg instanceof SpreadExpression;
}
ArrayExpression array = new ArrayExpression(lastPrmType.getComponentType(), lastArgs);
array.visit(controller.getAcg());
// adjust stack length
while (operandStack.getStackLength() < stackLength) {
operandStack.push(ClassHelper.OBJECT_TYPE);
if (spread) { // GROOVY-10597
controller.getAcg().despreadList(lastArgs, true);
operandStack.push(ClassHelper.OBJECT_TYPE.makeArray());
controller.getInvocationWriter().coerce(operandStack.getTopOperand(), lastPrmType);
} else {
controller.getAcg().visitArrayExpression(new ArrayExpression(lastPrmType.getComponentType(), lastArgs));
}
// adjust operand stack
if (nArgs == nPrms - 1) {
operandStack.remove(1);
} else {
for (int n = lastArgs.size(); n > 1; n -= 1)
operandStack.push(ClassHelper.OBJECT_TYPE);
}
} else if (nArgs == nPrms) {
for (int i = 0; i < nArgs; i += 1) {
@@ -36,13 +36,6 @@ public class MethodCallsStaticCompilationTest extends MethodCallsSTCTest impleme
}
}

@Override
void testSpreadArgsRestrictedInConstructorCall() {
shouldFail {
super.testSpreadArgsRestrictedInConstructorCall()
}
}

// GROOVY-7863
void testDoublyNestedPrivateMethodAccess() {
assertScript '''

0 comments on commit d62fb4d

Please sign in to comment.