Skip to content
This repository has been archived by the owner on Dec 1, 2023. It is now read-only.

Commit

Permalink
Merge pull request #61 from Quansight-Labs/kwargs
Browse files Browse the repository at this point in the history
Switch to native isinstance
  • Loading branch information
saulshanabrook committed May 16, 2019
2 parents 683078d + 9e6f9e0 commit 6655bba
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 97 deletions.
128 changes: 128 additions & 0 deletions explorations/2019.05.15.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Kwargs / signature exploration"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Question: If we define a function with overloads, can we get the signatures?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"For example, here is `arange`:"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import typing"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"Array = typing.NewType('Array', object)\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"@typing.overload\n",
"def arange(stop, *, step=None, dtype=None) -> Array:\n",
" ...\n",
"\n",
"\n",
"# If arange is called with two positional arguments, the first is stop, and the second is start\n",
"@typing.overload\n",
"def arange(start, stop, step=None, dtype=None) -> Array:\n",
" ...\n",
"\n",
"\n",
"def arange(*args, **kwargs) -> Array:\n",
" ..."
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"import typing"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'return': <function typing.NewType.<locals>.new_type(x)>}"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"typing.get_type_hints(arange)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"No looks like we can't...."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
92 changes: 26 additions & 66 deletions metadsl/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import typing
import typing_inspect
import functools
import inspect
from .typing_tools import *

__all__ = [
Expand All @@ -15,14 +14,13 @@
"LiteralExpression",
"E",
"fold_identity",
"is_expression_type",
]

T_expression = typing.TypeVar("T_expression", bound="Expression")


@dataclasses.dataclass(eq=False, repr=False)
class Expression:
class Expression(GenericCheck):
"""
Subclass this type and provide relevent methods for your type. Do not add any fields.
Expand Down Expand Up @@ -71,90 +69,52 @@ class LiteralExpression(Expression, typing.Generic[T]):
E = typing.Union[T, LiteralExpression[T]]


def is_expression_type(t: typing.Type) -> bool:
def extract_expression_type(t: typing.Type) -> typing.Type[Expression]:
"""
Checks if a type is a subclass of expression. Also works on generic types.
If t is an expression type, return it, otherwise, it should be a union of an expression type and non expression type
"""
return safe_issubclass(t, Expression)
if typing_inspect.is_union_type(t):
expression_args = [
arg for arg in typing_inspect.get_args(t) if issubclass(arg, Expression)
]
if len(expression_args) != 1:
raise TypeError(
f"Union must contain exactly one expression type, not {len(expression_args)}: {t}"
)
return expression_args[0]
if issubclass(t, Expression):
return t
raise TypeError(f"{t} is not an expression type")


def extract_literal_expression_type(
t: typing.Type
) -> typing.Optional[typing.Type[LiteralExpression]]:
"""
If t is a literal expression type, then it returns the literal expression type, else none.
"""
if not typing_inspect.is_union_type(t):
return None
l, r = typing_inspect.get_args(t)
l_is_expression = is_expression_type(l)
r_is_expression = is_expression_type(r)
if l_is_expression and r_is_expression:
raise TypeError(f"Cannot use union of expression {t}")
if not l_is_expression and not r_is_expression:
return None
return l if l_is_expression else r


def create_expression(
fn: typing.Callable[..., T],
args: typing.Tuple,
return_type: typing.Optional[typing.Type[Expression]],
) -> T:
def create_expression(fn: typing.Callable[..., T], args: typing.Tuple) -> T:
"""
Given a function and some arguments, return the right expression for the return type.
"""
if not return_type:
# We need to get access to the actual function, because even though the wrapped
# one has the same signature, the globals wont be set properly for
# typing.inspect_type
fn_for_typing = getattr(fn, "__wrapped__", fn)

arg_types = [get_type(arg) for arg in args]
return_type = infer_return_type(fn_for_typing, *arg_types)

# If it is a literal return value, create the literal expression
return_type = extract_literal_expression_type(return_type) or return_type
# We need to get access to the actual function, because even though the wrapped
# one has the same signature, the globals wont be set properly for
# typing.inspect_type
fn_for_typing = getattr(fn, "__wrapped__", fn)

if not is_expression_type(return_type):
raise TypeError(f"Must return expression type not {return_type}")
arg_types = [get_type(arg) for arg in args]
return_type = infer_return_type(fn_for_typing, *arg_types)
expr_return_type = extract_expression_type(return_type)

return typing.cast(T, return_type(fn, args))
return typing.cast(T, expr_return_type(fn, args))


T_callable = typing.TypeVar("T_callable", bound=typing.Callable)


def n_function_args(fn: typing.Callable) -> int:
"""
Returns the number of args a function takes, raising an error if there are any non parameter args or variable args.
"""
n = 0
for param in inspect.signature(fn).parameters.values():
if param.kind != param.POSITIONAL_OR_KEYWORD:
raise TypeError(f"Arg type of {param} not supported for function {fn}")
n += 1
return n


def expression(fn: T_callable) -> T_callable:
"""
Creates an expresion object by wrapping a Python function and providing a function
that will take in the args and return an expression of the right type.
"""
# Note: Cannot do this because of forward type references not resolved for methods in classes

# # Verify that it can be called with just expression types
# inferred: typing.Type = infer_return_type(fn, *(Expression for _ in range(n_function_args(fn))))
# inferred = extract_literal_expression_type(inferred) or inferred
# if not typing_inspect.is_typevar(inferred) and not is_expression_type(inferred):
# raise TypeError(
# f"{fn} should return an expression type when passed in expression types, not a {inferred}"
# )

@functools.wraps(fn)
def expresion_(*args, _return_type=None):
return create_expression(expresion_, args, _return_type)
def expresion_(*args):
return create_expression(expresion_, args)

return typing.cast(T_callable, expresion_)

Expand Down
66 changes: 36 additions & 30 deletions metadsl/typing_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,41 @@

from .dict_tools import *

__all__ = [
"infer_return_type",
"get_type",
"get_arg_hints",
"safe_isinstance",
"safe_issubclass",
]
__all__ = ["infer_return_type", "get_type", "get_arg_hints", "GenericCheck"]
T = typing.TypeVar("T")


class GenericCheckType(type):
def __subclasscheck__(cls, sub):
"""
Modified from https://github.com/python/cpython/blob/aa73841a8fdded4a462d045d1eb03899cbeecd65/Lib/typing.py#L707-L717
"""
sub = getattr(sub, "__origin__", sub)
if hasattr(cls, "__origin__"):
return issubclass(sub, cls)
return super().__subclasscheck__(sub)


class GenericCheck(metaclass=GenericCheckType):
"""
Subclass this to support isinstance and issubclass checks with generic classes.
"""

pass


# Allow isinstance and issubclass calls on generic types
def generic_subclasscheck(self, cls):
"""
Modified from https://github.com/python/cpython/blob/aa73841a8fdded4a462d045d1eb03899cbeecd65/Lib/typing.py#L707-L717
"""
cls = getattr(cls, "__origin__", cls)
return issubclass(self.__origin__, cls)


typing._GenericAlias.__subclasscheck__ = generic_subclasscheck # type: ignore


def get_type(v: T) -> typing.Type[T]:
"""
Returns the type of the value with generic arguments preserved.
Expand All @@ -39,12 +64,13 @@ def match_types(hint: typing.Type, t: typing.Type) -> typevar_mapping_typing:
return match_types(l, t)
except TypeError:
pass

