diff --git a/Src/IronPython/Compiler/Ast/CallExpression.cs b/Src/IronPython/Compiler/Ast/CallExpression.cs index 94c0de8ab..0589f35ce 100644 --- a/Src/IronPython/Compiler/Ast/CallExpression.cs +++ b/Src/IronPython/Compiler/Ast/CallExpression.cs @@ -22,18 +22,20 @@ namespace IronPython.Compiler.Ast { public class CallExpression : Expression, IInstructionProvider { - public CallExpression(Expression target, Arg[] args) { - // TODO: use two arrays (args/keywords) instead - if (args == null) throw new ArgumentNullException(nameof(args)); + public CallExpression(Expression target, IReadOnlyList? args, IReadOnlyList? kwargs) { Target = target; - Args = args; + Args = args?.ToArray() ?? Array.Empty(); + Kwargs = kwargs?.ToArray() ?? Array.Empty(); } public Expression Target { get; } - public Arg[] Args { get; } + public IReadOnlyList Args { get; } - internal IList ImplicitArgs { get; } = new List(); + public IReadOnlyList Kwargs { get; } + + internal IList ImplicitArgs => _implicitArgs ??= new List(); + private List? _implicitArgs; public bool NeedsLocalsDictionary() { if (!(Target is NameExpression nameExpr)) return false; @@ -41,21 +43,29 @@ public bool NeedsLocalsDictionary() { if (nameExpr.Name == "eval" || nameExpr.Name == "exec") return true; if (nameExpr.Name == "dir" || nameExpr.Name == "vars" || nameExpr.Name == "locals") { // could be splatting empty list or dict resulting in 0-param call which needs context - return Args.All(arg => arg.Name == "*" || arg.Name == "**"); + return Args.All(arg => arg.Name == "*") && Kwargs.All(arg => arg.Name == "**"); } return false; } public override MSAst.Expression Reduce() { - Arg[] args = Args; - if (Args.Length == 0 && ImplicitArgs.Count > 0) { - args = ImplicitArgs.ToArray(); + IReadOnlyList args = Args; + if (Args.Count == 0 && _implicitArgs != null && _implicitArgs.Count > 0) { + args = _implicitArgs; } - SplitArgs(args, out var simpleArgs, out var listArgs, out var namedArgs, out var dictArgs, out var numDict); + // count splatted list args and find the lowest index of a list argument, if any + ScanArgs(args, ArgumentType.List, out var numList, out int firstListPos); + Debug.Assert(numList == 0 || firstListPos < args.Count); + + // count splatted dictionary args and find the lowest index of a dict argument, if any + ScanArgs(Kwargs, ArgumentType.Dictionary, out var numDict, out int firstDictPos); + Debug.Assert(numDict == 0 || firstDictPos < Kwargs.Count); - Argument[] kinds = new Argument[simpleArgs.Length + Math.Min(listArgs.Length, 1) + namedArgs.Length + (dictArgs.Length - numDict) + Math.Min(numDict, 1)]; + // all list arguments and all simple arguments after the first list will be collated into a single list for the actual call + // all dictionary arguments will be merged into a single dictionary for the actual call + Argument[] kinds = new Argument[firstListPos + Math.Min(numList, 1) + (Kwargs.Count - numDict) + Math.Min(numDict, 1)]; MSAst.Expression[] values = new MSAst.Expression[2 + kinds.Length]; values[0] = Parent.LocalContext; @@ -64,31 +74,28 @@ public override MSAst.Expression Reduce() { int i = 0; // add simple arguments - foreach (var arg in simpleArgs) { - kinds[i] = arg.GetArgumentInfo(); - values[i + 2] = arg.Expression; - i++; + if (firstListPos > 0) { + foreach (var arg in args) { + if (i == firstListPos) break; + Debug.Assert(arg.GetArgumentInfo().Kind == ArgumentType.Simple); + kinds[i] = arg.GetArgumentInfo(); + values[i + 2] = arg.Expression; + i++; + } } // unpack list arguments - if (listArgs.Length > 0) { - var arg = listArgs[0]; + if (numList > 0) { + var arg = args[firstListPos]; Debug.Assert(arg.GetArgumentInfo().Kind == ArgumentType.List); kinds[i] = arg.GetArgumentInfo(); - values[i + 2] = UnpackListHelper(listArgs); + values[i + 2] = UnpackListHelper(args, firstListPos); i++; } // add named arguments - foreach (var arg in namedArgs) { - kinds[i] = arg.GetArgumentInfo(); - values[i + 2] = arg.Expression; - i++; - } - - // add named arguments specified after a dict unpack - if (dictArgs.Length != numDict) { - foreach (var arg in dictArgs) { + if (Kwargs.Count != numDict) { + foreach (var arg in Kwargs) { var info = arg.GetArgumentInfo(); if (info.Kind == ArgumentType.Named) { kinds[i] = info; @@ -99,11 +106,11 @@ public override MSAst.Expression Reduce() { } // unpack dict arguments - if (dictArgs.Length > 0) { - var arg = dictArgs[0]; + if (numDict > 0) { + var arg = Kwargs[firstDictPos]; Debug.Assert(arg.GetArgumentInfo().Kind == ArgumentType.Dictionary); kinds[i] = arg.GetArgumentInfo(); - values[i + 2] = UnpackDictHelper(Parent.LocalContext, dictArgs); + values[i + 2] = UnpackDictHelper(Parent.LocalContext, Kwargs, numDict, firstDictPos); } return Parent.Invoke( @@ -111,62 +118,33 @@ public override MSAst.Expression Reduce() { values ); - static void SplitArgs(Arg[] args, out ReadOnlySpan simpleArgs, out ReadOnlySpan listArgs, out ReadOnlySpan namedArgs, out ReadOnlySpan dictArgs, out int numDict) { - if (args.Length == 0) { - simpleArgs = default; - listArgs = default; - namedArgs = default; - dictArgs = default; - numDict = 0; - return; - } + static void ScanArgs(IReadOnlyList args, ArgumentType scanForType, out int numArgs, out int firstArgPos) { + numArgs = 0; + firstArgPos = args.Count; - int idxSimple = args.Length; - int idxList = args.Length; - int idxNamed = args.Length; - int idxDict = args.Length; - numDict = 0; + if (args.Count == 0) return; - // we want idxSimple <= idxList <= idxNamed <= idxDict - for (var i = args.Length - 1; i >= 0; i--) { + for (var i = args.Count - 1; i >= 0; i--) { var arg = args[i]; var info = arg.GetArgumentInfo(); - switch (info.Kind) { - case ArgumentType.Simple: - idxSimple = i; - break; - case ArgumentType.List: - idxList = i; - break; - case ArgumentType.Named: - idxNamed = i; - break; - case ArgumentType.Dictionary: - idxDict = i; - numDict++; - break; - default: - throw new InvalidOperationException(); + if (info.Kind == scanForType) { + firstArgPos = i; + numArgs++; } } - dictArgs = args.AsSpan(idxDict); - if (idxNamed > idxDict) idxNamed = idxDict; - namedArgs = args.AsSpan(idxNamed, idxDict - idxNamed); - if (idxList > idxNamed) idxList = idxNamed; - listArgs = args.AsSpan(idxList, idxNamed - idxList); - if (idxSimple > idxList) idxSimple = idxList; - simpleArgs = args.AsSpan(idxSimple, idxList - idxSimple); } - static MSAst.Expression UnpackListHelper(ReadOnlySpan args) { - Debug.Assert(args.Length > 0); - Debug.Assert(args[0].GetArgumentInfo().Kind == ArgumentType.List); - if (args.Length == 1) return args[0].Expression; + static MSAst.Expression UnpackListHelper(IReadOnlyList args, int firstListPos) { + Debug.Assert(args.Count > 0); + Debug.Assert(args[firstListPos].GetArgumentInfo().Kind == ArgumentType.List); + + if (args.Count - firstListPos == 1) return args[firstListPos].Expression; - var expressions = new ReadOnlyCollectionBuilder(args.Length + 2); + var expressions = new ReadOnlyCollectionBuilder(args.Count - firstListPos + 2); var varExpr = Expression.Variable(typeof(PythonList), "$coll"); expressions.Add(Expression.Assign(varExpr, Expression.Call(AstMethods.MakeEmptyList))); - foreach (var arg in args) { + for (int i = firstListPos; i < args.Count; i++) { + var arg = args[i]; if (arg.GetArgumentInfo().Kind == ArgumentType.List) { expressions.Add(Expression.Call(AstMethods.ListExtend, varExpr, AstUtils.Convert(arg.Expression, typeof(object)))); } else { @@ -177,15 +155,18 @@ static MSAst.Expression UnpackListHelper(ReadOnlySpan args) { return Expression.Block(typeof(PythonList), new MSAst.ParameterExpression[] { varExpr }, expressions); } - static MSAst.Expression UnpackDictHelper(MSAst.Expression context, ReadOnlySpan args) { - Debug.Assert(args.Length > 0); - Debug.Assert(args[0].GetArgumentInfo().Kind == ArgumentType.Dictionary); - if (args.Length == 1) return args[0].Expression; + static MSAst.Expression UnpackDictHelper(MSAst.Expression context, IReadOnlyList kwargs, int numDict, int firstDictPos) { + Debug.Assert(kwargs.Count > 0); + Debug.Assert(0 < numDict && numDict <= kwargs.Count); + Debug.Assert(kwargs[firstDictPos].GetArgumentInfo().Kind == ArgumentType.Dictionary); - var expressions = new List(args.Length + 2); + if (numDict == 1) return kwargs[firstDictPos].Expression; + + var expressions = new ReadOnlyCollectionBuilder(numDict + 2); var varExpr = Expression.Variable(typeof(PythonDictionary), "$dict"); expressions.Add(Expression.Assign(varExpr, Expression.Call(AstMethods.MakeEmptyDict))); - foreach (var arg in args) { + for (int i = firstDictPos; i < kwargs.Count; i++) { + var arg = kwargs[i]; if (arg.GetArgumentInfo().Kind == ArgumentType.Dictionary) { expressions.Add(Expression.Call(AstMethods.DictMerge, context, varExpr, arg.Expression)); } @@ -198,19 +179,24 @@ static MSAst.Expression UnpackDictHelper(MSAst.Expression context, ReadOnlySpan< #region IInstructionProvider Members void IInstructionProvider.AddInstructions(LightCompiler compiler) { - Arg[] args = Args; - if (args.Length == 0 && ImplicitArgs.Count > 0) { - args = ImplicitArgs.ToArray(); + IReadOnlyList args = Args; + if (args.Count == 0 && _implicitArgs != null && _implicitArgs.Count > 0) { + args = _implicitArgs; } - for (int i = 0; i < args.Length; i++) { + if (Kwargs.Count > 0) { + compiler.Compile(Reduce()); + return; + } + + for (int i = 0; i < args.Count; i++) { if (!args[i].GetArgumentInfo().IsSimple) { compiler.Compile(Reduce()); return; } } - switch (args.Length) { + switch (args.Count) { #region Generated Python Call Expression Instruction Switch // *** BEGIN GENERATED CODE *** @@ -441,11 +427,14 @@ public override void Walk(PythonWalker walker) { arg.Walk(walker); } } - if (ImplicitArgs.Count > 0) { + if (_implicitArgs != null && _implicitArgs.Count > 0) { foreach (Arg arg in ImplicitArgs) { arg.Walk(walker); } } + foreach (Arg arg in Kwargs) { + arg.Walk(walker); + } } walker.PostWalk(this); } diff --git a/Src/IronPython/Compiler/Ast/PythonNameBinder.cs b/Src/IronPython/Compiler/Ast/PythonNameBinder.cs index f95349544..398c33a0c 100644 --- a/Src/IronPython/Compiler/Ast/PythonNameBinder.cs +++ b/Src/IronPython/Compiler/Ast/PythonNameBinder.cs @@ -835,7 +835,7 @@ public override bool Walk(CallExpression node) { if (node.Target is NameExpression nameExpr && nameExpr.Name == "super" && _currentScope is FunctionDefinition func) { _currentScope.Reference("__class__"); - if (node.Args.Length == 0 && func.ParameterNames.Length > 0) { + if (node.Args.Count == 0 && node.Kwargs.Count == 0 && func.ParameterNames.Length > 0) { node.ImplicitArgs.Add(new Arg(new NameExpression("__class__"))); node.ImplicitArgs.Add(new Arg(new NameExpression(func.ParameterNames[0]))); } diff --git a/Src/IronPython/Compiler/Parser.cs b/Src/IronPython/Compiler/Parser.cs index e14ab6a8a..b3c42e5fd 100644 --- a/Src/IronPython/Compiler/Parser.cs +++ b/Src/IronPython/Compiler/Parser.cs @@ -1064,7 +1064,7 @@ private List ParseDecorators() { if (MaybeEat(TokenKind.LeftParenthesis)) { ParserSink?.StartParameters(GetSourceSpan()); - Arg[] args = FinishArgumentList(null); + IReadOnlyList args = FinishArgumentList(null); decorator = FinishCallExpr(decorator, args); } decorator.SetLoc(_globalParent, start, GetEnd()); @@ -1988,12 +1988,12 @@ private Expression AddTrailers(Expression ret, bool allowGeneratorExpression) { if (!allowGeneratorExpression) return ret; NextToken(); - Arg[] args = FinishArgListOrGenExpr(); + IReadOnlyList args = FinishArgListOrGenExpr(); CallExpression call; if (args != null) { call = FinishCallExpr(ret, args); } else { - call = new CallExpression(ret, new Arg[0]); + call = new CallExpression(ret, null, null); } call.SetLoc(_globalParent, ret.StartIndex, GetEnd()); @@ -2132,7 +2132,7 @@ private List ParseExprList(out bool trailingComma) { // expression "=" expression rest_of_arguments // expression "for" gen_expr_rest // - private Arg[] FinishArgListOrGenExpr() { + private IReadOnlyList FinishArgListOrGenExpr() { Arg a = null; ParserSink?.StartParameters(GetSourceSpan()); @@ -2204,7 +2204,7 @@ private void CheckUniqueArgument(List names, Arg arg) { //arglist: (argument ',')* (argument [',']| '*' expression [',' '**' expression] | '**' expression) //argument: [expression '='] expression # Really [keyword '='] expression - private Arg[] FinishArgumentList(Arg first) { + private IReadOnlyList FinishArgumentList(Arg first) { const TokenKind terminator = TokenKind.RightParenthesis; List l = new List(); @@ -2246,8 +2246,7 @@ private Arg[] FinishArgumentList(Arg first) { ParserSink?.EndParameters(GetSourceSpan()); - Arg[] ret = l.ToArray(); - return ret; + return l; } // testlist: test (',' test)* [','] @@ -2900,9 +2899,11 @@ private void PushFunction(FunctionDefinition function) { _functions.Push(function); } - private CallExpression FinishCallExpr(Expression target, params Arg[] args) { + private CallExpression FinishCallExpr(Expression target, IEnumerable args) { bool hasKeyword = false; bool hasKeywordUnpacking = false; + List posargs = null; + List kwargs = null; foreach (Arg arg in args) { if (arg.Name == null) { @@ -2911,18 +2912,26 @@ private CallExpression FinishCallExpr(Expression target, params Arg[] args) { } else if (hasKeyword) { ReportSyntaxError(arg.StartIndex, arg.EndIndex, "positional argument follows keyword argument"); } + posargs ??= new List(); + posargs.Add(arg); } else if (arg.Name == "*") { if (hasKeywordUnpacking) { ReportSyntaxError(arg.StartIndex, arg.EndIndex, "iterable argument unpacking follows keyword argument unpacking"); } + posargs ??= new List(); + posargs.Add(arg); } else if (arg.Name == "**") { hasKeywordUnpacking = true; + kwargs ??= new List(); + kwargs.Add(arg); } else { hasKeyword = true; + kwargs ??= new List(); + kwargs.Add(arg); } } - return new CallExpression(target, args); + return new CallExpression(target, posargs, kwargs); } #endregion diff --git a/Src/IronPython/Modules/_ast.cs b/Src/IronPython/Modules/_ast.cs index e7be0be03..7c4f5e858 100755 --- a/Src/IronPython/Modules/_ast.cs +++ b/Src/IronPython/Modules/_ast.cs @@ -1039,17 +1039,19 @@ public Call(expr func, PythonList args, PythonList keywords, expr starargs, expr internal Call(CallExpression call) : this() { - args = new PythonList(call.Args.Length); + args = new PythonList(call.Args.Count); keywords = new PythonList(); func = Convert(call.Target); foreach (Arg arg in call.Args) { if (arg.Name == null) args.Add(Convert(arg.Expression)); - else if (arg.Name == "*") + else // name == "*" starargs = Convert(arg.Expression); - else if (arg.Name == "**") + } + foreach (Arg arg in call.Kwargs) { + if (arg.Name == "**") kwargs = Convert(arg.Expression); - else + else // name is proper keywords.Add(new keyword(arg)); } } @@ -1057,15 +1059,16 @@ internal Call(CallExpression call) internal override AstExpression Revert() { AstExpression target = expr.Revert(func); List newArgs = new List(); + List newKwargs = new List(); foreach (expr ex in args) newArgs.Add(new Arg(expr.Revert(ex))); if (null != starargs) newArgs.Add(new Arg("*", expr.Revert(starargs))); if (null != kwargs) - newArgs.Add(new Arg("**", expr.Revert(kwargs))); + newKwargs.Add(new Arg("**", expr.Revert(kwargs))); foreach (keyword kw in keywords) - newArgs.Add(new Arg(kw.arg, expr.Revert(kw.value))); - return new CallExpression(target, newArgs.ToArray()); + newKwargs.Add(new Arg(kw.arg, expr.Revert(kw.value))); + return new CallExpression(target, newArgs, newKwargs); } public expr func { get; set; }