In [1]:
import torch

## Matrix Transpose

$B_{ji}=A_{ij}$

In [2]:
a = torch.arange(6).reshape(2, 3)
torch.einsum('ij->ji', [a])

tensor([[0, 3],
        [1, 4],
        [2, 5]])

## Sum
$b = A_{ij}$

In [3]:
a = torch.arange(6).reshape(2, 3)
torch.einsum("ij->", [a])

tensor(15)

## Column Sum
$b_j = \sum_{i=1}^{n} A_{ij} = A_{i,j}$

In [4]:
torch.einsum("ij->i", [a])

tensor([ 3, 12])

## Matrix vector multiplication
$b_i = \sum_{j=1}^{m} A_{ij} b_j = A_{i,j}b_j$

In [6]:
b=torch.arange(3)
torch.einsum("ij,j->i", [a, b])

tensor([ 5, 14])

## Matrix multiplication
$c_{ij} = \sum_{k}\sum_{j} A_{ik} B_{kj} = A_{ik} B_{kj}$

In [7]:
a = torch.arange(6).reshape(2, 3)
b = torch.arange(12).reshape(3, 4)

torch.einsum("ij,jk->ik", [a, b])

tensor([[20, 23, 26, 29],
        [56, 68, 80, 92]])

## Dot Product
### Vector
$C = A \cdot B = \sum_{i} A_i B_i = A_i B_i$

### Matrix
$C = A \cdot B = \sum_{i}\sum_{j} A_{ij} B_{ij} = A_{ij} B_{ij}$

In [9]:
a = torch.arange(6).reshape(-1)
b = torch.arange(6).reshape(-1)

print(torch.einsum("i,i->", [a, b]))

a = torch.arange(6).reshape(2, 3)
b = torch.arange(6).reshape(2, 3)
print(torch.einsum("ij,ij->", [a, b]))

tensor(55)
tensor(55)


## Tensor Contraction
Batch matrix multiplication is a special case of a tensor contraction. Let's say we have two tensors, an order-n tensor $A$ and an order-m tensor $B$. The tensor contraction of $A$ and $B$ is a tensor $C$ of order $n+m-2$ obtained by summing over the product of the elements of $A$ and $B$ along the specified axes.

In [17]:
a = torch.randn(2,3,5,7)
b = torch.randn(11,13,3,17,5, 6)

torch.einsum("ijkl,mnjpkr->ilmnpr", [a, b])

tensor([[[[[[ 5.4583e+00,  1.8139e+00,  5.5668e+00,  6.7842e+00,  8.2728e-01,
              2.6351e+00],
            [ 1.7960e+00,  5.3653e+00, -1.5996e-01,  4.5314e+00,  3.8997e+00,
             -1.0338e+01],
            [-3.3416e+00,  2.9324e-02, -7.0183e-01, -3.5193e-02,  2.2598e-01,
             -5.0452e+00],
            ...,
            [-7.3948e+00, -7.8453e-01, -1.3901e+00,  2.3176e+00,  3.3408e+00,
              8.2196e-01],
            [-5.8006e+00,  3.4067e+00,  1.0083e+01, -1.1365e+00,  8.4844e-02,
              1.9293e+00],
            [ 1.8058e-01,  7.2709e-01,  3.9015e+00, -1.4610e+00, -5.4825e+00,
             -5.9912e+00]],

           [[-5.6220e+00, -1.2180e+00,  5.3934e+00,  4.4196e+00, -2.1951e-01,
              3.3457e+00],
            [-6.6970e-01,  2.7369e+00, -1.1523e+00,  3.1238e+00, -5.0317e+00,
             -4.2064e-01],
            [ 4.8512e+00,  2.1841e+00, -3.5200e+00, -1.7206e+00, -5.2994e+00,
             -9.3077e-01],
            ...,
            [-6.040

## Bilinear Transformation
$C = A \cdot B \cdot A = \sum_{i}\sum_{j}\sum_{k} A_{ij} B_{jk} A_{ki} = A_{ij} B_{jk} A_{ki}$

In [19]:
a = torch.randn(2,3)
b = torch.randn(3,4)
c = torch.randn(4, 5, 6)

torch.einsum("ij,jk,kmn->imn", [a, b, c])

tensor([[[ -1.7381,  -2.3121,   0.1990,   0.1354,  -0.7963,  -1.8758],
         [ -2.8041,   1.6287,  -0.3935,  -2.1446,  -2.9181,   0.2594],
         [ -0.9753,  -0.0448,  -0.7636,  -4.0472,   1.8998,  -0.0825],
         [  2.1368,  -1.1268,   1.4547,  -0.5861,  -2.4453,   2.6793],
         [ -0.5131,   0.3768,   1.6490,   2.7058,   0.1691,  -4.3313]],

        [[  1.1639,  -5.6645,  -0.7945,  -2.3685, -16.0739,  -8.1747],
         [ -3.6608,   3.9785,  -0.0499,   1.2027,  -6.9434,   3.9332],
         [ -1.9942,  -8.5210,   1.7137,  -6.3119,  10.8528,   2.5875],
         [  0.9083,  -7.7480,  -0.5555,  -2.8272,   8.9838,   3.1934],
         [ -5.5207,   4.7843,  -3.4518,  -3.4748,   5.0775,  -4.1768]]])

In [21]:
a = torch.randn(3,3)
print(a)
print(torch.einsum("ii->", [a]))
print(torch.einsum("ij->", [a]))
print(torch.einsum("ij->j", [a]))
print(torch.einsum("ij->i", [a]))
print(torch.einsum("ii->i", [a]))

tensor([[ 1.3963,  0.4845,  1.0258],
        [ 1.1368, -0.0149,  0.1346],
        [-1.2187, -0.1384, -1.5310]])
tensor(-0.1496)
tensor(1.2751)
tensor([ 1.3144,  0.3313, -0.3706])
tensor([ 2.9066,  1.2565, -2.8880])
tensor([ 1.3963, -0.0149, -1.5310])
