In [1]:
import numpy as np
import torch

# Basic einsum concepts:
# - Each dimension gets a letter (i, j, k, etc.)
# - If a letter appears on both sides of ->, it's kept
# - If a letter appears only on the left, it's summed over
# - If a letter appears multiple times on the left, those dimensions are multiplied element-wise

print("=== Basic Einsum Examples ===\n")


=== Basic Einsum Examples ===



In [2]:
# 1. Vector dot product
# Traditional: np.dot(a, b) or a @ b
a = np.array([1, 2, 3])
b = np.array([4, 5, 6])


In [5]:
print(a.shape)
print(b.shape)
print(np.dot(a,b))

(3,)
(3,)
32


In [8]:
np.einsum('i->', b)

np.int64(15)

In [9]:
np.einsum('i,i->', a, b)

np.int64(32)

In [12]:
# 2. Matrix-vector multiplication
# Traditional: A @ v
A = np.array([[1, 2], [3, 4], [5, 6]])
v = np.array([7, 8])

print(A)
print(A.shape)
print(v.shape)

[[1 2]
 [3 4]
 [5 6]]
(3, 2)
(2,)


In [11]:
A @ v

array([23, 53, 83])

In [13]:
mv_traditional = A @ v
mv_einsum = np.einsum('ij,j->i', A, v)
print(f"Matrix-vector multiplication:")
print(f"A shape: {A.shape}, v shape: {v.shape}")
print(f"Traditional: {mv_traditional}")
print(f"Einsum 'ij,j->i': {mv_einsum}")
print()

Matrix-vector multiplication:
A shape: (3, 2), v shape: (2,)
Traditional: [23 53 83]
Einsum 'ij,j->i': [23 53 83]



In [14]:
np.einsum('ij,j->', A, v)

np.int64(159)

In [17]:
# 3. Matrix multiplication
# Traditional: A @ B

# Let me show with proper matrix multiplication dimensions:
A = np.array([[1, 2, 3],    # shape (2, 3) - 2 rows, 3 columns
              [4, 5, 6]])

B = np.array([[7, 8],       # shape (3, 2) - 3 rows, 2 columns  
              [9, 10],
              [11, 12]])

print(A)
print(B)
print(A.shape)
print(B.shape)

[[1 2 3]
 [4 5 6]]
[[ 7  8]
 [ 9 10]
 [11 12]]
(2, 3)
(3, 2)


In [18]:
mm_traditional = A @ B
mm_einsum = np.einsum('ij,jk->ik', A, B)
print(f"Matrix multiplication:")
print(f"A shape: {A.shape}, B shape: {B.shape}")
print(f"Traditional:\n{mm_traditional}")
print(f"Einsum 'ij,jk->ik':\n{mm_einsum}")
print()

Matrix multiplication:
A shape: (2, 3), B shape: (3, 2)
Traditional:
[[ 58  64]
 [139 154]]
Einsum 'ij,jk->ik':
[[ 58  64]
 [139 154]]



In [19]:
# 4. Element-wise multiplication (Hadamard product)
A = np.array([[1, 2], [3, 4]])
B = np.array([[5, 6], [7, 8]])
hadamard_traditional = A * B
hadamard_einsum = np.einsum('ij,ij->ij', A, B)
print(f"Element-wise multiplication:")
print(f"Traditional:\n{hadamard_traditional}")
print(f"Einsum 'ij,ij->ij':\n{hadamard_einsum}")
print()

Element-wise multiplication:
Traditional:
[[ 5 12]
 [21 32]]
Einsum 'ij,ij->ij':
[[ 5 12]
 [21 32]]



In [20]:
# 5. Sum operations
matrix = np.array([[1, 2, 3], [4, 5, 6]])
sum_all = np.einsum('ij->', matrix)  # Sum all elements
sum_rows = np.einsum('ij->i', matrix)  # Sum along columns (result has row index)
sum_cols = np.einsum('ij->j', matrix)  # Sum along rows (result has column index)
print(f"Sum operations on matrix:\n{matrix}")
print(f"Sum all 'ij->': {sum_all}")
print(f"Sum along columns 'ij->i': {sum_rows}")
print(f"Sum along rows 'ij->j': {sum_cols}")
print()

Sum operations on matrix:
[[1 2 3]
 [4 5 6]]
Sum all 'ij->': 21
Sum along columns 'ij->i': [ 6 15]
Sum along rows 'ij->j': [5 7 9]



In [21]:
# 6. Transpose
matrix = np.array([[1, 2, 3], [4, 5, 6]])
transpose_traditional = matrix.T
transpose_einsum = np.einsum('ij->ji', matrix)
print(f"Transpose:")
print(f"Original:\n{matrix}")
print(f"Traditional .T:\n{transpose_traditional}")
print(f"Einsum 'ij->ji':\n{transpose_einsum}")
print()

Transpose:
Original:
[[1 2 3]
 [4 5 6]]
Traditional .T:
[[1 4]
 [2 5]
 [3 6]]
Einsum 'ij->ji':
[[1 4]
 [2 5]
 [3 6]]



In [22]:
# 7. Batch operations
# Batch matrix multiplication
batch_A = np.random.randn(3, 4, 5)  # 3 matrices of size 4x5
batch_B = np.random.randn(3, 5, 2)  # 3 matrices of size 5x2
batch_mm = np.einsum('bij,bjk->bik', batch_A, batch_B)
print(f"Batch matrix multiplication:")
print(f"batch_A shape: {batch_A.shape}")
print(f"batch_B shape: {batch_B.shape}")
print(f"Result shape: {batch_mm.shape}")
print("Einsum 'bij,bjk->bik' multiplies corresponding matrices in each batch")
print()

print("=== Key Einsum Rules ===")
print("1. Each unique letter represents a dimension")
print("2. Repeated letters on the left side get multiplied element-wise")
print("3. Letters that appear only on left side get summed over")
print("4. Letters on the right side specify output dimensions")
print("5. Order of letters on right side determines output shape")

Batch matrix multiplication:
batch_A shape: (3, 4, 5)
batch_B shape: (3, 5, 2)
Result shape: (3, 4, 2)
Einsum 'bij,bjk->bik' multiplies corresponding matrices in each batch

=== Key Einsum Rules ===
1. Each unique letter represents a dimension
2. Repeated letters on the left side get multiplied element-wise
3. Letters that appear only on left side get summed over
4. Letters on the right side specify output dimensions
5. Order of letters on right side determines output shape
