diff --git a/Src/IronPython/Compiler/Ast/AstMethods.cs b/Src/IronPython/Compiler/Ast/AstMethods.cs index 6d8baac24..bc070891f 100644 --- a/Src/IronPython/Compiler/Ast/AstMethods.cs +++ b/Src/IronPython/Compiler/Ast/AstMethods.cs @@ -20,7 +20,7 @@ internal static class AstMethods { public static readonly MethodInfo IsTrue = GetMethod((Func)PythonOps.IsTrue); public static readonly MethodInfo RaiseAssertionError = GetMethod((Action)PythonOps.RaiseAssertionError); public static readonly MethodInfo RaiseAssertionErrorNoMessage = GetMethod((Action)PythonOps.RaiseAssertionError); - public static readonly MethodInfo MakeClass = GetMethod((Func, CodeContext, string, object[], object, string, object>)PythonOps.MakeClass); + public static readonly MethodInfo MakeClass = GetMethod((Func, CodeContext, string, PythonTuple, object, string, object>)PythonOps.MakeClass); public static readonly MethodInfo PrintExpressionValue = GetMethod((Action)PythonOps.PrintExpressionValue); public static readonly MethodInfo ImportWithNames = GetMethod((Func)PythonOps.ImportWithNames); public static readonly MethodInfo ImportFrom = GetMethod((Func)PythonOps.ImportFrom); @@ -45,6 +45,7 @@ internal static class AstMethods { public static readonly MethodInfo CheckException = GetMethod((Func)PythonOps.CheckException); public static readonly MethodInfo SetCurrentException = GetMethod((Func)PythonOps.SetCurrentException); public static readonly MethodInfo MakeTuple = GetMethod((Func)PythonOps.MakeTuple); + public static readonly MethodInfo MakeEmptyTuple = GetMethod((Func)PythonOps.MakeEmptyTuple); public static readonly MethodInfo IsNot = GetMethod((Func)PythonOps.IsNot); public static readonly MethodInfo Is = GetMethod((Func)PythonOps.Is); public static readonly MethodInfo ImportTop = GetMethod((Func)PythonOps.ImportTop); diff --git a/Src/IronPython/Compiler/Ast/CallExpression.cs b/Src/IronPython/Compiler/Ast/CallExpression.cs index 9e933624f..f0f1a740a 100644 --- a/Src/IronPython/Compiler/Ast/CallExpression.cs +++ b/Src/IronPython/Compiler/Ast/CallExpression.cs @@ -134,6 +134,7 @@ static void ScanArgs(IReadOnlyList args, ArgumentType scanForType, out int } } + // Compare to: ClassDefinition.Reduce.__UnpackBasesHelper static MSAst.Expression UnpackListHelper(IReadOnlyList args, int firstListPos) { Debug.Assert(args.Count > 0); Debug.Assert(args[firstListPos].ArgumentInfo.Kind == ArgumentType.List); diff --git a/Src/IronPython/Compiler/Ast/ClassDefinition.cs b/Src/IronPython/Compiler/Ast/ClassDefinition.cs index c3e788fd0..2b2ec7019 100644 --- a/Src/IronPython/Compiler/Ast/ClassDefinition.cs +++ b/Src/IronPython/Compiler/Ast/ClassDefinition.cs @@ -12,6 +12,7 @@ using IronPython.Runtime; using Microsoft.Scripting; +using Microsoft.Scripting.Actions; using Microsoft.Scripting.Utils; using AstUtils = Microsoft.Scripting.Ast.Utils; @@ -23,7 +24,7 @@ namespace IronPython.Compiler.Ast { public class ClassDefinition : ScopeStatement { private readonly string _name; - private readonly Expression[] _bases; + private readonly Arg[] _bases; private readonly Arg[] _keywords; private LightLambdaExpression _dlrBody; // the transformed body including all of our initialization, etc... @@ -33,15 +34,12 @@ public class ClassDefinition : ScopeStatement { private static readonly MSAst.ParameterExpression _parentContextParam = Ast.Parameter(typeof(CodeContext), "$parentContext"); private static readonly MSAst.Expression _tupleExpression = MSAst.Expression.Call(AstMethods.GetClosureTupleFromContext, _parentContextParam); - public ClassDefinition(string name, Expression[] bases, Arg[] keywords, Statement body = null) { - ContractUtils.RequiresNotNullItems(bases, nameof(bases)); - ContractUtils.RequiresNotNullItems(keywords, nameof(keywords)); - + public ClassDefinition(string name, IReadOnlyList bases, IReadOnlyList keywords, Statement body = null) { _name = name; - _bases = bases; - _keywords = keywords; + _bases = bases?.ToArray() ?? Array.Empty(); + _keywords = keywords?.ToArray() ?? Array.Empty(); Body = body; - Metaclass = keywords.Where(arg => arg.Name == "metaclass").Select(arg => arg.Expression).FirstOrDefault(); + Metaclass = _keywords.Where(arg => arg.Name == "metaclass").Select(arg => arg.Expression).FirstOrDefault(); } public SourceLocation Header => GlobalParent.IndexToLocation(HeaderIndex); @@ -50,7 +48,7 @@ public ClassDefinition(string name, Expression[] bases, Arg[] keywords, Statemen public override string Name => _name; - public IReadOnlyList Bases => _bases; + public IReadOnlyList Bases => _bases; public IReadOnlyList Keywords => _keywords; @@ -181,10 +179,7 @@ public override MSAst.Expression Reduce() { lambda, Parent.LocalContext, AstUtils.Constant(_name), - Ast.NewArrayInit( - typeof(object), - ToObjectArray(_bases) - ), + UnpackBasesHelper(_bases), Metaclass is null ? AstUtils.Constant(null, typeof(object)) : AstUtils.Convert(Metaclass, typeof(object)), AstUtils.Constant(FindSelfNames()) ); @@ -198,6 +193,33 @@ public override MSAst.Expression Reduce() { GlobalParent.IndexToLocation(HeaderIndex) ) ); + + // Compare to: CallExpression.Reduce.__UnpackListHelper + static MSAst.Expression UnpackBasesHelper(IReadOnlyList bases) { + if (bases.Count == 0) { + return Expression.Call(AstMethods.MakeEmptyTuple); + } else if (bases.All(arg => arg.ArgumentInfo.Kind is ArgumentType.Simple)) { + return Expression.Call(AstMethods.MakeTuple, + Expression.NewArrayInit( + typeof(object), + ToObjectArray(bases.Select(arg => arg.Expression).ToList()) + ) + ); + } else { + var expressions = new ReadOnlyCollectionBuilder(bases.Count + 2); + var varExpr = Expression.Variable(typeof(PythonList), "$coll"); + expressions.Add(Expression.Assign(varExpr, Expression.Call(AstMethods.MakeEmptyList))); + foreach (var arg in bases) { + if (arg.ArgumentInfo.Kind == ArgumentType.List) { + expressions.Add(Expression.Call(AstMethods.ListExtend, varExpr, AstUtils.Convert(arg.Expression, typeof(object)))); + } else { + expressions.Add(Expression.Call(AstMethods.ListAppend, varExpr, AstUtils.Convert(arg.Expression, typeof(object)))); + } + } + expressions.Add(Expression.Call(AstMethods.ListToTuple, varExpr)); + return Expression.Block(typeof(PythonTuple), new MSAst.ParameterExpression[] { varExpr }, expressions); + } + } } private Microsoft.Scripting.Ast.LightExpression> MakeClassBody() { @@ -297,10 +319,11 @@ public override void Walk(PythonWalker walker) { decorator.Walk(walker); } } - if (_bases != null) { - foreach (Expression b in _bases) { - b.Walk(walker); - } + foreach (Arg b in _bases) { + b.Walk(walker); + } + foreach (Arg b in _keywords) { + b.Walk(walker); } Body?.Walk(walker); } diff --git a/Src/IronPython/Compiler/Ast/FlowChecker.cs b/Src/IronPython/Compiler/Ast/FlowChecker.cs index 633075c46..a90af5bcb 100644 --- a/Src/IronPython/Compiler/Ast/FlowChecker.cs +++ b/Src/IronPython/Compiler/Ast/FlowChecker.cs @@ -328,7 +328,10 @@ public override bool Walk(ClassDefinition node) { } else { // analyze the class definition itself (it is visited while analyzing parent scope): Define(node.Name); - foreach (Expression e in node.Bases) { + foreach (Arg e in node.Bases) { + e.Walk(this); + } + foreach (Arg e in node.Keywords) { e.Walk(this); } return false; diff --git a/Src/IronPython/Compiler/Ast/PythonNameBinder.cs b/Src/IronPython/Compiler/Ast/PythonNameBinder.cs index 398c33a0c..6bd8ddce3 100644 --- a/Src/IronPython/Compiler/Ast/PythonNameBinder.cs +++ b/Src/IronPython/Compiler/Ast/PythonNameBinder.cs @@ -206,7 +206,7 @@ public override bool Walk(ClassDefinition node) { node.PythonVariable = DefineName(node.Name); // Base references are in the outer context - foreach (Expression b in node.Bases) b.Walk(this); + foreach (Arg b in node.Bases) b.Walk(this); foreach (Arg a in node.Keywords) a.Walk(this); diff --git a/Src/IronPython/Compiler/Parser.cs b/Src/IronPython/Compiler/Parser.cs index 900572419..4e45ec03e 100644 --- a/Src/IronPython/Compiler/Parser.cs +++ b/Src/IronPython/Compiler/Parser.cs @@ -1006,27 +1006,21 @@ private ClassDefinition ParseClassDef() { string name = ReadName(); if (name == null) { // no name, assume there's no class. - return new ClassDefinition(null, new Expression[0], new Arg[0], ErrorStmt()); + return new ClassDefinition(null, null, null, ErrorStmt()); } - var bases = new List(); - var keywords = new List(); + List bases = null; + List keywords = null; if (MaybeEat(TokenKind.LeftParenthesis)) { - foreach (var arg in FinishArgumentList(null)) { - var info = arg.ArgumentInfo; - if (info.Kind == Microsoft.Scripting.Actions.ArgumentType.Simple) { - bases.Add(arg.Expression); - } else if (info.Kind == Microsoft.Scripting.Actions.ArgumentType.Named) { - keywords.Add(arg); - } - } + IReadOnlyList args = FinishArgumentList(null); + SplitAndValidateArguments(args, out bases, out keywords); } var mid = GetEnd(); // Save private prefix string savedPrefix = SetPrivatePrefix(name); - var ret = new ClassDefinition(name, bases.ToArray(), keywords.ToArray()); + var ret = new ClassDefinition(name, bases, keywords); PushClass(ret); // Parse the class body @@ -1989,13 +1983,7 @@ private Expression AddTrailers(Expression ret, bool allowGeneratorExpression) { NextToken(); IReadOnlyList args = FinishArgListOrGenExpr(); - CallExpression call; - if (args != null) { - call = FinishCallExpr(ret, args); - } else { - call = new CallExpression(ret, null, null); - } - + CallExpression call = FinishCallExpr(ret, args); call.SetLoc(_globalParent, ret.StartIndex, GetEnd()); ret = call; break; @@ -2900,11 +2888,22 @@ private void PushFunction(FunctionDefinition function) { } private CallExpression FinishCallExpr(Expression target, IEnumerable args) { - bool hasKeyword = false; - bool hasKeywordUnpacking = false; List posargs = null; List kwargs = null; + if (args is not null) { + SplitAndValidateArguments(args, out posargs, out kwargs); + } + + return new CallExpression(target, posargs, kwargs); + } + + private void SplitAndValidateArguments(IEnumerable args, out List posargs, out List kwargs) { + bool hasKeyword = false; + bool hasKeywordUnpacking = false; + + posargs = kwargs = null; + foreach (Arg arg in args) { if (arg.Name == null) { if (hasKeywordUnpacking) { @@ -2930,8 +2929,6 @@ private CallExpression FinishCallExpr(Expression target, IEnumerable args) kwargs.Add(arg); } } - - return new CallExpression(target, posargs, kwargs); } #endregion diff --git a/Src/IronPython/Modules/_ast.cs b/Src/IronPython/Modules/_ast.cs index 7c4f5e858..7c3fcdadf 100755 --- a/Src/IronPython/Modules/_ast.cs +++ b/Src/IronPython/Modules/_ast.cs @@ -1064,10 +1064,10 @@ internal override AstExpression Revert() { newArgs.Add(new Arg(expr.Revert(ex))); if (null != starargs) newArgs.Add(new Arg("*", expr.Revert(starargs))); - if (null != kwargs) - newKwargs.Add(new Arg("**", expr.Revert(kwargs))); foreach (keyword kw in keywords) newKwargs.Add(new Arg(kw.arg, expr.Revert(kw.value))); + if (null != kwargs) + newKwargs.Add(new Arg("**", expr.Revert(kwargs))); return new CallExpression(target, newArgs, newKwargs); } @@ -1106,8 +1106,19 @@ internal ClassDef(ClassDefinition def) : this() { name = def.Name; bases = new PythonList(def.Bases.Count); - foreach (AstExpression expr in def.Bases) - bases.Add(Convert(expr)); + foreach (Arg arg in def.Bases) { + if (arg.Name == null) + bases.Add(Convert(arg.Expression)); + else // name == "*" + starargs = Convert(arg.Expression); + } + keywords = new PythonList(def.Keywords.Count); + foreach (Arg arg in def.Keywords) { + if (arg.Name == "**") + kwargs = Convert(arg.Expression); + else // name is proper + keywords.Add(new keyword(arg)); + } body = ConvertStatements(def.Body); if (def.Decorators != null) { decorator_list = new PythonList(def.Decorators.Count); @@ -1116,18 +1127,20 @@ internal ClassDef(ClassDefinition def) } else { decorator_list = new PythonList(0); } - if (def.Keywords != null) { - keywords = new PythonList(def.Keywords.Count); - foreach (Arg arg in def.Keywords) - keywords.AddNoLock(new keyword(arg)); - } else { - keywords = new PythonList(0); - } } internal override Statement Revert() { - var newBases = expr.RevertExprs(bases); - var newKeywords = keywords.Cast().Select(kw => new Arg(kw.arg, expr.Revert(kw.value))).ToArray(); + List newBases = new List(); + List newKeywords = new List(); + foreach (expr ex in bases) + newBases.Add(new Arg(expr.Revert(ex))); + if (null != starargs) + newBases.Add(new Arg("*", expr.Revert(starargs))); + foreach (keyword kw in keywords) + newKeywords.Add(new Arg(kw.arg, expr.Revert(kw.value))); + if (null != kwargs) + newKeywords.Add(new Arg("**", expr.Revert(kwargs))); + ClassDefinition cd = new ClassDefinition(name, newBases, newKeywords, RevertStmts(body)); if (decorator_list.Count != 0) cd.Decorators = expr.RevertExprs(decorator_list); diff --git a/Src/IronPython/Runtime/Operations/PythonOps.cs b/Src/IronPython/Runtime/Operations/PythonOps.cs index 0b88f6249..4f3d13189 100644 --- a/Src/IronPython/Runtime/Operations/PythonOps.cs +++ b/Src/IronPython/Runtime/Operations/PythonOps.cs @@ -1321,7 +1321,7 @@ public static void InitializeForFinalization(CodeContext/*!*/ context, object ne return classdict; } - public static object MakeClass(FunctionCode funcCode, Func body, CodeContext/*!*/ parentContext, string name, object[] bases, object metaclass, string selfNames) { + public static object MakeClass(FunctionCode funcCode, Func body, CodeContext/*!*/ parentContext, string name, PythonTuple bases, object metaclass, string selfNames) { Func func = GetClassCode(parentContext, funcCode, body); return MakeClass(parentContext, name, bases, metaclass, selfNames, func(parentContext).Dict); @@ -1342,11 +1342,11 @@ private static Func GetClassCode(CodeContext/*!*/ cont } } - private static object MakeClass(CodeContext/*!*/ context, string name, object[] bases, object metaclass, string selfNames, PythonDictionary vars) { - foreach (object dt in bases) { + private static object MakeClass(CodeContext/*!*/ context, string name, PythonTuple bases, object metaclass, string selfNames, PythonDictionary vars) { + foreach (object? dt in bases) { if (dt is TypeGroup) { - object[] newBases = new object[bases.Length]; - for (int i = 0; i < bases.Length; i++) { + object?[] newBases = new object[bases.Count]; + for (int i = 0; i < bases.Count; i++) { if (bases[i] is TypeGroup tc) { if (!tc.TryGetNonGenericType(out Type nonGenericType)) { throw PythonOps.TypeError("cannot derive from open generic types {0}", Repr(context, tc)); @@ -1356,7 +1356,7 @@ private static object MakeClass(CodeContext/*!*/ context, string name, object[] newBases[i] = bases[i]; } } - bases = newBases; + bases = PythonTuple.MakeTuple(newBases); break; } else if (dt is PythonType pt) { if (pt.Equals(PythonType.GetPythonType(typeof(Enum))) || pt.Equals(PythonType.GetPythonType(typeof(Array))) @@ -1367,20 +1367,18 @@ private static object MakeClass(CodeContext/*!*/ context, string name, object[] } } - PythonTuple tupleBases = PythonTuple.MakeTuple(bases); - if (metaclass is null) { // this makes sure that object is a base - if (tupleBases.Count == 0) { - tupleBases = PythonTuple.MakeTuple(DynamicHelpers.GetPythonTypeFromType(typeof(object))); + if (bases.Count == 0) { + bases = PythonTuple.MakeTuple(DynamicHelpers.GetPythonTypeFromType(typeof(object))); } - return PythonType.__new__(context, TypeCache.PythonType, name, tupleBases, vars, selfNames); + return PythonType.__new__(context, TypeCache.PythonType, name, bases, vars, selfNames); } object? classdict = vars; if (metaclass is PythonType) { - classdict = CallPrepare(context, (PythonType)metaclass, name, tupleBases, vars); + classdict = CallPrepare(context, (PythonType)metaclass, name, bases, vars); } // eg: @@ -1395,7 +1393,7 @@ private static object MakeClass(CodeContext/*!*/ context, string name, object[] context, metaclass, name, - tupleBases, + bases, classdict ); @@ -1486,6 +1484,15 @@ public static PythonTuple MakeTupleFromSequence(object items) { return PythonTuple.Make(items); } + /// + /// Python runtime helper to create an instance of an empty Tuple + /// + [NoSideEffects] + [EditorBrowsable(EditorBrowsableState.Never)] + public static PythonTuple MakeEmptyTuple() { + return PythonTuple.MakeTuple(); + } + /// /// DICT_MERGE /// diff --git a/Tests/test_class.py b/Tests/test_class.py index 4ab7073f0..7f5b76cf0 100644 --- a/Tests/test_class.py +++ b/Tests/test_class.py @@ -844,6 +844,29 @@ class L(K,E): pass self.assertEqual(L.__mro__, (L, K, H, I, G, E, D, B, C, A, object)) + def test_class_args(self): + class A1: pass + class A2: pass + class A3: pass + + class B1(A1, *(A2, A3)): pass + self.assertEqual(B1.__mro__, (B1, A1, A2, A3, object)) + + clist = [A1, A2] + with self.assertRaisesMessage(TypeError, "duplicate base class A1"): + class B2(A1, *clist): pass + + def foo(x): return x + + class B3(foo(A1), clist[1], A3): pass + self.assertEqual(B3.__mro__, (B3, A1, A2, A3, object)) + + class B4(A1, *foo([A2, A3])): pass + self.assertEqual(B4.__mro__, (B4, A1, A2, A3, object)) + + with self.assertRaisesMessage(UnboundLocalError, "local variable 'A4' referenced before assignment"): + class B5(A1, *(A4,)): pass + class A4: pass def test_newstyle_lookup(self): """new-style classes should only lookup methods from the class, not from the instance"""