In [157]:
import torch
import numpy as np

# einstein sum notation tests to get familiar with the notation

einsum(..) sums the product of all elements of the input according to the setup.  
 
examples:  
https://stackoverflow.com/questions/55894693/understanding-pytorch-einsum  
https://stackoverflow.com/questions/55894693/understanding-pytorch-einsum  
https://rockt.github.io/2018/04/30/einsum  
https://notebook.community/ESSS/notebooks/einsum-notebook


## 1-dimensional einsum operations

In [158]:
# inner product/scalar product
# comma is a multiplication
# indexes not used after arrow are summed over
# the arrow indicates a multiplication
a = torch.arange(3)
b = torch.arange(3,6)

a = a.unsqueeze(dim=0)
b = b.unsqueeze(dim=0)

print(f"{a=}")
print(f"{a.shape=}\n")
print(f"{b=}")
print(f"{b.shape=}\n")

print(f"Scalarproduct: {torch.einsum('ij,ij->', a, b)}")

a=tensor([[0, 1, 2]])
a.shape=torch.Size([1, 3])

b=tensor([[3, 4, 5]])
b.shape=torch.Size([1, 3])

Scalarproduct: 14


## 2-dimensional einsum operations

In [159]:
# setup
a = torch.randint(0, 10, (2,2))
b = torch.randint(0, 10, (2,2))

print(f"{a=}")
print(f"{a.shape=}\n")
print(f"{b=}")
print(f"{b.shape=}\n")

# outer product
print(f"Outer product: {a@b}\n")
print(f"Einsum outer product: {torch.einsum('ik,kj->ij', a, b)}\n")


a=tensor([[7, 7],
        [2, 4]])
a.shape=torch.Size([2, 2])

b=tensor([[9, 1],
        [7, 5]])
b.shape=torch.Size([2, 2])

Outer product: tensor([[112,  42],
        [ 46,  22]])

Einsum outer product: tensor([[112,  42],
        [ 46,  22]])



## 3-dimensional einsum operations

In [160]:
# setup
a = torch.randint(0, 10, (2,2,2))
b = torch.randint(0, 10, (2,2,2))

print(f"{a=}")
print(f"{a.shape=}\n")
print(f"{b=}")
print(f"{b.shape=}\n")

# batch matrix product
print(f"Batch matrix product: {a@b}\n")
print(f"Einsum batch matrix product: {torch.einsum('ijk,ikl->ijl', a, b)}\n")


a=tensor([[[0, 7],
         [9, 3]],

        [[2, 5],
         [4, 4]]])
a.shape=torch.Size([2, 2, 2])

b=tensor([[[3, 7],
         [6, 6]],

        [[2, 0],
         [9, 1]]])
b.shape=torch.Size([2, 2, 2])

Batch matrix product: tensor([[[42, 42],
         [45, 81]],

        [[49,  5],
         [44,  4]]])

Einsum batch matrix product: tensor([[[42, 42],
         [45, 81]],

        [[49,  5],
         [44,  4]]])



## 4-dimensional einsum operations

In [161]:
# setup
a = torch.randint(0, 10, (2, 2, 2, 2))
b = torch.randint(0, 10, (2, 2, 2, 2))

print(f"{a=}")
print(f"{a.shape=}\n")
print(f"{b=}")
print(f"{b.shape=}\n")

# batch matrix product
print(f"Batch matrix product: {a@b}\n")

a=tensor([[[[0, 2],
          [2, 9]],

         [[1, 5],
          [3, 9]]],


        [[[7, 4],
          [3, 2]],

         [[3, 4],
          [6, 8]]]])
a.shape=torch.Size([2, 2, 2, 2])

b=tensor([[[[6, 9],
          [8, 4]],

         [[7, 0],
          [6, 2]]],


        [[[2, 0],
          [0, 9]],

         [[9, 4],
          [0, 8]]]])
b.shape=torch.Size([2, 2, 2, 2])

Batch matrix product: tensor([[[[16,  8],
          [84, 54]],

         [[37, 10],
          [75, 18]]],


        [[[14, 36],
          [ 6, 18]],

         [[27, 44],
          [54, 88]]]])



# einstein sum notation used in the project

In [162]:
# torch.einsum('aixy,ajxy', a, b) -> same results
torch.einsum('aixy,ajxy->ij', a, b)

tensor([[102, 125],
        [189, 150]])

In [163]:
torch.einsum('aixy,ajxy->', a, b)

tensor(566)

In [164]:
torch.einsum('aixy,ajxy->aijxy', a, b).shape

torch.Size([2, 2, 2, 2, 2])

In [165]:
# torch.einsum('aixy,ajxy', a, b) -> same results
torch.einsum('aixy,ajxy->ij', a, b)

tensor([[102, 125],
        [189, 150]])

In [166]:
torch.einsum('aixy,aixy->ai', a, b)

tensor([[ 70,  43],
        [ 32, 107]])