# torchtyping

#### Quick tutorial of runtime type-checking for PyTorch tensors, for deep learning practitioners. 
[`torchtyping` Github link](https://github.com/patrick-kidger/torchtyping/blob/master/README.md)

(DISCLAIMER: I'm not the author of `torchtyping` and I'm not a programming language expert or type-theory afficionado)

In [None]:
!pip install torchtyping 
!pip install torch --extra-index-url https://download.pytorch.org/whl/cpu

## Gripes

#### Are you sick of tensor shape errors in your code?
#### Do you want to more interpretable error messages on shape mismatches?

In [None]:
import torch
weight = torch.randn((4, 4))
x = torch.randn((16, 4, 4))
torch.bmm(weight, x)

In [None]:
weight = torch.randn((1, 4, 4))
x = torch.randn((16, 4, 4))
torch.bmm(weight, x)

#### Are you frustrated by poor tensor shape documentation?

Done in code comments - if the code is ever commented at all.

```python
embed = Embedding(x) # [batch, seq, dim]    descriptive
layer1 = Layer1(embed) # [b, s, d]          terse
layer2 = Layer2(layer1) # batch first       what are the other dimensions?
logits = FunkyLayer(layer2) #               no annotation at all!

def foo(x: torch.Tensor) -> torch.Tensor: # very generic type annotations. like... obviously?
    return x
``` 

#### Are `asserts` your only defense against shape mismatches?

```python
HIDDEN_DIM = 1024
logits = my_op(input)
assert logits.size(-1) == HIDDEN_DIM
```

#### Are you lazy and only want to check shapes when you have to?

```python
def complex_tensor_wizardry_op(t):  # we want to document this!
    ...

def clamper(t):                     # not worth documenting or checking
    return torch.clamp(t, max=1.0)
```

#### Do you want to check for valid shapes in unit tests before loading a massive model into memory?

10 seconds to 10 minutes wasted on model loading, distributed training initialization, and priming data loaders only to find out that you transposed dims or forgot to `(un)squeeze`.

#### Do you want to avoid the effort of annotating an entire codebase?

C'mon, you want to spend your time running experiments - not tackling technical debt!

#### What if you're locked to a particular Python or torch version and don't want to upgrade to get experimental support for named tensors?

You're pinned to that golden PyTorch 1.2.0 Docker image that is the only one that works in production.

## BEHOLD [torchtyping](https://github.com/patrick-kidger/torchtyping)

Turn this:

```python
def batch_outer_product(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
    # x has shape (batch, x_channels)
    # y has shape (batch, y_channels)
    # return has shape (batch, x_channels, y_channels)

    return x.unsqueeze(-1) * y.unsqueeze(-2)
```

into this:

```python
@typechecked
def batch_outer_product(x:   TensorType["batch", "x_channels"],
                        y:   TensorType["batch", "y_channels"]
                        ) -> TensorType["batch", "x_channels", "y_channels"]:

    return x.unsqueeze(-1) * y.unsqueeze(-2)
```

(^ shamelessly stolen from the project's [readme](https://github.com/patrick-kidger/torchtyping/blob/master/README.md))

## Tutorial

Remember type annotations ("type hints") in Python are completely optional: __they aren't enforced statically or at runtime__.
```python
def foo(bar: int) -> int:
    return 42
```

`mypy` allows **optional static** type-checking.

Use `torchtyping`'s `patch_typeguard` in your module to enforce tensor type checking at **runtime**.

In [None]:
from torchtyping import TensorType, patch_typeguard
from typeguard import typechecked

patch_typeguard()

A **generic** type can be specified with a `TypeVar` and used as a type annotation. e.g.

```python
from typing import List
x: List[int] = [1, 2, 3]
```

Tensors are multidimensional, so they can have a different generic type for each dimension/axis.

We want something like this:

```python
x: Tensor['batch', 'seqlen', 'hidden'] = torch.randn((16, 128, 1024))
```

Tensors can also have layout and sparsity properties, as well as actual data type (`float64`, `int`, etc.).

This specification is called a **variadic generic** type. 
Variadic generics were proposed in [PEP 646](https://peps.python.org/pep-0646/) and are [coming in Python 3.11](https://mail.python.org/archives/list/python-dev@python.org/message/OR5RKV7GAVSGLVH3JAGQ6OXFAXIP5XDX/)

#### **`torchtyping` enables variadic generics for PyTorch tensors.**

These can be used in function signatures or left-hand-side (LHS) type annotations.

`torchtyping` can __enforce__ variadic generic types on function signatures at runtime or during testing.

In [None]:
@typechecked
def mm(A: TensorType['m','n'], B: TensorType['m','p']) -> TensorType['n','p']:
    return A.T @ B

mm(torch.eye(3), torch.arange(6).float().reshape((3,2)))

What does a `torchtyping` error look like?

In [None]:
mm(torch.eye(4), torch.arange(6).float().reshape((3,2)))

What about matrix-vector multiplication?

In [None]:
mm(torch.eye(4), torch.arange(4).float())

`tensortyping` has support for `Union` and `Optional`

In [None]:
from typing import Union

@typechecked
def mm2(A: TensorType['m', 'n'], B: Union[TensorType['m', 'p'], TensorType['m']] ) -> Union[TensorType['n',  'p'], TensorType['n']]:
    return A.T @ B

# matrix-vector
mm2(torch.eye(3), torch.arange(3).float())

In [None]:
# matrix-matrix
mm2(torch.eye(3), torch.arange(6).float().reshape((3,2)))

`torchtyping` handles constant dimensions and scalar return types as well.

In [None]:
@typechecked
def intdot2(A: TensorType[2, int], B: TensorType[int]) -> int:
    return A.dot(B).item()

# [0, 1] . [1, 2]
intdot2(torch.arange(2), torch.arange(2)+1)

In [None]:
# [float, float] . [float, float]
intdot2(torch.randn((2,)), torch.randn((2,)))

In [None]:
# [long, long, long] . [long, long, long]
intdot2(torch.randn((3,)).long(), torch.randn((3,)).long())

`torchtyping` can return a scalar tensor.

In [None]:
@typechecked
def scalar_intdot2(A: TensorType[2, int], B: TensorType[int]) -> TensorType[()]:
    return A.dot(B)

scalar_intdot2(torch.arange(2), torch.arange(2)+1)

If you want to handle group a sequence of dimensions together as a single tuple, use `<dimension group name>: ...`.

In [None]:
# adds one to the whole tensor without changing its shape.

@typechecked
def add_one(x: TensorType['dims': ...]) -> TensorType['dims': ...]:
    return x + 1

add_one(torch.arange(6).reshape((1,2,3)))

In [None]:
@typechecked
def bad_add_one(x: TensorType['dims': ...]) -> TensorType['dims': ...]:
    return (x + 1).squeeze() # could reduce number of dimensions

bad_add_one(torch.arange(6).reshape((1,2,3)))

## torchtyping limitations

#### 1. No real linting support with `mypy` or `flake8`

Add `# noqa` everywhere to make linters stop complaining. [This is documented](https://github.com/patrick-kidger/torchtyping/blob/master/FURTHER-DOCUMENTATION.md).

#### 2. Types aren't static and have no static type checking support.

[This is documented](https://github.com/patrick-kidger/torchtyping/blob/master/FURTHER-DOCUMENTATION.md).
    
##### BUT I want static tensor type-checking!

You can test sensitive functions with unit tests. Define unit tests with the `@typechecked` decorator, to be run with `pytest` using the `--torchtyping-patch-typeguard` option:

```bash
pytest --torchtyping-patch-typeguard --tb=short
```

##### _NO_ I want _actual_ static type checking

Check out PyTorch's [named tensors](https://pytorch.org/docs/stable/named_tensor.html#creating-named-tensors) or HarvardNLP's [NamedTensor](https://github.com/harvardnlp/NamedTensor).

#### 3. Types aren't strong

`torchtyping` doesn't enforce **strong** types on tensors. I.e. tensors aren't required to strictly have the specified named dimensions in order to execute the program.

Strongly-typed tensors are enabled by PyTorch's own [`named tensors`](https://pytorch.org/docs/stable/named_tensor.html#creating-named-tensors).

Named dimensions also propagate (generally) with `named tensors.

E.g.

```python
>>> x = torch.randn(3, 3, names=('N', 'C'))
>>> x.abs().names
('N', 'C')
```

**However** there are some caveats PyTorch's `named tensors`:

* Still experimental and has been for several years. The feature is on [hiatus](https://github.com/pytorch/pytorch/issues/60832), and [may be deprecated entirely](https://github.com/pytorch/pytorch/pull/76093).

* You lose the convenience of dynamic typing with Python.

* Named dimensions [do not propagate through `autograd`](https://pytorch.org/docs/stable/named_tensor.html#autograd-support).

## Good News

Fortunately lots of these limitations are documented and being explored by the `torchtyping` author. See [here](https://github.com/patrick-kidger/torchtyping/blob/master/FURTHER-DOCUMENTATION.md).

Also [PEP 646](https://peps.python.org/pep-0646/) is coming in [Python 3.11](https://mail.python.org/archives/list/python-dev@python.org/message/OR5RKV7GAVSGLVH3JAGQ6OXFAXIP5XDX/)