Skip to content

Commit

Permalink
bugfix: type checking logic for awaitable and coroutine (#64)
Browse files Browse the repository at this point in the history
* bugfix: type checking logic for awaitable and coroutine

* fix doctest

* fix tests and coverage

Co-authored-by: Willi Sontopski <willi.sontopski@peerox.de>
  • Loading branch information
LostInDarkMath and Willi Sontopski authored Oct 14, 2022
1 parent 8246efc commit 86cb684
Show file tree
Hide file tree
Showing 11 changed files with 615 additions and 17 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
# Changelog
## Pedantic 1.12.6
- bugfix in type checking logic concerning `typing.Aewaitable` and `typing.Coroutine`

## Pedantic 1.12.5
- fix type hints
- use `kw_only=True` in `frozen_dataclass` and `frozen_type_safe_dataclass`
Expand Down
5 changes: 5 additions & 0 deletions docs/pedantic/tests/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ <h1 class="title">Module <code>pedantic.tests</code></h1>
<section>
<h2 class="section-title" id="header-submodules">Sub-modules</h2>
<dl>
<dt><code class="name"><a title="pedantic.tests.test_assert_value_matches_type" href="test_assert_value_matches_type.html">pedantic.tests.test_assert_value_matches_type</a></code></dt>
<dd>
<div class="desc"></div>
</dd>
<dt><code class="name"><a title="pedantic.tests.test_frozen_dataclass" href="test_frozen_dataclass.html">pedantic.tests.test_frozen_dataclass</a></code></dt>
<dd>
<div class="desc"></div>
Expand Down Expand Up @@ -128,6 +132,7 @@ <h1>Index</h1>
</li>
<li><h3><a href="#header-submodules">Sub-modules</a></h3>
<ul>
<li><code><a title="pedantic.tests.test_assert_value_matches_type" href="test_assert_value_matches_type.html">pedantic.tests.test_assert_value_matches_type</a></code></li>
<li><code><a title="pedantic.tests.test_frozen_dataclass" href="test_frozen_dataclass.html">pedantic.tests.test_frozen_dataclass</a></code></li>
<li><code><a title="pedantic.tests.test_generator_wrapper" href="test_generator_wrapper.html">pedantic.tests.test_generator_wrapper</a></code></li>
<li><code><a title="pedantic.tests.test_rename_kwargs" href="test_rename_kwargs.html">pedantic.tests.test_rename_kwargs</a></code></li>
Expand Down
385 changes: 385 additions & 0 deletions docs/pedantic/tests/test_assert_value_matches_type.html

Large diffs are not rendered by default.

63 changes: 60 additions & 3 deletions docs/pedantic/tests/test_frozen_dataclass.html
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ <h1 class="title">Module <code>pedantic.tests.test_frozen_dataclass</code></h1>
</summary>
<pre><code class="python">import unittest
from dataclasses import dataclass, FrozenInstanceError
from typing import List, Dict, Set, Tuple
from typing import List, Dict, Set, Tuple, Awaitable, Callable

from pedantic.decorators.cls_deco_frozen_dataclass import frozen_dataclass, frozen_type_safe_dataclass
from pedantic.exceptions import PedanticTypeCheckException
Expand All @@ -52,6 +52,7 @@ <h1 class="title">Module <code>pedantic.tests.test_frozen_dataclass</code></h1>
bar: Dict[str, str]
values: Tuple[B, B]


class TestFrozenDataclass(unittest.TestCase):
def test_equals_and_hash(self):
a = Foo(a=6, b=&#39;hi&#39;, c=True)
Expand Down Expand Up @@ -311,7 +312,22 @@ <h1 class="title">Module <code>pedantic.tests.test_frozen_dataclass</code></h1>

a = b.copy_with()
self.assertEqual(b, a)
self.assertEqual(4, i)</code></pre>
self.assertEqual(4, i)

def test_type_safe_frozen_dataclass_with_awaitable(self):
@frozen_type_safe_dataclass
class A:
f: Callable[..., Awaitable[int]]

async def _cb() -&gt; int:
return 42

async def _cb_2() -&gt; str:
return &#39;42&#39;

A(f=_cb)
with self.assertRaises(expected_exception=PedanticTypeCheckException):
A(f=_cb_2)</code></pre>
</details>
</section>
<section>
Expand Down Expand Up @@ -915,7 +931,22 @@ <h3>Methods</h3>

a = b.copy_with()
self.assertEqual(b, a)
self.assertEqual(4, i)</code></pre>
self.assertEqual(4, i)

def test_type_safe_frozen_dataclass_with_awaitable(self):
@frozen_type_safe_dataclass
class A:
f: Callable[..., Awaitable[int]]

async def _cb() -&gt; int:
return 42

async def _cb_2() -&gt; str:
return &#39;42&#39;

A(f=_cb)
with self.assertRaises(expected_exception=PedanticTypeCheckException):
A(f=_cb_2)</code></pre>
</details>
<h3>Ancestors</h3>
<ul class="hlist">
Expand Down Expand Up @@ -1309,6 +1340,31 @@ <h3>Methods</h3>
)</code></pre>
</details>
</dd>
<dt id="pedantic.tests.test_frozen_dataclass.TestFrozenDataclass.test_type_safe_frozen_dataclass_with_awaitable"><code class="name flex">
<span>def <span class="ident">test_type_safe_frozen_dataclass_with_awaitable</span></span>(<span>self)</span>
</code></dt>
<dd>
<div class="desc"></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">def test_type_safe_frozen_dataclass_with_awaitable(self):
@frozen_type_safe_dataclass
class A:
f: Callable[..., Awaitable[int]]

async def _cb() -&gt; int:
return 42

async def _cb_2() -&gt; str:
return &#39;42&#39;

A(f=_cb)
with self.assertRaises(expected_exception=PedanticTypeCheckException):
A(f=_cb_2)</code></pre>
</details>
</dd>
<dt id="pedantic.tests.test_frozen_dataclass.TestFrozenDataclass.test_validate_types"><code class="name flex">
<span>def <span class="ident">test_validate_types</span></span>(<span>self)</span>
</code></dt>
Expand Down Expand Up @@ -1398,6 +1454,7 @@ <h4><code><a title="pedantic.tests.test_frozen_dataclass.TestFrozenDataclass" hr
<li><code><a title="pedantic.tests.test_frozen_dataclass.TestFrozenDataclass.test_frozen_type_safe_dataclass_copy_with_check" href="#pedantic.tests.test_frozen_dataclass.TestFrozenDataclass.test_frozen_type_safe_dataclass_copy_with_check">test_frozen_type_safe_dataclass_copy_with_check</a></code></li>
<li><code><a title="pedantic.tests.test_frozen_dataclass.TestFrozenDataclass.test_frozen_typesafe_dataclass_with_post_init" href="#pedantic.tests.test_frozen_dataclass.TestFrozenDataclass.test_frozen_typesafe_dataclass_with_post_init">test_frozen_typesafe_dataclass_with_post_init</a></code></li>
<li><code><a title="pedantic.tests.test_frozen_dataclass.TestFrozenDataclass.test_frozen_typesafe_dataclass_without_post_init" href="#pedantic.tests.test_frozen_dataclass.TestFrozenDataclass.test_frozen_typesafe_dataclass_without_post_init">test_frozen_typesafe_dataclass_without_post_init</a></code></li>
<li><code><a title="pedantic.tests.test_frozen_dataclass.TestFrozenDataclass.test_type_safe_frozen_dataclass_with_awaitable" href="#pedantic.tests.test_frozen_dataclass.TestFrozenDataclass.test_type_safe_frozen_dataclass_with_awaitable">test_type_safe_frozen_dataclass_with_awaitable</a></code></li>
<li><code><a title="pedantic.tests.test_frozen_dataclass.TestFrozenDataclass.test_validate_types" href="#pedantic.tests.test_frozen_dataclass.TestFrozenDataclass.test_validate_types">test_validate_types</a></code></li>
</ul>
</li>
Expand Down
1 change: 0 additions & 1 deletion docs/pedantic/tests/tests_pedantic_async.html
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ <h1 class="title">Module <code>pedantic.tests.tests_pedantic_async</code></h1>
</summary>
<pre><code class="python">import asyncio
import unittest
from typing import Any, Coroutine

from pedantic.decorators.class_decorators import pedantic_class
from pedantic.exceptions import PedanticTypeCheckException
Expand Down
50 changes: 43 additions & 7 deletions docs/pedantic/type_checking_logic/check_types.html
Original file line number Diff line number Diff line change
Expand Up @@ -393,7 +393,7 @@ <h1 class="title">Module <code>pedantic.type_checking_logic.check_types</code></

def get_type_arguments(cls: Any) -&gt; Tuple[Any, ...]:
&#34;&#34;&#34; Works similar to typing.args()
&gt;&gt;&gt; from typing import Tuple, List, Union, Callable, Any, NewType, TypeVar, Optional
&gt;&gt;&gt; from typing import Tuple, List, Union, Callable, Any, NewType, TypeVar, Optional, Awaitable, Coroutine
&gt;&gt;&gt; get_type_arguments(int)
()
&gt;&gt;&gt; get_type_arguments(List[float])
Expand Down Expand Up @@ -434,6 +434,10 @@ <h1 class="title">Module <code>pedantic.type_checking_logic.check_types</code></
(&lt;class &#39;int&#39;&gt;, &lt;class &#39;NoneType&#39;&gt;)
&gt;&gt;&gt; get_type_arguments(str | int) if sys.version_info &gt;= (3, 10) else (str, int)
(&lt;class &#39;str&#39;&gt;, &lt;class &#39;int&#39;&gt;)
&gt;&gt;&gt; get_type_arguments(Awaitable[str])
(&lt;class &#39;str&#39;&gt;,)
&gt;&gt;&gt; get_type_arguments(Coroutine[int, bool, str])
(&lt;class &#39;int&#39;&gt;, &lt;class &#39;bool&#39;&gt;, &lt;class &#39;str&#39;&gt;)
&#34;&#34;&#34;

result = ()
Expand All @@ -459,7 +463,7 @@ <h1 class="title">Module <code>pedantic.type_checking_logic.check_types</code></

def get_base_generic(cls: Any) -&gt; Any:
&#34;&#34;&#34;
&gt;&gt;&gt; from typing import List, Union, Tuple, Callable, Dict, Set
&gt;&gt;&gt; from typing import List, Union, Tuple, Callable, Dict, Set, Awaitable, Coroutine
&gt;&gt;&gt; get_base_generic(List)
typing.List
&gt;&gt;&gt; get_base_generic(List[float])
Expand Down Expand Up @@ -490,6 +494,10 @@ <h1 class="title">Module <code>pedantic.type_checking_logic.check_types</code></
typing.Set
&gt;&gt;&gt; get_base_generic(Set[int])
typing.Set
&gt;&gt;&gt; get_base_generic(Awaitable[int])
typing.Awaitable
&gt;&gt;&gt; get_base_generic(Coroutine[None, None, int])
typing.Coroutine
&#34;&#34;&#34;

origin = cls.__origin__ if hasattr(cls, &#39;__origin__&#39;) else None
Expand Down Expand Up @@ -896,7 +904,19 @@ <h1 class="title">Module <code>pedantic.type_checking_logic.check_types</code></
if not _is_subtype(sub_type=param.annotation, super_type=expected_type):
return False

return _is_subtype(sub_type=sig.return_annotation, super_type=ret_type)
if not inspect.iscoroutinefunction(value):
return _is_subtype(sub_type=sig.return_annotation, super_type=ret_type)

base = get_base_generic(ret_type)

if base == typing.Awaitable:
arg = get_type_arguments(ret_type)[0]
elif base == typing.Coroutine:
arg = get_type_arguments(ret_type)[2]
else:
return False

return _is_subtype(sub_type=sig.return_annotation, super_type=arg)


def _is_lambda(obj: Any) -&gt; bool:
Expand Down Expand Up @@ -1014,7 +1034,7 @@ <h2 class="section-title" id="header-functions">Functions</h2>
<span>def <span class="ident">get_base_generic</span></span>(<span>cls: Any) ‑> Any</span>
</code></dt>
<dd>
<div class="desc"><pre><code class="language-python-repl">&gt;&gt;&gt; from typing import List, Union, Tuple, Callable, Dict, Set
<div class="desc"><pre><code class="language-python-repl">&gt;&gt;&gt; from typing import List, Union, Tuple, Callable, Dict, Set, Awaitable, Coroutine
&gt;&gt;&gt; get_base_generic(List)
typing.List
&gt;&gt;&gt; get_base_generic(List[float])
Expand Down Expand Up @@ -1045,14 +1065,18 @@ <h2 class="section-title" id="header-functions">Functions</h2>
typing.Set
&gt;&gt;&gt; get_base_generic(Set[int])
typing.Set
&gt;&gt;&gt; get_base_generic(Awaitable[int])
typing.Awaitable
&gt;&gt;&gt; get_base_generic(Coroutine[None, None, int])
typing.Coroutine
</code></pre></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">def get_base_generic(cls: Any) -&gt; Any:
&#34;&#34;&#34;
&gt;&gt;&gt; from typing import List, Union, Tuple, Callable, Dict, Set
&gt;&gt;&gt; from typing import List, Union, Tuple, Callable, Dict, Set, Awaitable, Coroutine
&gt;&gt;&gt; get_base_generic(List)
typing.List
&gt;&gt;&gt; get_base_generic(List[float])
Expand Down Expand Up @@ -1083,6 +1107,10 @@ <h2 class="section-title" id="header-functions">Functions</h2>
typing.Set
&gt;&gt;&gt; get_base_generic(Set[int])
typing.Set
&gt;&gt;&gt; get_base_generic(Awaitable[int])
typing.Awaitable
&gt;&gt;&gt; get_base_generic(Coroutine[None, None, int])
typing.Coroutine
&#34;&#34;&#34;

origin = cls.__origin__ if hasattr(cls, &#39;__origin__&#39;) else None
Expand All @@ -1100,7 +1128,7 @@ <h2 class="section-title" id="header-functions">Functions</h2>
</code></dt>
<dd>
<div class="desc"><p>Works similar to typing.args()</p>
<pre><code class="language-python-repl">&gt;&gt;&gt; from typing import Tuple, List, Union, Callable, Any, NewType, TypeVar, Optional
<pre><code class="language-python-repl">&gt;&gt;&gt; from typing import Tuple, List, Union, Callable, Any, NewType, TypeVar, Optional, Awaitable, Coroutine
&gt;&gt;&gt; get_type_arguments(int)
()
&gt;&gt;&gt; get_type_arguments(List[float])
Expand Down Expand Up @@ -1141,14 +1169,18 @@ <h2 class="section-title" id="header-functions">Functions</h2>
(&lt;class 'int'&gt;, &lt;class 'NoneType'&gt;)
&gt;&gt;&gt; get_type_arguments(str | int) if sys.version_info &gt;= (3, 10) else (str, int)
(&lt;class 'str'&gt;, &lt;class 'int'&gt;)
&gt;&gt;&gt; get_type_arguments(Awaitable[str])
(&lt;class 'str'&gt;,)
&gt;&gt;&gt; get_type_arguments(Coroutine[int, bool, str])
(&lt;class 'int'&gt;, &lt;class 'bool'&gt;, &lt;class 'str'&gt;)
</code></pre></div>
<details class="source">
<summary>
<span>Expand source code</span>
</summary>
<pre><code class="python">def get_type_arguments(cls: Any) -&gt; Tuple[Any, ...]:
&#34;&#34;&#34; Works similar to typing.args()
&gt;&gt;&gt; from typing import Tuple, List, Union, Callable, Any, NewType, TypeVar, Optional
&gt;&gt;&gt; from typing import Tuple, List, Union, Callable, Any, NewType, TypeVar, Optional, Awaitable, Coroutine
&gt;&gt;&gt; get_type_arguments(int)
()
&gt;&gt;&gt; get_type_arguments(List[float])
Expand Down Expand Up @@ -1189,6 +1221,10 @@ <h2 class="section-title" id="header-functions">Functions</h2>
(&lt;class &#39;int&#39;&gt;, &lt;class &#39;NoneType&#39;&gt;)
&gt;&gt;&gt; get_type_arguments(str | int) if sys.version_info &gt;= (3, 10) else (str, int)
(&lt;class &#39;str&#39;&gt;, &lt;class &#39;int&#39;&gt;)
&gt;&gt;&gt; get_type_arguments(Awaitable[str])
(&lt;class &#39;str&#39;&gt;,)
&gt;&gt;&gt; get_type_arguments(Coroutine[int, bool, str])
(&lt;class &#39;int&#39;&gt;, &lt;class &#39;bool&#39;&gt;, &lt;class &#39;str&#39;&gt;)
&#34;&#34;&#34;

result = ()
Expand Down
78 changes: 78 additions & 0 deletions pedantic/tests/test_assert_value_matches_type.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import unittest
from dataclasses import dataclass
from typing import Callable, Awaitable, Coroutine

from pedantic.exceptions import PedanticTypeCheckException
from pedantic.type_checking_logic.check_types import assert_value_matches_type


@dataclass
class Foo:
value: int


class TestAssertValueMatchesType(unittest.TestCase):
def test_callable(self):
def _cb(foo: Foo) -> str:
return str(foo.value)

assert_value_matches_type(
value=_cb,
type_=Callable[..., str],
err='',
type_vars={},
)

with self.assertRaises(expected_exception=PedanticTypeCheckException):
assert_value_matches_type(
value=_cb,
type_=Callable[..., int],
err='',
type_vars={},
)

def test_callable_awaitable(self):
async def _cb(foo: Foo) -> str:
return str(foo.value)

assert_value_matches_type(
value=_cb,
type_=Callable[..., Awaitable[str]],
err='',
type_vars={},
)

with self.assertRaises(expected_exception=PedanticTypeCheckException):
assert_value_matches_type(
value=_cb,
type_=Callable[..., Awaitable[int]],
err='',
type_vars={},
)

with self.assertRaises(expected_exception=PedanticTypeCheckException):
assert_value_matches_type(
value=_cb,
type_=Callable[..., str],
err='',
type_vars={},
)

def test_coroutine_awaitable(self):
async def _cb(foo: Foo) -> str:
return str(foo.value)

assert_value_matches_type(
value=_cb,
type_=Callable[..., Coroutine[None, None, str]],
err='',
type_vars={},
)

with self.assertRaises(expected_exception=PedanticTypeCheckException):
assert_value_matches_type(
value=_cb,
type_=Callable[..., Coroutine[None, None, int]],
err='',
type_vars={},
)
18 changes: 17 additions & 1 deletion pedantic/tests/test_frozen_dataclass.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import unittest
from dataclasses import dataclass, FrozenInstanceError
from typing import List, Dict, Set, Tuple
from typing import List, Dict, Set, Tuple, Awaitable, Callable

from pedantic.decorators.cls_deco_frozen_dataclass import frozen_dataclass, frozen_type_safe_dataclass
from pedantic.exceptions import PedanticTypeCheckException
Expand All @@ -24,6 +24,7 @@ class A:
bar: Dict[str, str]
values: Tuple[B, B]


class TestFrozenDataclass(unittest.TestCase):
def test_equals_and_hash(self):
a = Foo(a=6, b='hi', c=True)
Expand Down Expand Up @@ -284,3 +285,18 @@ class B(A):
a = b.copy_with()
self.assertEqual(b, a)
self.assertEqual(4, i)

def test_type_safe_frozen_dataclass_with_awaitable(self):
@frozen_type_safe_dataclass
class A:
f: Callable[..., Awaitable[int]]

async def _cb() -> int:
return 42

async def _cb_2() -> str:
return '42'

A(f=_cb)
with self.assertRaises(expected_exception=PedanticTypeCheckException):
A(f=_cb_2)
Loading

0 comments on commit 86cb684

Please sign in to comment.