# torchtyping

Quick tutorial of runtime type-checking for PyTorch tensors. [Github link](https://github.com/patrick-kidger/torchtyping/blob/master/README.md)


In [None]:
!pip install torchtyping 
!pip install torch --extra-index-url https://download.pytorch.org/whl/cpu

## Gripes

#### Are you sick of tensor shape errors in your code?
#### Do you want to more interpretable error messages on shape mismatches?

In [None]:
weight = torch.randn((4, 4))
x = torch.randn((16, 4, 4))
torch.bmm(weight, x)

In [None]:
weight = torch.randn((1, 4, 4))
x = torch.randn((16, 4, 4))
torch.bmm(weight, x)

#### Are you frustrated by poor tensor shape documentation?

```python

embed = Embedding(x) # [batch, seq, dim]    descriptive
layer1 = Layer1(embed) # [b, s, d]          terse
layer2 = Layer2(layer1) # batch first       what are the other dimensions?
logits = FunkyLayer(layer2) #               no annotation at all!

def foo(x: torch.Tensor) -> torch.Tensor: # like... obviously?
    return x
``` 

#### Are `asserts` your only defense against shape mismatches?

```python
HIDDEN_DIM = 1024
logits = my_op(input)
assert logits.size(-1 == HIDDEN_DIM
```

#### Do you only want to type check when you have to?

```python
def complex_op(t):  # we want to document this!
    ...

def clamper(t):     # not worth documenting or checking
    return torch.clamp(t, max=1.0)
```

#### Do you want to check for valid shapes in unit tests before loading a massive model into memory?

#### Do you want to avoid the effort of annotating an entire codebase or implementing statically-enforced tensor types?

#### What if you're locked to a particular Python or torch version and don't want to upgrade to get experimental support for named tensors?

## BEHOLD [torchtyping](https://github.com/patrick-kidger/torchtyping)

Turn this:
```python
def batch_outer_product(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    # x has shape (batch, x_channels)
    # y has shape (batch, y_channels)
    # return has shape (batch, x_channels, y_channels)

    return x.unsqueeze(-1) * y.unsqueeze(-2)
```
into this:
```python
@typechecked
def batch_outer_product(x:   TensorType["batch", "x_channels"],
                        y:   TensorType["batch", "y_channels"]
                        ) -> TensorType["batch", "x_channels", "y_channels"]:

    return x.unsqueeze(-1) * y.unsqueeze(-2)
```

(^ shamelessly stolen from the project's [readme](https://github.com/patrick-kidger/torchtyping/blob/master/README.md))

## Tutorial

Remember type annotations ("type hints") in Python are completely optional: they aren't enforced statically or at runtime.
```python
def foo(bar: int) -> int:
    return 42
```

`mypy` allows optional _static_ type-checking.

Use `patch_typeguard` in your module to enforce tensor type checking at _runtime_

In [None]:
import torch
from torchtyping import TensorType, patch_typeguard
from typeguard import typechecked

patch_typeguard()

A _generic_ type can be specified with a `TypeVar` and used as a type annotation. e.g. 
```
from typing import List
x: List[int] = [1, 2, 3]
```

This is an example of a _variadic_ generic type. 
Variadic generics were proposed in [PEP 646](https://peps.python.org/pep-0646/) and are [coming in Python 3.11](https://mail.python.org/archives/list/python-dev@python.org/message/OR5RKV7GAVSGLVH3JAGQ6OXFAXIP5XDX/)

In [None]:
@typechecked
def mm(A: TensorType['m','n'], B: TensorType['m','p']) -> TensorType['n','p']:
    return A.T @ B

mm(torch.eye(3), torch.arange(6).float().reshape((3,2)))

What does an error look like?

In [None]:
mm(torch.eye(4), torch.arange(6).float().reshape((3,2)))

What about matrix-vector multiplication?

In [None]:
mm(torch.eye(4), torch.arange(4).float())

`tensortyping` has support for `Union` and `Optional`

In [None]:
from typing import Union

@typechecked
def mm2(A: TensorType['m','n'], B: Union[TensorType['m','p'], TensorType['m']] ) -> Union[TensorType['n','p'], TensorType['n']]:
    return A.T @ B

mm2(torch.eye(3), torch.arange(3).float())

In [None]:
mm2(torch.eye(3), torch.arange(6).float().reshape((3,2)))

Handles constant dimensions and scalar return types.

In [None]:
@typechecked
def intdot2(A: TensorType[2, int], B: TensorType[int]) -> int:
    return A.dot(B).item()

intdot2(torch.arange(2), torch.arange(2)+1)

In [None]:
intdot2(torch.randn((2,)), torch.randn((2,)))

In [None]:
intdot2(torch.randn((3,)).long(), torch.randn((3,)).long())

Can return a scalar tensor as well

In [None]:
@typechecked
def scalar_intdot2(A: TensorType[2, int], B: TensorType[int]) -> TensorType[()]:
    return A.dot(B)

scalar_intdot2(torch.arange(2), torch.arange(2)+1)

If you want to handle an arbitrary number of dimensions as a single tuple, use `<name>: ...`.

In [None]:
@typechecked
def add_one(x: TensorType['d': ...]) -> TensorType['d': ...]:
    return x + 1

add_one(torch.arange(6).reshape((1,2,3)))

In [None]:
@typechecked
def bad_add_one(x: TensorType['d': ...]) -> TensorType['d': ...]:
    return (x + 1).squeeze()

bad_add_one(torch.arange(6).reshape((1,2,3)))

## torchtyping limitations

1. No linting support (with `mypy` or `flake8`). [This is documented](https://github.com/patrick-kidger/torchtyping/blob/master/FURTHER-DOCUMENTATION.md).

1. No static type checking support. [This is documented](https://github.com/patrick-kidger/torchtyping/blob/master/FURTHER-DOCUMENTATION.md).
    
    **BUT I want static tensor type-checking!**

    You can simulate this. Define unit tests to run with `pytest`, using the `@typechecked` decorator, and invoking `pytest` like this:
    ```bash
    pytest --torchtyping-patch-typeguard --tb=short
    ```

    **_NO_ I want _actual_ static type checking**
    
    Check out PyTorch's [named tensors](https://pytorch.org/docs/stable/named_tensor.html#creating-named-tensors) or HarvardNLP's [NamedTensor](https://github.com/harvardnlp/NamedTensor).

1. Types aren't strong and don't propagate like PyTorch's own [named tensors](https://pytorch.org/docs/stable/named_tensor.html#creating-named-tensors)
    E.g. named dimensions with
    ```python
    >>> x = torch.randn(3, 3, names=('N', 'C'))
    >>> x.abs().names
    ('N', 'C')
    ```

    *HOWEVER*
    * Named tensors are still experimental (have been for several years), is on [hiatus](https://github.com/pytorch/pytorch/issues/60832), and [may be deprecated entirely](https://github.com/pytorch/pytorch/pull/76093).
    * You lose the convenience of dynamic typing with Python.
    * Named dimensions [do not propagate through `autograd`](https://pytorch.org/docs/stable/named_tensor.html#autograd-support).

Fortunately lots of these are documented and being explored by the `torchtyping` author. See [here](https://github.com/patrick-kidger/torchtyping/blob/master/FURTHER-DOCUMENTATION.md).
Also [PEP 646](https://peps.python.org/pep-0646/) is coming in [Python 3.11](https://mail.python.org/archives/list/python-dev@python.org/message/OR5RKV7GAVSGLVH3JAGQ6OXFAXIP5XDX/)
