# Named tensor notation with funsors (Part 1)

## Introduction

**This is a translation of the [Named Tensor Notation (Chiang, Rush, Barak 2021)](https://namedtensor.github.io/) example from the [funsor](https://funsor.pyro.ai/) library to `effectful`. Much of the expository text is taken directly from the original.**

The mathematical notation with *named axes* introduced in [Named Tensor Notation (Chiang, Rush, Barak 2021)](https://namedtensor.github.io/) improves the readability of mathematical formulas involving multidimensional arrays. This includes tensor operations such as elementwise operations, reductions, contractions, renaming, indexing, and broadcasting. Part 1 covers examples from [2 Informal Overview](https://namedtensor.github.io/#sec:overview), [3.4.2 Advanced Indexing](https://namedtensor.github.io/#sec:examples), and [5 Formal Definitions](https://namedtensor.github.io/#sec:definitions).

In [1]:
import functools

import torch
from torch import tensor

from effectful.ops.core import evaluate
from effectful.ops.handler import handler
from effectful.internals.sugar import gensym, torch_getitem
from effectful.indexed.ops import Indexable, to_tensor


def subst(term, substs):
    with handler(
        {k: functools.partial(lambda vv: vv, v) for (k, v) in substs.items()},
    ):
        return evaluate(term)


def reduce(indexes, indexed_tensor, reducer):
    """Reduce an indexed tensor along one or more named dimensions.

    Args:
    - indexes: Names of dimensions to reduce.
    - indexed_tensor: The tensor to reduce.
    - reducer: A reduction function like `torch.sum`. Must take `tensor`, `dim`, and `keepdim` arguments.

    Returns: A new indexed tensor with the specified dimensions reduced.

    Example:
    >>> width, height = gensym(int, name='width'), gensym(int, name='height')
    >>> t = indexed(torch.ones(2, 3))[width(), height()]
    >>> reduce([width], t, "sum")
    indexed(tensor([2., 2., 2.]))[height()]
    """
    num_positional = len(indexed_tensor.shape)
    # convert indexed dimensions to positional and flatten all new positional dims
    t = to_tensor(indexed_tensor, indexes)
    new_positional = len(t.shape) - num_positional
    t_flat = torch.flatten(t, 0, new_positional - 1)

    # reduce dim 0 into the first index of dim 0, then return reduction
    return reducer(t_flat, 0, keepdim=True)[0]

## Named Tensors

Each tensor axis is given a name:

$$
\begin{aligned}
  A &\in \mathbb{R}^{\mathsf{\vphantom{fg}height}[3] \times \mathsf{\vphantom{fg}width}[3]} = \mathbb{R}^{\mathsf{\vphantom{fg}width}[3] \times \mathsf{\vphantom{fg}height}[3]} \\
  A &= \mathsf{\vphantom{fg}height}
  \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\\begin{bmatrix}
    3 & 1 & 4 \\
    1 & 5 & 9 \\
    2 & 6 & 5
  \end{bmatrix}\end{array} =
  \mathsf{\vphantom{fg}width}
  \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}height}\\\begin{bmatrix}
    3 & 1 & 2 \\
    1 & 5 & 6 \\
    4 & 9 & 5
  \end{bmatrix}\end{array}.
\end{aligned}
$$

In [2]:
height, width = gensym(int, name="height"), gensym(int, name="width")
t = tensor([[3, 1, 4], [1, 5, 9], [2, 6, 5]])
A = Indexable(tensor([[3, 1, 4], [1, 5, 9], [2, 6, 5]]))[height(), width()]
A

Indexable(tensor([[3, 1, 4],
                  [1, 5, 9],
                  [2, 6, 5]]))[height(), width()]

Access elements of $A$ using named indices:

$$
A_{\mathsf{\vphantom{fg}height}(1), \mathsf{\vphantom{fg}width}(3)} = A_{\mathsf{\vphantom{fg}width}(3), \mathsf{\vphantom{fg}height}(1)} = 4
$$

In [3]:
subst(A, {height: 0, width: 2})

tensor(4)

Partial indexing:

$$
\begin{aligned}
A_{\mathsf{\vphantom{fg}height}(1)} &= \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\
\begin{bmatrix}
  3 & 1 & 4
