# Named dimensions
- [PyTorch named tensor](https://pytorch.org/docs/stable/named_tensor.html)
- [Jax Typing](https://docs.kidger.site/jaxtyping/) [Line annotations does not work as expected yet.](https://github.com/patrick-kidger/jaxtyping/issues/153)

We will look at Jax Typing and how to do named dimension checking with a code example. This is really useful as the code is much easier to read and on top of that input and output dimensions are checked. This is probably only real option currently as sadly Pytorch named tensor project is mainly abandoned? Let us proceed to code example.

In [2]:
import typing
from typing import Optional, Union, Tuple
import torch
from torch import nn
from torch.functional import F
from jaxtyping import Float, Int64, jaxtyped
# Use your favourite typechecker: usually one of the two lines below.
import beartype
from typeguard import typechecked as typechecker
%reload_ext jaxtyping
# beartype does random sampling for this example we want statis
# %jaxtyping.typechecker beartype.beartype
%jaxtyping.typechecker typechecker
#jaxtyping.install_import_hook(module, typechecker=beartype.beartype)

Array: typing.TypeAlias = torch.Tensor
Long: typing.TypeAlias = Int64

vocab_size = 65
token_embedding_table = nn.Embedding(vocab_size, vocab_size)


#@jaxtyped(typechecker=typechecker)
def forward(idx: Long[torch.Tensor, "batch_dim context_dim"],
            targets: Optional[Long[torch.Tensor, "batch_dim context_dim"]] = None) -> Tuple[
    Union[Float[torch.Tensor, "batch_dim*context_dim latent_dim"], Float[
        torch.Tensor, "batch_dim context_dim latent_dim"]], Union[Float[torch.Tensor, ""], None]]:
    # idx and targets are both (B,T) tensor of integers
    logits: Long[torch.Tensor, "batch_dim context_dim latent_dim"] = token_embedding_table(
        idx)  # (B,T,C=vocab_size)
    if targets is None:
        return logits, None
    else:
        # Note that here strictly speaking this does not fix batch size explicitly to B
        B, T, C = logits.shape  # (B,T,C=vocab_size)
        # Just a hack to avoid transposing, cross_entropy expects B x C x T in batched mode
        # This converts into non batched mode
        logits: Long[torch.Tensor, "batch_dim*context_dim latent_dim"] = logits.view(B * T, C)
        # The above is clearly wrong but will not be checked currently
        targets: Long[torch.Tensor, "batch_dim*context_dim"] = targets.view(B * T)
        # https://agustinus.kristia.de/techblog/2016/12/21/forward-reverse-kl/
        loss: Float[torch.Tensor, ""] = F.cross_entropy(logits, targets)
        return logits, loss


def forward_broken(idx: Long[torch.Tensor, "batch_dim context_dim"],
            targets: Long[torch.Tensor, "batch_dim context_dim"]) -> Tuple[
    Float[torch.Tensor, "batch_dim context_dim latent_dim"], Float[torch.Tensor, ""]]:
    # idx and targets are both (B,T) tensor of integers
    logits: Long[torch.Tensor, "batch_dim context_dim latent_dim"] = token_embedding_table(
        idx)  # (B,T,C=vocab_size)
    # Note that here strictly speaking this does not fix batch size explicitly to B
    B, T, C = logits.shape  # (B,T,C=vocab_size)
    # Just a hack to avoid transposing, cross_entropy expects B x C x T in batched mode
    # This converts into non batched mode
    logits: Long[torch.Tensor, "batch_dim*context_dim latent_dim"] = logits.view(B * T, C)
    # The above is clearly wrong but will not be checked currently
    targets: Long[torch.Tensor, "batch_dim*context_dim"] = targets.view(B * T)
    # https://agustinus.kristia.de/techblog/2016/12/21/forward-reverse-kl/
    loss: Float[torch.Tensor, ""] = F.cross_entropy(logits, targets)
    return logits, loss


idx = torch.randint(low=0, high=65, size=(256, 65))
targets = torch.randint(low=0, high=65, size=(256, 65))
logits_not_none, loss_not_none = forward(idx, targets)
logits_none, loss_none = forward(idx, None)
logits_broken, loss_broken = forward_broken(idx, targets)

TypeCheckError: Type-check error whilst checking the return value of forward_broken.
Actual value: ( tensor([[ 0.7932,  1.2100, -0.0075,  ..., -1.1108, -0.0285, -0.5872],
        [ 1.2552,  0.1470, -1.0430,  ...,  0.5065, -0.6637,  0.9305],
        [-0.6178,  0.4888, -1.0378,  ...,  2.0027,  1.0175, -1.0587],
        ...,
        [-1.5059, -0.8984, -1.0045,  ..., -0.0105,  0.7734,  0.7782],
        [-0.9111,  0.1937, -2.5012,  ...,  0.4195, -0.7629, -0.6584],
        [ 0.1368, -0.4695, -0.3229,  ..., -0.6181,  0.5164,  1.2186]],
       grad_fn=<ViewBackward0>),
  tensor(4.6550, grad_fn=<NllLossBackward0>))
Expected type: typing.Tuple[Float[Tensor, 'batch_dim context_dim latent_dim'], Float[Tensor, '']].
----------------------
Called with parameters: { 'idx': tensor([[63, 25,  9,  ..., 28,  0, 61],
        [ 8, 37, 18,  ..., 57,  5, 39],
        [50, 13,  0,  ..., 21, 33, 45],
        ...,
        [35, 16, 49,  ..., 54, 11, 20],
        [22, 34, 51,  ...,  9, 25, 62],
        [64, 35,  6,  ..., 62, 10, 58]]),
  'targets': tensor([[16, 63, 24,  ...,  8,  9, 46],
        [ 0, 15,  8,  ..., 31,  5, 22],
        [ 2, 32, 40,  ..., 23,  0, 18],
        ...,
        [31, 47, 10,  ..., 31,  4, 15],
        [32, 57, 46,  ..., 31, 24, 45],
        [60, 50,  3,  ..., 36, 53, 50]])}
Parameter annotations: (idx: Int64[Tensor, 'batch_dim context_dim'], targets: Int64[Tensor, 'batch_dim context_dim']).
The current values for each jaxtyping axis annotation are as follows.
batch_dim=256
context_dim=65