# Einstein Notation

## Goals:

- understand the various matrix operations `einsum` can replace

## Concepts:

- `einsum`

# Matrix Operations:

- dot product
- matrix-vector
- matrix-matrix
- hadamard product
- outer product
- batched matrix mult
- transpose
- summing along an axis after multiplication
- trace
- diag

In [1]:
import numpy as np
np.random.seed(42)

import jax.numpy as jnp
from jax.scipy import stats

In [2]:
NUM_FEAT = 10
NUM_SAMPLES = 100

v1 = jnp.asarray(np.random.rand(NUM_FEAT))
v2 = jnp.asarray(np.random.rand(NUM_FEAT))

M1 = jnp.asarray(np.random.rand(NUM_FEAT, NUM_SAMPLES))
M2 = jnp.asarray(np.random.rand(NUM_SAMPLES, NUM_FEAT))

BATCH_SIZE = 128
NUM_CHANNELS = 3
HEIGHT = 32
WIDTH = 32

LEARNED_FILTERS = 5
images = jnp.asarray(np.random.rand(BATCH_SIZE, HEIGHT, WIDTH, NUM_CHANNELS))
learned_kernels = jnp.asarray(np.random.rand(LEARNED_FILTERS, HEIGHT, WIDTH, NUM_CHANNELS))

# Vector-Vector dot product

which returns a scalar. This is equivalent to doing a hadamard-then-sum

In [3]:
# vector-vector dot-product
def vv_dot():
    res = jnp.dot(v1, v2)
    res2 = v1 @ v2
    res3 = jnp.matmul(v1, v2)
    
    res_ein = jnp.einsum("j, j ->", v1, v2)
    print("All vector dot-products equivalent?",
        jnp.all(res == res2 == res3 == res_ein)
    )

vv_dot()

All vector dot-products equivalent? True


2024-04-26 16:00:19.246122: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


# Outer Product

Let 

$$u \in \mathbb{R}^m$$

and 

$$v \in \mathbb{R}^n$$

then the dot-product is

$$(u \otimes v) \in \mathbb{R}^{m, n}$$

In [4]:
# vector-vector outer-product

def vv_outer():
    res = jnp.outer(v1, v2)
    
    res_ein = jnp.einsum("i, j ->ij", v1, v2)
    print(
        "Outer-products equivalent?",
        jnp.allclose(res, res_ein))

vv_outer()

Outer-products equivalent? True


# Matrix-Matrix multiplication

In [5]:
def mm_mult():
    # Matrix mult
    res_mm = jnp.dot(M1, M2)
    res_ein_mm = np.einsum("ij, jl ->il", M1, M2)
    print("Matrix-mult equivalent?",jnp.allclose(res_mm, res_ein_mm))

    # Matrix mult then sum along the 0-th axis
    res_mm = np.sum(jnp.dot(M1, M2), axis=0)
    res_ein_mm = np.einsum("ij, jl ->l", M1, M2)
    print("Matrix-mult then sum-rows equivalent?",jnp.allclose(res_mm, res_ein_mm))

    # Matrix mult then sum along the 1st axis
    res_mm = np.sum(jnp.dot(M1, M2), axis=1)
    res_ein_mm = np.einsum("ij, jl ->i", M1, M2)
    print("Matrix-mult then sum-cols equivalent?",jnp.allclose(res_mm, res_ein_mm))

mm_mult()

Matrix-mult equivalent? True
Matrix-mult then sum-rows equivalent? True
Matrix-mult then sum-cols equivalent? True


# Misc. Operations

Some other operations that are useful involve taking:

- the trace, the sum of the diagonals
- the diagonals themselves
- transpose

In [7]:
def trace():

    #############################################
    # Note: the einsum trace is only defined when the matrix is square
    #############################################
    print(
        "Non-square matrix trace equivalent?",
        jnp.allclose(
            jnp.trace(M1), 
            np.einsum("ij->", M1)
        )
    )

    print(
        "Square matrix trace equivalent?",
        jnp.allclose(
            jnp.trace(M1 @ M1.T), 
            np.einsum("ii->", M1 @ M1.T)
        )
    )

