## awfutils.typecheck: Run-time type checking for annotations

I love run time type checkers, [particularly with JAX](https://github.com/google/jaxtyping/blob/main/FAQ.md#what-about-support-for-static-type-checkers-like-mypy-pyright-etc) but by default they don't check statement-level annotations like these:

In [1]:
def foo(x : int, y : float):
  z : int = x * y # This should error, but doesn't
  w : float = z * 3.2
  return w

foo(3, 1.3)

12.480000000000002

With the awfutils `typecheck` decorator, they can...

In [2]:
from awfutils import typecheck

@typecheck
def foo(x : int, y : float):
  z : int = x * y # Now it raises TypeError: z not of type int
  w : float = z * 3.2
  return w

try:
  foo(3, 1.3) # Error comes from this call
  print("BOO. Should not get here")
except TypeError as e:
  print("YAY. Caught TypeError")
  print("Error:", e)

YAY. Caught TypeError
Error: z not of type int, was <class 'float'>, value 3.9000000000000004


And you can use a callable instead to check more complex concepts

In [3]:
import torch

def is_square_tensor(x):
  return x.shape[0] == x.shape[1]

@typecheck
def foo(x : torch.Tensor):
  z : is_square_tensor = x * 3 # check result is square
  return z

try:
  foo(torch.ones(3,4))
  print("BOO. Should not get here")
except TypeError as e:
  print("YAY. Caught TypeError")
  print("Error:", e)

YAY. Caught TypeError
Error: z does not satisfy is_square_tensor


Or define local shape checkers, using runtime-derived values

In [4]:
def is_shape(*sh):
  return lambda x: x.shape == sh 

@typecheck
def foo(x : torch.Tensor):
  L,D = x.shape # Get shape of X
  LxD = is_shape(L,D) # LxD(v) checks that v is LxD
  LxL = is_shape(L,L) # LxD(v) checks that v is LxD

  z : LxL = x @ x.T # check result is LxL
  w : LxL = z @ x # Should fail - we meant LxD
  return w

try:
  foo(torch.ones(3,4))
  print("BOO. Should not get here")
except TypeError as e:
  print("YAY. Caught TypeError")
  print("Error:", e)

YAY. Caught TypeError
Error: w does not satisfy LxL


## How it works: source-code transformation
This works by AST transformation, replacing the function with a new version, with additional statements inserted to perform the type checks.

So the function foo from above
```py
def foo(x : int, y : float):
  z : int = x * y
  w : float = z * 3.2
  return w
```
is transformed into (see the cell below)
```py
def foo_typecheck_wrap(x: int, y: float):
  # Check argument annotations:
  typecheck.check_annot(x, int, 'x', 'int')
  typecheck.check_annot(y, float, 'y', 'float')

  # Function body with checked statement annotations:
  z: int = x * y
  typecheck.check_annot(z, int, 'z', 'int')
  w: float = z * 3.2
  typecheck.check_annot(w, float, 'w', 'float')
  return w
```

In [5]:
@typecheck(show_src=True)
def foo(x : int, y : float):
  z : int = x * y
  w : float = z * 3.2
  return w

typecheck: Transformed source code
def _():

    @typecheck(show_src=True)
    def foo_typecheck_wrap(x: int, y: float):
        typecheck.check_annot(x, int, 'x', 'int')
        typecheck.check_annot(y, float, 'y', 'float')
        z: int = x * y
        typecheck.check_annot(z, int, 'z', 'int')
        w: float = z * 3.2
        typecheck.check_annot(w, float, 'w', 'float')
        return w
