Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
165 changes: 77 additions & 88 deletions Src/IronPython/Compiler/Ast/CallExpression.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,40 +22,50 @@
namespace IronPython.Compiler.Ast {

public class CallExpression : Expression, IInstructionProvider {
public CallExpression(Expression target, Arg[] args) {
// TODO: use two arrays (args/keywords) instead
if (args == null) throw new ArgumentNullException(nameof(args));
public CallExpression(Expression target, IReadOnlyList<Arg>? args, IReadOnlyList<Arg>? kwargs) {
Target = target;
Args = args;
Args = args?.ToArray() ?? Array.Empty<Arg>();
Kwargs = kwargs?.ToArray() ?? Array.Empty<Arg>();
}

public Expression Target { get; }

public Arg[] Args { get; }
public IReadOnlyList<Arg> Args { get; }

internal IList<Arg> ImplicitArgs { get; } = new List<Arg>();
public IReadOnlyList<Arg> Kwargs { get; }

internal IList<Arg> ImplicitArgs => _implicitArgs ??= new List<Arg>();
private List<Arg>? _implicitArgs;

public bool NeedsLocalsDictionary() {
if (!(Target is NameExpression nameExpr)) return false;

if (nameExpr.Name == "eval" || nameExpr.Name == "exec") return true;
if (nameExpr.Name == "dir" || nameExpr.Name == "vars" || nameExpr.Name == "locals") {
// could be splatting empty list or dict resulting in 0-param call which needs context
return Args.All(arg => arg.Name == "*" || arg.Name == "**");
return Args.All(arg => arg.Name == "*") && Kwargs.All(arg => arg.Name == "**");
}

return false;
}

public override MSAst.Expression Reduce() {
Arg[] args = Args;
if (Args.Length == 0 && ImplicitArgs.Count > 0) {
args = ImplicitArgs.ToArray();
IReadOnlyList<Arg> args = Args;
if (Args.Count == 0 && _implicitArgs != null && _implicitArgs.Count > 0) {
args = _implicitArgs;
}

SplitArgs(args, out var simpleArgs, out var listArgs, out var namedArgs, out var dictArgs, out var numDict);
// count splatted list args and find the lowest index of a list argument, if any
ScanArgs(args, ArgumentType.List, out var numList, out int firstListPos);
Debug.Assert(numList == 0 || firstListPos < args.Count);

// count splatted dictionary args and find the lowest index of a dict argument, if any
ScanArgs(Kwargs, ArgumentType.Dictionary, out var numDict, out int firstDictPos);
Debug.Assert(numDict == 0 || firstDictPos < Kwargs.Count);

Argument[] kinds = new Argument[simpleArgs.Length + Math.Min(listArgs.Length, 1) + namedArgs.Length + (dictArgs.Length - numDict) + Math.Min(numDict, 1)];
// all list arguments and all simple arguments after the first list will be collated into a single list for the actual call
// all dictionary arguments will be merged into a single dictionary for the actual call
Argument[] kinds = new Argument[firstListPos + Math.Min(numList, 1) + (Kwargs.Count - numDict) + Math.Min(numDict, 1)];
MSAst.Expression[] values = new MSAst.Expression[2 + kinds.Length];

values[0] = Parent.LocalContext;
Expand All @@ -64,31 +74,28 @@ public override MSAst.Expression Reduce() {
int i = 0;

// add simple arguments
foreach (var arg in simpleArgs) {
kinds[i] = arg.GetArgumentInfo();
values[i + 2] = arg.Expression;
i++;
if (firstListPos > 0) {
foreach (var arg in args) {
if (i == firstListPos) break;
Debug.Assert(arg.GetArgumentInfo().Kind == ArgumentType.Simple);
kinds[i] = arg.GetArgumentInfo();
values[i + 2] = arg.Expression;
i++;
}
}

// unpack list arguments
if (listArgs.Length > 0) {
var arg = listArgs[0];
if (numList > 0) {
var arg = args[firstListPos];
Debug.Assert(arg.GetArgumentInfo().Kind == ArgumentType.List);
kinds[i] = arg.GetArgumentInfo();
values[i + 2] = UnpackListHelper(listArgs);
values[i + 2] = UnpackListHelper(args, firstListPos);
i++;
}

// add named arguments
foreach (var arg in namedArgs) {
kinds[i] = arg.GetArgumentInfo();
values[i + 2] = arg.Expression;
i++;
}

// add named arguments specified after a dict unpack
if (dictArgs.Length != numDict) {
foreach (var arg in dictArgs) {
if (Kwargs.Count != numDict) {
foreach (var arg in Kwargs) {
var info = arg.GetArgumentInfo();
if (info.Kind == ArgumentType.Named) {
kinds[i] = info;
Expand All @@ -99,74 +106,45 @@ public override MSAst.Expression Reduce() {
}

// unpack dict arguments
if (dictArgs.Length > 0) {
var arg = dictArgs[0];
if (numDict > 0) {
var arg = Kwargs[firstDictPos];
Debug.Assert(arg.GetArgumentInfo().Kind == ArgumentType.Dictionary);
kinds[i] = arg.GetArgumentInfo();
values[i + 2] = UnpackDictHelper(Parent.LocalContext, dictArgs);
values[i + 2] = UnpackDictHelper(Parent.LocalContext, Kwargs, numDict, firstDictPos);
}

return Parent.Invoke(
new CallSignature(kinds),
values
);

static void SplitArgs(Arg[] args, out ReadOnlySpan<Arg> simpleArgs, out ReadOnlySpan<Arg> listArgs, out ReadOnlySpan<Arg> namedArgs, out ReadOnlySpan<Arg> dictArgs, out int numDict) {
if (args.Length == 0) {
simpleArgs = default;
listArgs = default;
namedArgs = default;
dictArgs = default;
numDict = 0;
return;
}
static void ScanArgs(IReadOnlyList<Arg> args, ArgumentType scanForType, out int numArgs, out int firstArgPos) {
numArgs = 0;
firstArgPos = args.Count;

int idxSimple = args.Length;
int idxList = args.Length;
int idxNamed = args.Length;
int idxDict = args.Length;
numDict = 0;
if (args.Count == 0) return;

// we want idxSimple <= idxList <= idxNamed <= idxDict
for (var i = args.Length - 1; i >= 0; i--) {
for (var i = args.Count - 1; i >= 0; i--) {
var arg = args[i];
var info = arg.GetArgumentInfo();
switch (info.Kind) {
case ArgumentType.Simple:
idxSimple = i;
break;
case ArgumentType.List:
idxList = i;
break;
case ArgumentType.Named:
idxNamed = i;
break;
case ArgumentType.Dictionary:
idxDict = i;
numDict++;
break;
default:
throw new InvalidOperationException();
if (info.Kind == scanForType) {
firstArgPos = i;
numArgs++;
}
}
dictArgs = args.AsSpan(idxDict);
if (idxNamed > idxDict) idxNamed = idxDict;
namedArgs = args.AsSpan(idxNamed, idxDict - idxNamed);
if (idxList > idxNamed) idxList = idxNamed;
listArgs = args.AsSpan(idxList, idxNamed - idxList);
if (idxSimple > idxList) idxSimple = idxList;
simpleArgs = args.AsSpan(idxSimple, idxList - idxSimple);
}

static MSAst.Expression UnpackListHelper(ReadOnlySpan<Arg> args) {
Debug.Assert(args.Length > 0);
Debug.Assert(args[0].GetArgumentInfo().Kind == ArgumentType.List);
if (args.Length == 1) return args[0].Expression;
static MSAst.Expression UnpackListHelper(IReadOnlyList<Arg> args, int firstListPos) {
Debug.Assert(args.Count > 0);
Debug.Assert(args[firstListPos].GetArgumentInfo().Kind == ArgumentType.List);

if (args.Count - firstListPos == 1) return args[firstListPos].Expression;

var expressions = new ReadOnlyCollectionBuilder<MSAst.Expression>(args.Length + 2);
var expressions = new ReadOnlyCollectionBuilder<MSAst.Expression>(args.Count - firstListPos + 2);
var varExpr = Expression.Variable(typeof(PythonList), "$coll");
expressions.Add(Expression.Assign(varExpr, Expression.Call(AstMethods.MakeEmptyList)));
foreach (var arg in args) {
for (int i = firstListPos; i < args.Count; i++) {
var arg = args[i];
if (arg.GetArgumentInfo().Kind == ArgumentType.List) {
expressions.Add(Expression.Call(AstMethods.ListExtend, varExpr, AstUtils.Convert(arg.Expression, typeof(object))));
} else {
Expand All @@ -177,15 +155,18 @@ static MSAst.Expression UnpackListHelper(ReadOnlySpan<Arg> args) {
return Expression.Block(typeof(PythonList), new MSAst.ParameterExpression[] { varExpr }, expressions);
}

static MSAst.Expression UnpackDictHelper(MSAst.Expression context, ReadOnlySpan<Arg> args) {
Debug.Assert(args.Length > 0);
Debug.Assert(args[0].GetArgumentInfo().Kind == ArgumentType.Dictionary);
if (args.Length == 1) return args[0].Expression;
static MSAst.Expression UnpackDictHelper(MSAst.Expression context, IReadOnlyList<Arg> kwargs, int numDict, int firstDictPos) {
Debug.Assert(kwargs.Count > 0);
Debug.Assert(0 < numDict && numDict <= kwargs.Count);
Debug.Assert(kwargs[firstDictPos].GetArgumentInfo().Kind == ArgumentType.Dictionary);

var expressions = new List<MSAst.Expression>(args.Length + 2);
if (numDict == 1) return kwargs[firstDictPos].Expression;

var expressions = new ReadOnlyCollectionBuilder<MSAst.Expression>(numDict + 2);
var varExpr = Expression.Variable(typeof(PythonDictionary), "$dict");
expressions.Add(Expression.Assign(varExpr, Expression.Call(AstMethods.MakeEmptyDict)));
foreach (var arg in args) {
for (int i = firstDictPos; i < kwargs.Count; i++) {
var arg = kwargs[i];
if (arg.GetArgumentInfo().Kind == ArgumentType.Dictionary) {
expressions.Add(Expression.Call(AstMethods.DictMerge, context, varExpr, arg.Expression));
}
Expand All @@ -198,19 +179,24 @@ static MSAst.Expression UnpackDictHelper(MSAst.Expression context, ReadOnlySpan<
#region IInstructionProvider Members

void IInstructionProvider.AddInstructions(LightCompiler compiler) {
Arg[] args = Args;
if (args.Length == 0 && ImplicitArgs.Count > 0) {
args = ImplicitArgs.ToArray();
IReadOnlyList<Arg> args = Args;
if (args.Count == 0 && _implicitArgs != null && _implicitArgs.Count > 0) {
args = _implicitArgs;
}

for (int i = 0; i < args.Length; i++) {
if (Kwargs.Count > 0) {
compiler.Compile(Reduce());
return;
}

for (int i = 0; i < args.Count; i++) {
if (!args[i].GetArgumentInfo().IsSimple) {
compiler.Compile(Reduce());
return;
}
}

switch (args.Length) {
switch (args.Count) {
#region Generated Python Call Expression Instruction Switch

// *** BEGIN GENERATED CODE ***
Expand Down Expand Up @@ -441,11 +427,14 @@ public override void Walk(PythonWalker walker) {
arg.Walk(walker);
}
}
if (ImplicitArgs.Count > 0) {
if (_implicitArgs != null && _implicitArgs.Count > 0) {
foreach (Arg arg in ImplicitArgs) {
arg.Walk(walker);
}
}
foreach (Arg arg in Kwargs) {
arg.Walk(walker);
}
}
walker.PostWalk(this);
}
Expand Down
2 changes: 1 addition & 1 deletion Src/IronPython/Compiler/Ast/PythonNameBinder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -835,7 +835,7 @@ public override bool Walk(CallExpression node) {

if (node.Target is NameExpression nameExpr && nameExpr.Name == "super" && _currentScope is FunctionDefinition func) {
_currentScope.Reference("__class__");
if (node.Args.Length == 0 && func.ParameterNames.Length > 0) {
if (node.Args.Count == 0 && node.Kwargs.Count == 0 && func.ParameterNames.Length > 0) {
node.ImplicitArgs.Add(new Arg(new NameExpression("__class__")));
node.ImplicitArgs.Add(new Arg(new NameExpression(func.ParameterNames[0])));
}
Expand Down
27 changes: 18 additions & 9 deletions Src/IronPython/Compiler/Parser.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1064,7 +1064,7 @@ private List<Expression> ParseDecorators() {

if (MaybeEat(TokenKind.LeftParenthesis)) {
ParserSink?.StartParameters(GetSourceSpan());
Arg[] args = FinishArgumentList(null);
IReadOnlyList<Arg> args = FinishArgumentList(null);
decorator = FinishCallExpr(decorator, args);
}
decorator.SetLoc(_globalParent, start, GetEnd());
Expand Down Expand Up @@ -1988,12 +1988,12 @@ private Expression AddTrailers(Expression ret, bool allowGeneratorExpression) {
if (!allowGeneratorExpression) return ret;

NextToken();
Arg[] args = FinishArgListOrGenExpr();
IReadOnlyList<Arg> args = FinishArgListOrGenExpr();
CallExpression call;
if (args != null) {
call = FinishCallExpr(ret, args);
} else {
call = new CallExpression(ret, new Arg[0]);
call = new CallExpression(ret, null, null);
}

call.SetLoc(_globalParent, ret.StartIndex, GetEnd());
Expand Down Expand Up @@ -2132,7 +2132,7 @@ private List<Expression> ParseExprList(out bool trailingComma) {
// expression "=" expression rest_of_arguments
// expression "for" gen_expr_rest
//
private Arg[] FinishArgListOrGenExpr() {
private IReadOnlyList<Arg> FinishArgListOrGenExpr() {
Arg a = null;

ParserSink?.StartParameters(GetSourceSpan());
Expand Down Expand Up @@ -2204,7 +2204,7 @@ private void CheckUniqueArgument(List<Arg> names, Arg arg) {

//arglist: (argument ',')* (argument [',']| '*' expression [',' '**' expression] | '**' expression)
//argument: [expression '='] expression # Really [keyword '='] expression
private Arg[] FinishArgumentList(Arg first) {
private IReadOnlyList<Arg> FinishArgumentList(Arg first) {
const TokenKind terminator = TokenKind.RightParenthesis;
List<Arg> l = new List<Arg>();

Expand Down Expand Up @@ -2246,8 +2246,7 @@ private Arg[] FinishArgumentList(Arg first) {

ParserSink?.EndParameters(GetSourceSpan());

Arg[] ret = l.ToArray();
return ret;
return l;
}

// testlist: test (',' test)* [',']
Expand Down Expand Up @@ -2900,9 +2899,11 @@ private void PushFunction(FunctionDefinition function) {
_functions.Push(function);
}

private CallExpression FinishCallExpr(Expression target, params Arg[] args) {
private CallExpression FinishCallExpr(Expression target, IEnumerable<Arg> args) {
bool hasKeyword = false;
bool hasKeywordUnpacking = false;
List<Arg> posargs = null;
List<Arg> kwargs = null;

foreach (Arg arg in args) {
if (arg.Name == null) {
Expand All @@ -2911,18 +2912,26 @@ private CallExpression FinishCallExpr(Expression target, params Arg[] args) {
} else if (hasKeyword) {
ReportSyntaxError(arg.StartIndex, arg.EndIndex, "positional argument follows keyword argument");
}
posargs ??= new List<Arg>();
posargs.Add(arg);
} else if (arg.Name == "*") {
if (hasKeywordUnpacking) {
ReportSyntaxError(arg.StartIndex, arg.EndIndex, "iterable argument unpacking follows keyword argument unpacking");
}
posargs ??= new List<Arg>();
posargs.Add(arg);
} else if (arg.Name == "**") {
hasKeywordUnpacking = true;
kwargs ??= new List<Arg>();
kwargs.Add(arg);
} else {
hasKeyword = true;
kwargs ??= new List<Arg>();
kwargs.Add(arg);
}
}

return new CallExpression(target, args);
return new CallExpression(target, posargs, kwargs);
}

#endregion
Expand Down
Loading