\end{bmatrix}\end{array}
&
A_{\mathsf{\vphantom{fg}width}(3)} &= \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}height}\\
\begin{bmatrix}
  4 & 9 & 5
\end{bmatrix}\end{array}.
\end{aligned}
$$

In [4]:
subst(A, {height: 0})

Indexable(tensor([3, 1, 4]))[width()]

In [5]:
subst(A, {width: 2})

Indexable(tensor([4, 9, 5]))[height()]

## Named tensor operations

### Elementwise operations and broadcasting

Elementwise operations:

$$
\frac1{1+\exp(-A)} = \mathsf{\vphantom{fg}height}
\begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\
\begin{bmatrix}
  \frac 1{1+\exp(-3)} & \frac 1{1+\exp(-1)} & \frac 1{1+\exp(-4)} \\[1ex]
  \frac 1{1+\exp(-1)} & \frac 1{1+\exp(-5)} & \frac 1{1+\exp(-9)} \\[1ex]
  \frac 1{1+\exp(-2)} & \frac 1{1+\exp(-6)} & \frac 1{1+\exp(-5)}
\end{bmatrix}\end{array}.
$$

In [6]:
1 / (1 + (-A).exp())

Indexable(tensor([[0.9526, 0.7311, 0.9820],
                  [0.7311, 0.9933, 0.9999],
                  [0.8808, 0.9975, 0.9933]]))[height(), width()]

Tensors with different shapes are automatically broadcasted against each other before an operation is applied. Let

$$
\begin{aligned}
  x &\in \mathbb{R}^{\mathsf{\vphantom{fg}height}[3]} & y &\in \mathbb{R}^{\mathsf{\vphantom{fg}width}[3]} \\
  x &= \mathsf{\vphantom{fg}height}
  \begin{array}[b]{@{}c@{}}\\
  \begin{bmatrix}
    2 \\ 7 \\ 1
  \end{bmatrix}\end{array} & 
  y &= 
  \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\\begin{bmatrix}
    1 & 4 & 1
  \end{bmatrix}\end{array}.
\end{aligned}
$$

In [7]:
x = Indexable(tensor([2, 7, 1]))[height()]
y = Indexable(tensor([1, 4, 1]))[width()]

Binary addition operation:

$$
\begin{aligned}
A + x &= \mathsf{\vphantom{fg}height}
\begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\\begin{bmatrix}
  3+2 & 1+2 & 4+2 \\
  1+7 & 5+7 & 9+7 \\
  2+1 & 6+1 & 5+1
\end{bmatrix}\end{array} &
A + y &= \mathsf{\vphantom{fg}height}
\begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\\begin{bmatrix}
  3+1 & 1+4 & 4+1 \\
  1+1 & 5+4 & 9+1 \\
  2+1 & 6+4 & 5+1
\end{bmatrix}\end{array}.
\end{aligned}
$$

In [8]:
A + x

Indexable(tensor([[ 5,  3,  6],
                  [ 8, 12, 16],
                  [ 3,  7,  6]]))[height(), width()]

In [9]:
A + y

Indexable(tensor([[ 4,  5,  5],
                  [ 2,  9, 10],
                  [ 3, 10,  6]]))[height(), width()]

Binary multiplication operation:

$$
A \odot x = \mathsf{\vphantom{fg}height}
\begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\\begin{bmatrix}
  3\cdot2 & 1\cdot2 & 4\cdot2 \\
  1\cdot7 & 5\cdot7 & 9\cdot7 \\
  2\cdot1 & 6\cdot1 & 5\cdot1
\end{bmatrix}\end{array}
$$

In [10]:
A * x

Indexable(tensor([[ 6,  2,  8],
                  [ 7, 35, 63],
                  [ 2,  6,  5]]))[height(), width()]

Binary maximum operation:

$$
\max(A, y) = \mathsf{\vphantom{fg}height}
\begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\\begin{bmatrix}
  \max(3, 1) & \max(1, 4) & \max(4, 1) \\
  \max(1, 1) & \max(5, 4) & \max(9, 1) \\
  \max(2, 1) & \max(6, 4) & \max(5, 1)
\end{bmatrix}\end{array}.
$$

In [11]:
torch.max(A, y)

