Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve clr.accepts/returns #1449

Merged
merged 4 commits into from
May 11, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Directory.Build.props
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@
</PropertyGroup>

<ItemGroup>
<Using Include="Microsoft.Scripting.Runtime.NotNullAttribute" Alias="NotNoneAttribute" />
<Using Include="Microsoft.Scripting.Runtime.NotNullAttribute" Alias="NotNone" />
</ItemGroup>

<!-- Release -->
Expand Down
75 changes: 49 additions & 26 deletions Src/IronPython/Runtime/ClrModule.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
using System.Collections;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Reflection;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
Expand Down Expand Up @@ -592,7 +593,7 @@ public sealed class ReferencesList : List<Assembly>, ICodeFormattable {
/// </summary>
/// <param name="types"></param>
/// <returns></returns>
public static object accepts(params object[] types) {
public static object accepts([NotNone] params object[] types) {
return new ArgChecker(types);
}

Expand All @@ -615,10 +616,10 @@ public sealed class ReferencesList : List<Assembly>, ICodeFormattable {
/// Decorator for verifying the arguments to a function are of a specified type.
/// </summary>
public class ArgChecker {
private readonly object[] expected;
private readonly PythonType[] expected;

public ArgChecker(object[] prms) {
expected = prms;
public ArgChecker([NotNone] object[] types) {
expected = types.Select(t => t.ToPythonType()).ToArray();
}

#region ICallableWithCodeContext Members
Expand All @@ -638,16 +639,16 @@ public class ArgChecker {
/// then calls the original function.
/// </summary>
public class RuntimeArgChecker : PythonTypeSlot {
private readonly object[] _expected;
private readonly PythonType[] _expected;
private readonly object _func;
private readonly object _inst;

public RuntimeArgChecker(object function, object[] expectedArgs) {
public RuntimeArgChecker([NotNone] object function, [NotNone] PythonType[] expectedArgs) {
_expected = expectedArgs;
_func = function;
}

public RuntimeArgChecker(object instance, object function, object[] expectedArgs)
public RuntimeArgChecker(object instance, [NotNone] object function, [NotNone] PythonType[] expectedArgs)
: this(function, expectedArgs) {
_inst = instance;
}
Expand All @@ -661,11 +662,12 @@ public RuntimeArgChecker(object instance, object function, object[] expectedArgs

// no need to validate self... the method should handle it.
for (int i = start; i < args.Length + start; i++) {
PythonType dt = DynamicHelpers.GetPythonType(args[i - start]);

PythonType expct = _expected[i] as PythonType;
if (dt != _expected[i] && !dt.IsSubclassOf(expct)) {
throw PythonOps.AssertionError("argument {0} has bad value (got {1}, expected {2})", i, dt, _expected[i]);
object arg = args[i - start];
PythonType expct = _expected[i];
if (!IsInstanceOf(arg, expct)) {
throw PythonOps.AssertionError(
"argument {0} has bad value (expected {1}, got {2})",
i, expct.Name, PythonOps.GetPythonTypeName(arg));
}
}
}
Expand Down Expand Up @@ -714,10 +716,10 @@ public RuntimeArgChecker(object instance, object function, object[] expectedArgs
/// Decorator for verifying the return type of functions.
/// </summary>
public class ReturnChecker {
public object retType;
public PythonType retType;

public ReturnChecker(object returnType) {
retType = returnType;
retType = returnType.ToPythonType();
}

#region ICallableWithCodeContext Members
Expand All @@ -735,30 +737,25 @@ public class ReturnChecker {
/// validates the return type is of a specified type.
/// </summary>
public class RuntimeReturnChecker : PythonTypeSlot {
private readonly object _retType;
private readonly PythonType _retType;
private readonly object _func;
private readonly object _inst;

public RuntimeReturnChecker(object function, object expectedReturn) {
public RuntimeReturnChecker([NotNone] object function, [NotNone] PythonType expectedReturn) {
_retType = expectedReturn;
_func = function;
}

public RuntimeReturnChecker(object instance, object function, object expectedReturn)
public RuntimeReturnChecker(object instance, [NotNone] object function, [NotNone] PythonType expectedReturn)
: this(function, expectedReturn) {
_inst = instance;
}

private void ValidateReturn(object ret) {
// we return void...
if (ret == null && _retType == null) return;

PythonType dt = DynamicHelpers.GetPythonType(ret);
if (dt != _retType) {
PythonType expct = _retType as PythonType;

if (!dt.IsSubclassOf(expct))
throw PythonOps.AssertionError("bad return value returned (expected {0}, got {1})", _retType, dt);
if (!IsInstanceOf(ret, _retType)) {
throw PythonOps.AssertionError(
"bad return value returned (expected {0}, got {1})",
_retType.Name, PythonOps.GetPythonTypeName(ret));
}
}

Expand Down Expand Up @@ -811,6 +808,32 @@ public RuntimeReturnChecker(object instance, object function, object expectedRet
#endregion
}

private static PythonType ToPythonType(this object obj) {
return obj switch {
PythonType pt => pt,
Type t => DynamicHelpers.GetPythonTypeFromType(t),
null => TypeCache.Null,
_ => throw PythonOps.TypeErrorForTypeMismatch("type", obj)
};
}

private static bool IsInstanceOf(object obj, PythonType pt) {
// See also PythonOps.IsInstance
var objType = DynamicHelpers.GetPythonType(obj);

if (objType == pt) {
return true;
}

// PEP 237: int/long unification
// https://github.com/IronLanguages/ironpython3/issues/52
if (pt == TypeCache.BigInteger && obj is int) {
return true;
}

return pt.__subclasscheck__(objType);
}

/// <summary>
/// returns the result of dir(o) as-if "import clr" has not been performed.
/// </summary>
Expand Down
91 changes: 71 additions & 20 deletions Tests/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import unittest

from iptest import IronPythonTestCase, is_cli, is_mono, is_netcoreapp, is_posix, big, run_test, skipUnlessIronPython
from iptest import IronPythonTestCase, is_cli, is_mono, is_netcoreapp, is_posix, big, clr_int_types, run_test, skipUnlessIronPython
from types import FunctionType, MethodType

global init
Expand Down Expand Up @@ -560,19 +560,29 @@ def foo(x):
return x

self.assertEqual(foo('abc'), 'abc')
self.assertRaises(AssertionError, foo, 2)
self.assertRaises(AssertionError, foo, big(2))
self.assertRaises(AssertionError, foo, 2.0)
self.assertRaises(AssertionError, foo, True)
self.assertRaisesMessage(AssertionError, "argument 0 has bad value (expected str, got int)", foo, 2)
self.assertRaisesMessage(AssertionError, "argument 0 has bad value (expected str, got int)", foo, big(2))
self.assertRaisesMessage(AssertionError, "argument 0 has bad value (expected str, got float)", foo, 2.0)
self.assertRaisesMessage(AssertionError, "argument 0 has bad value (expected str, got bool)", foo, True)

@clr.accepts(int)
def foo(x):
return x

self.assertEqual(foo(1), 1)
self.assertEqual(foo(big(1)), 1)
self.assertEqual(foo(True), True)
self.assertRaisesMessage(AssertionError, "argument 0 has bad value (expected int, got str)", foo, 'abc')
self.assertRaisesMessage(AssertionError, "argument 0 has bad value (expected int, got float)", foo, 2.0)

@clr.accepts(str, bool)
def foo(x, y):
return x, y

self.assertEqual(foo('abc', True), ('abc', True))
self.assertRaises(AssertionError, foo, ('abc',2))
self.assertRaises(AssertionError, foo, ('abc',big(2)))
self.assertRaises(AssertionError, foo, ('abc',2.0))
self.assertRaisesMessage(AssertionError, "argument 1 has bad value (expected bool, got int)", foo, 'abc', 2)
self.assertRaisesMessage(AssertionError, "argument 1 has bad value (expected bool, got int)", foo, 'abc', big(2))
self.assertRaisesMessage(AssertionError, "argument 1 has bad value (expected bool, got float)", foo, 'abc', 2.0)


class bar:
Expand All @@ -583,21 +593,25 @@ def foo(self, x):

a = bar()
self.assertEqual(a.foo('xyz'), 'xyz')
self.assertRaises(AssertionError, a.foo, 2)
self.assertRaises(AssertionError, a.foo, big(2))
self.assertRaises(AssertionError, a.foo, 2.0)
self.assertRaises(AssertionError, a.foo, True)
self.assertRaisesMessage(AssertionError, "argument 1 has bad value (expected str, got int)", a.foo, 2)
self.assertRaisesMessage(AssertionError, "argument 1 has bad value (expected str, got int)", a.foo, big(2))
self.assertRaisesMessage(AssertionError, "argument 1 has bad value (expected str, got float)", a.foo, 2.0)
self.assertRaisesMessage(AssertionError, "argument 1 has bad value (expected str, got bool)", a.foo, True)

@clr.returns(str)
def foo(x):
return x


self.assertEqual(foo('abc'), 'abc')
self.assertRaises(AssertionError, foo, 2)
self.assertRaises(AssertionError, foo, big(2))
self.assertRaises(AssertionError, foo, 2.0)
self.assertRaises(AssertionError, foo, True)
self.assertRaisesMessage(AssertionError, "bad return value returned (expected str, got int)", foo, 2)
self.assertRaisesMessage(AssertionError, "bad return value returned (expected str, got int)", foo, big(2))
self.assertRaisesMessage(AssertionError, "bad return value returned (expected str, got float)", foo, 2.0)
self.assertRaisesMessage(AssertionError, "bad return value returned (expected str, got bool)", foo, True)

with self.assertRaisesMessage(TypeError, "expected type, got int"):
@clr.accepts(0)
def foo(x): pass


@clr.accepts(bool)
@clr.returns(str)
Expand All @@ -607,15 +621,52 @@ def foo(x):

self.assertEqual(foo(True), 'True')

self.assertRaises(AssertionError, foo, 2)
self.assertRaises(AssertionError, foo, big(2))
self.assertRaises(AssertionError, foo, False)
self.assertRaisesMessage(AssertionError, "argument 0 has bad value (expected bool, got int)", foo, 2)
self.assertRaisesMessage(AssertionError, "argument 0 has bad value (expected bool, got int)", foo, big(2))
self.assertRaisesMessage(AssertionError, "bad return value returned (expected str, got int)", foo, False)


@clr.returns(None)
def foo(): pass

self.assertEqual(foo(), None)

@clr.returns(type(None))
def foo(): pass

self.assertEqual(foo(), None)

@clr.returns(None)
def foo():
return 1

self.assertRaisesMessage(AssertionError, "bad return value returned (expected NoneType, got int)", foo)

with self.assertRaisesMessage(TypeError, "expected type, got int"):
@clr.returns(0)
def foo(): pass

import System
for t in clr_int_types:
@clr.accepts(t)
def foo(x):
return x

self.assertRaisesMessage(AssertionError, f"""argument 0 has bad value (expected {str(t).split("'")[1]}, got int)""", foo, big(0))
if t != System.Int32:
self.assertRaisesMessage(AssertionError, f"""argument 0 has bad value (expected {str(t).split("'")[1]}, got int)""", foo, 0)

@clr.accepts(System.IConvertible)
@clr.returns(System.IConvertible)
def foo(x):
return x

self.assertEqual(foo(1), 1)
self.assertEqual(foo(True), True)
self.assertEqual(foo('abc'), 'abc')
self.assertRaisesMessage(AssertionError, "argument 0 has bad value (expected IConvertible, got int)", foo, big(1))
self.assertRaisesMessage(AssertionError, "argument 0 has bad value (expected IConvertible, got Guid)", foo, System.Guid.Empty)

def test_error_message(self):
try:
repr()
Expand Down