try:
return match_types(r, t)
except TypeError:
raise TypeError(f"Cannot match type {t} with hitn {hint}")
raise TypeError(f"Cannot match type {t} with hint {hint}")

if not safe_issubclass(t, hint):
if not issubclass(t, hint):
raise TypeError(f"Cannot match concrete type {t} with hint {hint}")
return safe_merge(
*(
Expand Down Expand Up @@ -105,7 +131,7 @@ def infer_return_type(
by looking at the type signature and matching generics.
"""
arg_hints = get_arg_hints(fn, arg_types[0] if arg_types else None)
return_hint = typing.get_type_hints(fn)['return']
return_hint = typing.get_type_hints(fn)["return"]

matches: typevar_mapping_typing = safe_merge(
*(
Expand All @@ -115,23 +141,3 @@ def infer_return_type(
)
)
return replace_typevars(matches, return_hint)


def get_base_type(t: typing.Type) -> typing.Type:
return typing_inspect.get_origin(t) or t


def safe_isinstance(obj: object, t: typing.Type) -> bool:
"""
Works with types that are generic. If they are generic,
just checks if the value is an instance of the base type.
"""
return isinstance(obj, get_base_type(t))


def safe_issubclass(cls: typing.Type, cls_: typing.Type) -> bool:
"""
Works with types that are generic. If they are generic,
just checks if the base types ar e
"""
return issubclass(get_base_type(cls), get_base_type(cls_))
2 changes: 1 addition & 1 deletion metadsl/typing_tools_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def _generic_return(arg: T) -> T:
...


class _GenericClass(typing.Generic[T]):
class _GenericClass(GenericCheck, typing.Generic[T]):
def method(self) -> T:
...

Expand Down

0 comments on commit 6655bba

Please sign in to comment.