In [1]:
from icecream import ic
help(compile)

Help on built-in function compile in module builtins:

compile(source, filename, mode, flags=0, dont_inherit=False, optimize=-1, *, _feature_version=-1)
    Compile source into a code object that can be executed by exec() or eval().
    
    The source code may represent a Python module, statement or expression.
    The filename will be used for run-time error messages.
    The mode must be 'exec' to compile a module, 'single' to compile a
    single (interactive) statement, or 'eval' to compile an expression.
    The flags argument, if present, controls which future statements influence
    the compilation of the code.
    The dont_inherit argument, if true, stops the compilation inheriting
    the effects of any future statements in effect in the code calling
    compile; if absent or false these statements do influence the compilation,
    in addition to any features explicitly specified.



In [2]:
import types
import ast
import inspect
import astpretty

class AnnotsToCommentsVisitor(ast.NodeTransformer):
    def visit_AnnAssign(self, node):
        node2 = ast.Assign([node.target], node.value, str(node.annotation))
        return ast.copy_location(node2, node)
  
class TypeCheckVisitor(ast.NodeTransformer):
    def visit_FunctionDef(self, node):
        # name is a raw string of the function name.
        # args is a arguments node.
        # body is the list of nodes inside the function.
        # decorator_list is the list of decorators to be applied, stored outermost first (i.e. the first in the list will be applied last).
        # returns is the return annotation (Python 3 only).
        # type_comment is optional. It is a string containing the PEP 484 type comment of the function (added in Python 3.8)        
        node = self.generic_visit(node)
        new_decorator_list = [dec for dec in node.decorator_list 
                              if not (isinstance(dec, ast.Name) and dec.id == 'typecheck')]
        new_node = ast.FunctionDef(node.name + '_checked', node.args, 
                                   node.body, new_decorator_list, 
                                   node.returns, node.type_comment)
        new_node = ast.copy_location(new_node, node)
        return new_node
 

    def visit_AnnAssign(self, node):
        # An assignment with a type annotation. 
        # mode.target is a single node and can be a Name, a Attribute or a Subscript. 
        # annotation is the annotation, such as a Str or Name node. 
        # value is a single optional node.
        # simple is True for a Name node in target that do not appear in 
        #   between parenthesis and are hence pure names and not expressions.

        if not node.simple:
          return node
        
        assert isinstance(node.target, ast.Name) # Should be guaranteed by node.simple
        node_assert = ast.Assert(
          test=ast.Call(ast.Name('isinstance', ctx=ast.Load()), [
                          node.target, # ast.Name(node.target.id, ctx=ast.Load()), # Convert ctx from Store to Load 
                          node.annotation
                        ], 
                        []),
          msg=ast.Constant(value='failed', kind=None) # f'{node.target.id} not an {node.annotation}'
        )
        node_assert = ast.copy_location(node_assert, node)
        ast.fix_missing_locations(node_assert)
        return [node, node_assert]
  
def typecheck(f, show_src=False):
  # If this fails, try inspect-via-pytorch.py
  src = inspect.getsource(f)
  filename = inspect.getsourcefile(f)
  node = ast.parse(src, filename=filename)
  new_node = TypeCheckVisitor().visit(node)
  
  if show_src:
    print('typecheck: Transformed source code')
    new_src = ast.unparse(new_node)
    print(new_src)

  # Compile new AST to get wrapped function
  try:
    code = compile(new_node, filename='<typecheck>', mode='exec')
  except Exception as e:
    # Most compile errors are pretty opaque (https://stackoverflow.com/a/25795966)
    # So call astpretty.  If it succeeds, it's helpful to debug, if it fails, its
    # error messages are much more helpful
    msg = astpretty.pformat(new_node)
    print(msg)
    raise ValueError("See AST printed above") from e

  f_code = code.co_consts[3] # TODO search better
  f_checked = types.FunctionType(f_code, globals=f.__globals__)
  f_checked.__wrapped__ = f
  return f_checked


import jax
import jax.numpy as jnp

@typecheck
def foo(x : int, t : float) -> float:
  y : float = x * t
  assert isinstance(y, float), f"y a {type(y)} not a float"
  z : int = x // 2
  assert isinstance(z, int), "z not a int"
  return z * y

foo.__wrapped__(3,4.2)

foo(3,4.2)

@jax.jit
@typecheck
def foo1(x : int, t : jnp.ndarray) -> float:
  y : int = x *t # fred
  z : jnp.ndarray = y / 2
  return z


Module(
    body=[
        FunctionDef(
            lineno=2,
            col_offset=0,
            end_lineno=7,
            end_col_offset=14,
            name='foo_checked',
            args=arguments(
                posonlyargs=[],
                args=[
                    arg(
                        lineno=2,
                        col_offset=8,
                        end_lineno=2,
                        end_col_offset=15,
                        arg='x',
                        annotation=Name(lineno=2, col_offset=12, end_lineno=2, end_col_offset=15, id='int', ctx=Load()),
                        type_comment=None,
                    ),
                    arg(
                        lineno=2,
                        col_offset=17,
                        end_lineno=2,
                        end_col_offset=26,
                        arg='t',
                        annotation=Name(lineno=2, col_offset=21, end_lineno=2, end_col_offset=26, id='float', ctx=Load()),
       

ValueError: See AST printed above

In [108]:
x = 3
t 



{'__name__': '__main__',
 '__doc__': 'Automatically created module for IPython interactive environment',
 '__package__': None,
 '__loader__': None,
 '__spec__': None,
 '__builtin__': <module 'builtins' (built-in)>,
 '__builtins__': <module 'builtins' (built-in)>,
 '_ih': ['',
  'import jax\nimport jax.numpy as jnp\n\nimport ast\nimport inspect\n\nclass Fred:\n  def __instancecheck__(cls, obj):\n    return isinstance(obj, torch.Tensor)\n\ndef f(x : int, t : jnp.ndarray) -> float:\n  y : int = x *t\n  z : jnp.ndarray = y / 2\n  return z\n\n# If this fails, try inspect-via-pytorch.py\nsrc = inspect.getsource(f)\nprint(src)\npy_ast = ast.parse(src)\nprint(ast.unparse(py_ast))\n\n\nfor n in ast.walk(py_ast):\n  print(n)',
  'import jax\nimport jax.numpy as jnp\n\nimport ast\nimport inspect\n\nclass Fred:\n  def __instancecheck__(cls, obj):\n    return isinstance(obj, torch.Tensor)\n\ndef f(x : int, t : jnp.ndarray) -> float:\n  y : int = x *t # fred\n  z : jnp.ndarray = y / 2\n  return z\n\