Indexable(tensor([[3, 4, 4],
                  [1, 5, 9],
                  [2, 6, 5]]))[height(), width()]

### Reductions

Named axes can be reduced over by calling the `.reduce` method and specifying the [reduction operator](https://en.wikipedia.org/wiki/Reduction_Operator) and names of reduced axes. Note that reduction is defined only for operators that are associative and commutative.

$$
\sum\limits_{\substack{\mathsf{\vphantom{fg}height}}} A = \sum_i A_{\mathsf{\vphantom{fg}height}(i)} = \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\
\begin{bmatrix}
  3+1+2 & 1+5+6 & 4+9+5
\end{bmatrix}\end{array}.
$$

In [12]:
reduce([height], A, torch.sum)

Indexable(tensor([ 6, 12, 18]))[width()]

$$
\sum\limits_{\substack{\mathsf{\vphantom{fg}width}}} A = \sum_j A_{\mathsf{\vphantom{fg}width}(j)} = \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}height}\\
\begin{bmatrix}
  3+1+4 & 1+5+9 & 2+6+5
\end{bmatrix}\end{array}.
$$

In [13]:
reduce([width], A, torch.sum)

Indexable(tensor([ 8, 15, 13]))[height()]

Reduction over multiple axes:

$$
\sum\limits_{\substack{\mathsf{\vphantom{fg}height}\\
 \mathsf{\vphantom{fg}width}}} A = \sum_i \sum_j A_{\mathsf{\vphantom{fg}height}(i),\mathsf{\vphantom{fg}width}(j)} = 3+1+4+1+5+9+2+6+5.
 $$

In [14]:
reduce([height, width], A, torch.sum)

tensor(36)

Multiplication reduction:

$$
\prod\limits_{\substack{\mathsf{\vphantom{fg}height}}} A = \prod_i A_{\mathsf{\vphantom{fg}height}(i)} = \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\
\begin{bmatrix}
  3\cdot1\cdot2 & 1\cdot5\cdot6 & 4\cdot9\cdot5
\end{bmatrix}\end{array}.
$$

In [15]:
reduce([height], A, torch.prod)

Indexable(tensor([  6,  30, 180]))[width()]

Max reduction:

$$
\max\limits_{\substack{\mathsf{\vphantom{fg}height}}} A = \max \{A_{\mathsf{\vphantom{fg}height}(i)} \mid 1 \leq i \leq n\} = \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}\\
\begin{bmatrix}
  \max(3, 1, 2) & \max(1, 5, 6) & \max(4, 9, 5)
\end{bmatrix}\end{array}.
$$

In [16]:
reduce([height], A, torch.amax)

Indexable(tensor([3, 6, 9]))[width()]

### Contraction

Contraction operation can be written as elementwise multiplication followed by summation over an axis:

$$
A \mathbin{\underset{\substack{\mathsf{\vphantom{fg}width}}}{\vphantom{fg}\odot}} y = \sum_j A_{\mathsf{\vphantom{fg}width}(j)} \, y_{\mathsf{\vphantom{fg}width}(j)} = \mathsf{\vphantom{fg}height}
\begin{array}[b]{@{}c@{}}\\\begin{bmatrix}
  3\cdot 1 + 1\cdot 4 + 4\cdot 1 \\
  1\cdot 1 + 5\cdot 4 + 9\cdot 1 \\
  2\cdot 1 + 6\cdot 4 + 5\cdot 1
\end{bmatrix}\end{array}.
$$

In [17]:
reduce([width], A * y, torch.sum)

Indexable(tensor([11, 30, 31]))[height()]

Some other operations from linear algebra:

$$
x \mathbin{\underset{\substack{\mathsf{\vphantom{fg}height}}}{\vphantom{fg}\odot}} x = \sum_i x_{\mathsf{\vphantom{fg}height}(i)} \, x_{\mathsf{\vphantom{fg}height}(i)} \qquad \text{inner product}
$$

In [18]:
reduce([height], x * x, torch.sum)

tensor(54)

$$
[x \odot y]_{\mathsf{\vphantom{fg}height}(i), \mathsf{\vphantom{fg}width}(j)} = x_{\mathsf{\vphantom{fg}height}(i)} \, y_{\mathsf{\vphantom{fg}width}(j)} \qquad \text{outer product}
$$

