# Named dimensions
- [PyTorch named tensor](https://pytorch.org/docs/stable/named_tensor.html)
- [Jax Typing](https://docs.kidger.site/jaxtyping/) [Line annotations does not work as expected yet.](https://github.com/patrick-kidger/jaxtyping/issues/153)

We will look at Jax Typing and how to do named dimension checking with a code example. This is really useful as the code is much easier to read and on top of that input and output dimensions are checked. This probably only real option currently as sadly Pytorch named tensor project is mainly abandoned? Let us proceed to code example.

In [1]:
import typing
from typing import Optional, Union, Tuple
import torch
from torch import nn
from torch.functional import F
from jaxtyping import Float, Int64, jaxtyped
# Use your favourite typechecker: usually one of the two lines below.
import beartype
from typeguard import typechecked as typechecker
%reload_ext jaxtyping
%jaxtyping.typechecker beartype.beartype
#jaxtyping.install_import_hook(module, typechecker=beartype.beartype)

Array: typing.TypeAlias = torch.Tensor
Long: typing.TypeAlias = Int64

vocab_size = 65
token_embedding_table = nn.Embedding(vocab_size, vocab_size)


#@jaxtyped(typechecker=typechecker)
def forward(idx: Long[torch.Tensor, "batch_dim context_dim"],
            targets: Optional[Long[torch.Tensor, "batch_dim context_dim"]] = None) -> Tuple[
    Union[Float[torch.Tensor, "batch_dim*context_dim latent_dim"], Float[
        torch.Tensor, "batch_dim context_dim latent_dim"]], Union[Float[torch.Tensor, ""], None]]:
    # idx and targets are both (B,T) tensor of integers
    logits: Long[torch.Tensor, "batch_dim context_dim latent_dim"] = token_embedding_table(
        idx)  # (B,T,C=vocab_size)
    if targets is None:
        return logits, None
    else:
        # Note that here strictly speaking this does not fix batch size explicitly to B
        B, T, C = logits.shape  # (B,T,C=vocab_size)
        # Just a hack to avoid transposing, cross_entropy expects B x C x T in batched mode
        # This converts into non batched mode
        logits: Long[torch.Tensor, "batch_dim*context_dim latent_dim"] = logits.view(B * T, C)
        # The above is clearly wrong but will not be checked currently
        targets: Long[torch.Tensor, "batch_dim*context_dim"] = targets.view(B * T)
        # https://agustinus.kristia.de/techblog/2016/12/21/forward-reverse-kl/
        loss: Float[torch.Tensor, ""] = F.cross_entropy(logits, targets)
        return logits, loss


idx = torch.randint(low=0, high=65, size=(256, 65))
targets = torch.randint(low=0, high=65, size=(256, 65))
out_not_none = forward(idx, targets)
out_none = forward(idx, None)