def transpose():
    #############################################
    # Transpose
    print(
        "Matrix Transpose equivalent?", jnp.allclose(
        M1.T, 
        np.einsum("ij -> ji", M1)
    ))

def diagonal():
    #############################################
    # Diagonal
    # Note: the einsum diagonal is only defined when the matrix is square
    try:
        print(
            "Diag on non-square matrix equivalent?",
            jnp.allclose(
            jnp.diag(M1), 
            np.einsum("ii -> i", M1)
        ))
    except Exception as e:
        print("\n", e)
        # We get the expected result
        print(
            "Diag on square matrix equivalent?",
            jnp.allclose(
            jnp.diag(M1 @ M1.T), 
            np.einsum("ii -> i", M1 @ M1.T)
        ))


trace()
transpose()
diagonal()

Non-square matrix trace equivalent? False
Square matrix trace equivalent? True
Matrix Transpose equivalent? True

 dimensions in single operand for collapsing index 'i' don't match (10 != 100)
Diag on square matrix equivalent? True


# Higher-dimensional operations

Scenario: we have learned some filters and want to apply those filters to an image.

Another good example, which we do not cover, is calculating attention in LLMs. See [Einsum is All you Need - Einstein Summation in Deep Learning#3.2 Attention](https://rockt.github.io/2018/04/30/einsum)

In [9]:
def tensor_contraction():
    #############################################
    # Apply our learned filters (T2) to our data (T1)
    #############################################
    flattened_images = images.reshape(BATCH_SIZE, HEIGHT * WIDTH * NUM_CHANNELS)
    
    # Flatten height, width, and channels for T2, but need to transpose to match the matrix multiplication requirements
    flattened_kernel = learned_kernels.reshape(LEARNED_FILTERS, HEIGHT * WIDTH * NUM_CHANNELS).transpose()
    
    # Matrix multiplication
    res_tc = jnp.matmul(flattened_images, flattened_kernel).reshape(BATCH_SIZE, LEARNED_FILTERS)

    # Einsum
    res_ein = jnp.einsum('bhwc,fhwc->bf', images, learned_kernels)

    print(
        "Tensor Contraction Equivalent?",
        jnp.allclose(
            res_tc,
            res_ein
    ))

tensor_contraction()

Tensor Contraction Equivalent? True


# Gaussian PDF

One of the steps involves calculating 

$$(x - \mu)^T\Sigma(x - \mu)$$

In [8]:
import numpy as np

def gaussian_pdf(x, mu, Sigma):
    # Dimensions of the data
    k = mu.shape[0]
    
    # Calculate determinant and inverse of the covariance matrix
    Sigma_det = jnp.linalg.det(Sigma)
    Sigma_inv = jnp.linalg.inv(Sigma)
    
    # Calculate the normalization factor
    normalization_factor = 1 / jnp.sqrt((2 * jnp.pi) ** k * Sigma_det)
    x_mu = x - mu
    ###############################################################
    exponent1 = -0.5 * jnp.sum(x_mu @ Sigma_inv * x_mu, axis=1)
    exponent2 = -0.5 * jnp.einsum('ij,jj,ij->i', x_mu, Sigma_inv, x_mu)
    print("Exponent calculations equivalent?: ", exponent1 == exponent2)
    ###############################################################
    
    # Compute the Gaussian PDF
    return normalization_factor * jnp.exp(exponent2)

# Example usage
mu = np.array([0, 0])  # Mean vector
Sigma = np.array([[1, 0], [0, 1]])  # Covariance matrix
x = np.array([[1, 1], [2, 2], [3, 3]])  # Point to evaluate the PDF

pdf_value = gaussian_pdf(x, mu, Sigma)

print(jnp.allclose(
    pdf_value, 
    stats.multivariate_normal.pdf(x, mu, Sigma)
))

Exponent calculations equivalent?:  [ True  True  True]
True
