# Einstein Notation

## Lesson Goals:

By the end of this lesson, you will understand how to read `einsum` notation, and get a feel for the kinds of operations that it can replace.

## Core Concepts:

- `einsum`


# Matrix Operations:

- dot product
- matrix-vector
- matrix-matrix
- hadamard product
- outer product
- batched matrix mult
- transpose
- summing along an axis after multiplication
- trace
- diag

In [None]:
import numpy as np
np.random.seed(42)

import jax.numpy as jnp
from jax.scipy import stats

In [None]:
NUM_FEAT = 10
NUM_SAMPLES = 100

v1 = jnp.asarray(np.random.rand(NUM_FEAT))
v2 = jnp.asarray(np.random.rand(NUM_FEAT))

M1 = jnp.asarray(np.random.rand(NUM_FEAT, NUM_SAMPLES))
M2 = jnp.asarray(np.random.rand(NUM_SAMPLES, NUM_FEAT))

# Vector-Vector dot product

which returns a scalar. This is equivalent to doing a hadamard-then-sum

In [None]:
# vector-vector dot-product
def vv_dot():
    res = jnp.dot(v1, v2)
    res2 = v1 @ v2
    res3 = jnp.matmul(v1, v2)

    res_ein = ...  # Your code here
    
    print("All vector dot-products equivalent?",
        jnp.all(res == res2 == res3 == res_ein)
    )

vv_dot()

# Outer Product

Let 

$$u \in \mathbb{R}^m$$

and 

$$v \in \mathbb{R}^n$$

then the dot-product is

$$(u \otimes v) \in \mathbb{R}^{m, n}$$

In [None]:
# vector-vector outer-product

def vv_outer():
    res = jnp.outer(v1, v2)
    
    res_ein = ...  # Your code here
    print(
        "Outer-products equivalent?",
        jnp.allclose(res, res_ein))

vv_outer()

# Matrix-Matrix multiplication

In [None]:
def mm_mult():
    # Matrix mult
    res_mm = jnp.dot(M1, M2)
    res_ein_mm = ...  # Your code here
    print("Matrix-mult equivalent?",jnp.allclose(res_mm, res_ein_mm))

    # Matrix mult then sum along the 0-th axis
    res_mm = np.sum(jnp.dot(M1, M2), axis=0)
    res_ein_mm = ...  # Your code here
    print("Matrix-mult then sum-rows equivalent?",jnp.allclose(res_mm, res_ein_mm))

    # Matrix mult then sum along the 1st axis
    res_mm = np.sum(jnp.dot(M1, M2), axis=1)
    res_ein_mm = ...  # Your code here
    print("Matrix-mult then sum-cols equivalent?",jnp.allclose(res_mm, res_ein_mm))

mm_mult()

# Misc. Operations

Some other operations that are useful involve taking:

- the trace, the sum of the diagonals
- the diagonals themselves
- transpose

In [None]:
def trace():

    #############################################
    # Note: the einsum trace is only defined when the matrix is square
    #############################################
    print(
        "Non-square matrix trace equivalent?",
        jnp.allclose(
            jnp.trace(M1), 
            np.einsum("ij->", M1)
        )
    )

    print(
        "Square matrix trace equivalent?",
        jnp.allclose(
            jnp.trace(M1 @ M1.T), 
            np.einsum("ii->", M1 @ M1.T)
        )
    )

def transpose():
    #############################################
    # Transpose
    print(
        "Matrix Transpose equivalent?", jnp.allclose(
        M1.T, 
        np.einsum("ij -> ji", M1)
    ))

def diagonal():
    #############################################
    # Diagonal
    # Note: the einsum diagonal is only defined when the matrix is square
    try:
        print(
            "Diag on non-square matrix equivalent?",
            jnp.allclose(
            jnp.diag(M1), 
            np.einsum("ii -> i", M1)
        ))
    except Exception as e:
        print("\n", e)
        # We get the expected result
        print(
            "Diag on square matrix equivalent?",
            jnp.allclose(
            jnp.diag(M1 @ M1.T), 
            np.einsum("ii -> i", M1 @ M1.T)
        ))


trace()
transpose()
diagonal()

# Higher-dimensional operations

Scenario: we have learned some filters and want to apply those filters to an image.

Another good example, which we do not cover, is calculating attention in LLMs. See [Einsum is All you Need - Einstein Summation in Deep Learning#3.2 Attention](https://rockt.github.io/2018/04/30/einsum)

P.s. What is the expected output shape of `res_tc`?

In [None]:
BATCH_SIZE = 128
NUM_CHANNELS = 3
HEIGHT = 32
WIDTH = 32

LEARNED_FILTERS = 5

def tensor_contraction(images, learned_kernels):
    #############################################
    # Apply our learned filters (T2) to our data (T1)
    #############################################
    flattened_images = images.reshape(BATCH_SIZE, HEIGHT * WIDTH * NUM_CHANNELS)
    
    # Flatten height, width, and channels for T2, but need to transpose to match the matrix multiplication requirements
    flattened_kernel = learned_kernels.reshape(LEARNED_FILTERS, HEIGHT * WIDTH * NUM_CHANNELS).transpose()
    
    # Matrix multiplication
    res_tc = jnp.matmul(flattened_images, flattened_kernel)

    # Einsum
    res_ein = ...  # TODO: Your code here

    print(
        "Tensor Contraction Equivalent?",
        jnp.allclose(
            res_tc,
            res_ein
    ))


images = jnp.asarray(np.random.rand(BATCH_SIZE, HEIGHT, WIDTH, NUM_CHANNELS))
learned_kernels = jnp.asarray(np.random.rand(LEARNED_FILTERS, HEIGHT, WIDTH, NUM_CHANNELS))
tensor_contraction(images, learned_kernels)

# Gaussian PDF

Recall the Gaussian PDF calculation from our [vmap lecture](./exe_04_vmap.ipynb). Another way of handling this is via the `einsum`. One of the steps involves calculating 

$$(x - \mu)^T\Sigma(x - \mu)$$

and we can leverage the einsum here

In [None]:
import numpy as np

def gaussian_pdf(x, mu, Sigma):
    # Dimensions of the data
    k = mu.shape[0]
    
    # Calculate determinant and inverse of the covariance matrix
    Sigma_det = jnp.linalg.det(Sigma)
    Sigma_inv = jnp.linalg.inv(Sigma)
    
    # Calculate the normalization factor
    normalization_factor = 1 / jnp.sqrt((2 * jnp.pi) ** k * Sigma_det)
    x_mu = x - mu
    ###############################################################
    exponent1 = -0.5 * jnp.sum(x_mu @ Sigma_inv * x_mu, axis=1)
    exponent2 = ...  # TODO: Your code here
    print("Exponent calculations equivalent?: ", exponent1 == exponent2)
    ###############################################################
    
    # Compute the Gaussian PDF
    return normalization_factor * jnp.exp(exponent2)

# Example usage
mu = np.array([0, 0])  # Mean vector
Sigma = np.array([[1, 0], [0, 1]])  # Covariance matrix
x = np.array([[1, 1], [2, 2], [3, 3]])  # Point to evaluate the PDF

pdf_value = gaussian_pdf(x, mu, Sigma)

print(jnp.allclose(
    pdf_value, 
    stats.multivariate_normal.pdf(x, mu, Sigma)
))