In [19]:
x * y

Indexable(tensor([[ 2,  8,  2],
                  [ 7, 28,  7],
                  [ 1,  4,  1]]))[height(), width()]

$$
A \mathbin{\underset{\substack{\mathsf{\vphantom{fg}width}}}{\vphantom{fg}\odot}} y = \sum_i A_{\mathsf{\vphantom{fg}width}(i)} \, y_{\mathsf{\vphantom{fg}width}(i)} \qquad \text{matrix-vector product}
$$

In [20]:
reduce([width], A * y, torch.sum)

Indexable(tensor([11, 30, 31]))[height()]

$$
x \mathbin{\underset{\substack{\mathsf{\vphantom{fg}height}}}{\vphantom{fg}\odot}} A = \sum_i x_{\mathsf{\vphantom{fg}height}(i)} \, A_{\mathsf{\vphantom{fg}height}(i)} \qquad \text{vector-matrix product} \\
$$

In [21]:
reduce([height], x * A, torch.sum)

Indexable(tensor([15, 43, 76]))[width()]

$$
A \mathbin{\underset{\substack{\mathsf{\vphantom{fg}width}}}{\vphantom{fg}\odot}} B = \sum_i A_{\mathsf{\vphantom{fg}width}(i)} \odot B_{\mathsf{\vphantom{fg}width}(i)} \qquad \text{matrix-matrix product}~(B \in \mathbb{R}^{\mathsf{\vphantom{fg}width}\times \mathsf{\vphantom{fg}width2}})
$$

In [22]:
width2 = gensym(int, name="width2")
B = Indexable(
    tensor([[3, 2, 5], [5, 4, 0], [8, 3, 6]]),
)[width(), width2()]

reduce([width], A * B, torch.sum)

Indexable(tensor([[ 46,  22,  39],
                  [100,  49,  59],
                  [ 76,  43,  40]]))[height(), width2()]

Contraction can be generalized to other binary and reduction operations:

$$
\max_{\mathsf{\vphantom{fg}width}} (A + y) = \mathsf{\vphantom{fg}height}
\begin{array}[b]{@{}c@{}}\\\begin{bmatrix}
  \max(3+1, 1+4, 4+1) \\
  \max(1+1, 5+4, 9+1) \\
  \max(2+1, 6+4, 5+1)
\end{bmatrix}\end{array}.
$$

In [23]:
reduce([width], A + y, torch.amax)

Indexable(tensor([ 5, 10, 10]))[height()]

### Renaming and reshaping

Renaming funsors is simple:

$$
A_{\mathsf{\vphantom{fg}height}\rightarrow\mathsf{\vphantom{fg}height2}} = \mathsf{\vphantom{fg}height2}
\begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}
\\\begin{bmatrix}
  3 & 1 & 4 \\
  1 & 5 & 9 \\
  2 & 6 & 5 \\
\end{bmatrix}\end{array}.
$$

In [24]:
height2 = gensym(int, name="height2")
subst(A, {height: height2()})

Indexable(tensor([[3, 1, 4],
                  [1, 5, 9],
                  [2, 6, 5]]))[height2(), width()]

$$
A_{(\mathsf{\vphantom{fg}height},\mathsf{\vphantom{fg}width})\rightarrow\mathsf{\vphantom{fg}layer}} = \begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}layer}\\
\begin{bmatrix}
    3 & 1 & 4 & 1 & 5 & 9 & 2 & 6 & 5
\end{bmatrix}\end{array}
$$

