# Einstein Notation

## Goals:

- understand the various matrix operations `einsum` can replace

## Concepts:

- `einsum`

# Matrix Operations:

- dot product
- matrix-vector
- matrix-matrix
- hadamard product
- outer product
- batched matrix mult
- transpose
- summing along an axis after multiplication
- trace
- diag

In [1]:
import numpy as np
np.random.seed(42)

import jax.numpy as jnp

In [2]:
NUM_FEAT = 10
NUM_SAMPLES = 100

v1 = jnp.asarray(np.random.rand(NUM_FEAT))
v2 = jnp.asarray(np.random.rand(NUM_FEAT))

M1 = jnp.asarray(np.random.rand(NUM_FEAT, NUM_SAMPLES))
M2 = jnp.asarray(np.random.rand(NUM_SAMPLES, NUM_FEAT))

BATCH_SIZE = 128
NUM_CHANNELS = 3
HEIGHT = 32
WIDTH = 32

LEARNED_FILTERS = 5
T1 = jnp.asarray(np.random.rand(BATCH_SIZE, HEIGHT, WIDTH, NUM_CHANNELS))
T2 = jnp.asarray(np.random.rand(LEARNED_FILTERS, HEIGHT, WIDTH, NUM_CHANNELS))

In [62]:
# vector-vector dot-product
def vv_dot():
    res = jnp.dot(v1, v2)
    res2 = v1 @ v2
    res3 = jnp.matmul(v1, v2)
    
    res_ein = jnp.einsum("j, j ->", v1, v2)
    print("All vector dot-products equivalent?",
        jnp.all(res == res2 == res3 == res_ein)
    )

vv_dot()

All vector dot-products equivalent? True


In [63]:
# vector-vector outer-product

def vv_outer():
    res = jnp.outer(v1, v2)
    
    res_ein = jnp.einsum("i, j ->ij", v1, v2)
    print(
        "Outer-products equivalent?",
        jnp.allclose(res, res_ein))

vv_outer()

Outer-products equivalent? True


In [27]:
def mm_mult():
    # Matrix mult
    res_mm = jnp.dot(M1, M2)
    res_ein_mm = np.einsum("ij, jl ->il", M1, M2)
    print("Matrix-mult equivalent?",jnp.allclose(res_mm, res_ein_mm))

    # Matrix mult then sum along the 0-th axis
    res_mm = np.sum(jnp.dot(M1, M2), axis=0)
    res_ein_mm = np.einsum("ij, jl ->l", M1, M2)
    print("Matrix-mult then sum-rows equivalent?",jnp.allclose(res_mm, res_ein_mm))

    # Matrix mult then sum along the 1st axis
    res_mm = np.sum(jnp.dot(M1, M2), axis=1)
    res_ein_mm = np.einsum("ij, jl ->i", M1, M2)
    print("Matrix-mult then sum-cols equivalent?",jnp.allclose(res_mm, res_ein_mm))

mm_mult()

True
True
True


In [65]:
def trace_transpose_diagonal():

    #############################################
    # Note: the einsum trace is only defined when the matrix is square
    #############################################
    print(
        "Non-square matrix trace equivalent?",
        jnp.allclose(
            jnp.trace(M1), 
            np.einsum("ij->", M1)
        )
    )

    print(
        "Square matrix trace equivalent?",
        jnp.allclose(
            jnp.trace(M1 @ M1.T), 
            np.einsum("ii->", M1 @ M1.T)
        )
    )

    #############################################
    # Transpose
    print(
        "Matrix Transpose equivalent?", jnp.allclose(
        M1.T, 
        np.einsum("ij -> ji", M1)
    ))

    
    #############################################
    # Diagonal
    # Note: the einsum diagonal is only defined when the matrix is square
    try:
        print(
            "Diag on non-square matrix equivalent?",
            jnp.allclose(
            jnp.diag(M1), 
            np.einsum("ii -> i", M1)
        ))
    except Exception as e:
        print("\n", e)
        # We get the expected result
        print(
            "Diag on square matrix equivalent?",
            jnp.allclose(
            jnp.diag(M1 @ M1.T), 
            np.einsum("ii -> i", M1 @ M1.T)
        ))


trace_transpose_diagonal()

Non-square matrix trace equivalent? False
Square matrix trace equivalent? True
Matrix Transpose equivalent? True

 dimensions in single operand for collapsing index 'i' don't match (10 != 100)
Diag on square matrix equivalent? True


In [70]:
def tensor_contraction():
    # Note: for more interesting examples, check out how attention is implemented!

    #############################################
    # Apply our learned filters (T2) to our data (T1)
    #############################################
    T1_reshaped = T1.reshape(BATCH_SIZE, HEIGHT * WIDTH * NUM_CHANNELS)
    
    # Flatten height, width, and channels for T2, but need to transpose to match the matrix multiplication requirements
    T2_reshaped = T2.reshape(LEARNED_FILTERS, HEIGHT * WIDTH * NUM_CHANNELS).transpose()
    
    # Matrix multiplication
    res_tc = jnp.matmul(T1_reshaped, T2_reshaped).reshape(BATCH_SIZE, LEARNED_FILTERS)

    # Einsum
    res_ein = jnp.einsum('bhwc,fhwc->bf', T1, T2)

    print(
        "Tensor Contraction Equivalent?",
        jnp.allclose(
            res_tc,
            res_ein
    ))

tensor_contraction()

Tensor Contraction Equivalent? True