In [25]:
layer = gensym(int, name="layer")
A_layer = subst(A, {height: layer() // 3, width: layer() % 3})
print(subst(A_layer, {layer: 2}))

tensor(4)


$$
A_{\mathsf{\vphantom{fg}layer}\rightarrow(\mathsf{\vphantom{fg}height},\mathsf{\vphantom{fg}width})} = \mathsf{\vphantom{fg}height}
\begin{array}[b]{@{}c@{}}\mathsf{\vphantom{fg}width}
\\\begin{bmatrix}
  3 & 1 & 4 \\
  1 & 5 & 9 \\
  2 & 6 & 5 \\
\end{bmatrix}\end{array}.
$$

In [26]:
print(subst(A_layer, {layer: height() * 3 + width() % 3}))

_torch_op(tensor([[3, 1, 4],
        [1, 5, 9],
        [2, 6, 5]]), [floordiv(add(mul(height(), 3), mod(width(), 3)), 3), mod(add(mul(height(), 3), mod(width(), 3)), 3)])


## Advanced indexing

All of advanced indexing can be achieved through name substitutions in funsors.

$$
\mathop{\underset{\substack{\mathsf{\vphantom{fg}ax}}}{\vphantom{fg}\mathrm{index}}} \colon \mathbb{R}^{\mathsf{\vphantom{fg}ax}[n]} \times [n] \rightarrow \mathbb{R}\\
\mathop{\underset{\substack{\mathsf{\vphantom{fg}ax}}}{\vphantom{fg}\mathrm{index}}}(A, i) = A_{\mathsf{\vphantom{fg}ax}(i)}.
$$

$$
\begin{aligned}
  E &\in \mathbb{R}^{\mathsf{\vphantom{fg}vocab}[n] \times \mathsf{\vphantom{fg}emb}} \\
  i &\in [n] \\
  I &\in [n]^{\mathsf{\vphantom{fg}seq}} \\
  P &\in \mathbb{R}^{\mathsf{\vphantom{fg}seq}\times \mathsf{\vphantom{fg}vocab}[n]}
\end{aligned}
$$

Partial indexing $\mathop{\underset{\substack{\mathsf{\vphantom{fg}vocab}}}{\vphantom{fg}\mathrm{index}}}(E,i)$:

In [27]:
vocab, emb = gensym(int, name="vocab"), gensym(int, name="emb")
E = Indexable(
    tensor([[2, 1, 5], [3, 4, 2], [1, 3, 7], [1, 4, 3], [5, 9, 2]]),
)[vocab(), emb()]

subst(E, {vocab: 2})

Indexable(tensor([1, 3, 7]))[emb()]

Integer array indexing $\mathop{\underset{\substack{\mathsf{\vphantom{fg}vocab}}}{\vphantom{fg}\mathrm{index}}}(E,I)$:

In [28]:
seq = gensym(int, name="seq")
I = Indexable(tensor([3, 2, 4, 0]))[seq()]

subst(E, {vocab: I})

Indexable(tensor([[1, 4, 3],
                  [1, 3, 7],
                  [5, 9, 2],
                  [2, 1, 5]]))[seq(), emb()]

Gather operation $\mathop{\underset{\substack{\mathsf{\vphantom{fg}vocab}}}{\vphantom{fg}\mathrm{index}}}(P,I)$:

In [29]:
P = Indexable(
    tensor([[6, 2, 4, 2], [8, 2, 1, 3], [5, 5, 7, 0], [1, 3, 8, 2], [5, 9, 2, 3]]),
)[vocab(), seq()]

subst(P, {vocab: I})

Indexable(tensor([1, 5, 2, 2]))[seq()]

Indexing with two integer arrays:

$$
\begin{aligned}
  |\mathsf{\vphantom{fg}seq}| &= m \\
  I_1 &= [m]^\mathsf{\vphantom{fg}subseq}\\
  I_2 &= [n]^\mathsf{\vphantom{fg}subseq}\\
  S &= \mathop{\underset{\substack{\mathsf{\vphantom{fg}vocab}}}{\vphantom{fg}\mathrm{index}}}(\mathop{\underset{\substack{\mathsf{\vphantom{fg}seq}}}{\vphantom{fg}\mathrm{index}}}(P, I_1), I_2) \in \mathbb{R}^{\mathsf{\vphantom{fg}subseq}} \\
  S_{\mathsf{\vphantom{fg}subseq}(i)} &= P_{\mathsf{\vphantom{fg}seq}(I_{\mathsf{\vphantom{fg}subseq}(i)}), \mathsf{\vphantom{fg}vocab}(I_{\mathsf{\vphantom{fg}subseq}(i)})}.
\end{aligned}
$$

In [30]:
subseq = gensym(int, name="subseq")
I1 = Indexable(tensor([1, 2, 0]))[subseq()]
I2 = Indexable(tensor([3, 0, 4]))[subseq()]

subst(P, {seq: I1, vocab: I2})

Indexable(tensor([3, 4, 5]))[subseq()]