In [36]:
import os
import torch
from IPython.core.debugger import set_trace

os.environ["TRITON_INTERPRET"] = "0"  # Need to set before importing Triton


def check_tensors_gpu_ready(*tensors: torch.Tensor):
    for t in tensors:
        assert t.is_contiguous(), "A tensor is not contiguous"
        if not os.environ.get("TRITON_INTERPRET") == "1":
            assert t.is_cuda, "A tensor is not on GPU"


def test_pid_conds(conds_str: str, pid_0=[0], pid_1=[0], pid_2=[0]) -> bool:
    pids = pid_0[0], pid_1[0], pid_2[0]
    conds = conds_str.replace(" ", "").split(",")
    for i, (cond, pid) in enumerate(zip(conds, pids)):
        if cond == "":
            continue
        try:
            op, threshold = cond[0], int(cond[1:])
        except ValueError as e:
            if len(cond[1:]) == 2:
                op, threshold = cond[0:2], int(cond[2:])
            else:
                raise ValueError(e)
        if op not in ["<", ">", ">=", "<=", "=", "!="]:
            raise ValueError(
                f"Rules may only use these ops: '<','>','>=','<=','=', '!='. Invalid rule: '{cond}'."
            )
        op = "==" if op == "=" else op
        if not eval(f"{pid} {op} {threshold}"):
            return False
    return True


assert test_pid_conds("")
assert test_pid_conds(">=0")
assert test_pid_conds(">0", [1], [1])
assert not test_pid_conds(">0", [0], [1])
assert test_pid_conds("=0,=1", [0], [1], [1])
assert test_pid_conds("=0, =1", [0], [1], [2])


def breakpoint_if(conds: str, pid_0=[0], pid_1=[0], pid_2=[0]):
    if test_pid_conds(conds, pid_0, pid_1, pid_2):
        set_trace()


def print_if(txt: str, conds: str, pid_0=[0], pid_1=[0], pid_2=[0]):
    if test_pid_conds(conds, pid_0, pid_1, pid_2):
        print(txt)


def cdiv(a: int, b: int):
    return (a + b - 1) // b


assert cdiv(10, 2) == 5
assert cdiv(11, 2) == 6

check_tensors_gpu_ready(
    torch.tensor([1, 2, 3], device="cuda"),
    torch.tensor([[4, 5, 6], [7, 8, 9], [10, 11, 12]], device="cuda")[
        :, 2
    ].contiguous(),
)

0 >= 0
1 > 0
0 > 0
0 == 0
1 == 1
0 == 0
1 == 1


In [1]:
import numpy as np
import opt_einsum as oe

# Define your contraction
equation = "ij,jk->ik"

array1 = np.random.rand(3, 3)
array2 = np.random.rand(3, 3)
operands = [array1, array2]

# Find the best path
path, info = oe.contract_path(equation, *operands, optimize="optimal")

# Use the path in your contraction
result = oe.contract(equation, *operands, optimize=path)

In [3]:
info

  Complete contraction:  ij,jk->ik
         Naive scaling:  3
     Optimized scaling:  3
      Naive FLOP count:  5.400e+1
  Optimized FLOP count:  5.400e+1
   Theoretical speedup:  1.000e+0
  Largest intermediate:  9.000e+0 elements
--------------------------------------------------------------------------------
scaling        BLAS                current                             remaining
--------------------------------------------------------------------------------
   3           GEMM              jk,ij->ik                                ik->ik

In [4]:
import opt_einsum as oe

# Take a complex string
einsum_string = "bdik,acaj,ikab,ajac,ikbd->"

# Build random views to represent this contraction
unique_inds = set(einsum_string) - {",", "-", ">"}
index_size = [10, 17, 9, 10, 13, 16, 15, 14, 12]
sizes_dict = dict(zip(unique_inds, index_size))

In [6]:
views = oe.helpers.build_views(einsum_string, sizes_dict)
views

[array([[[[6.42986249e-01, 1.29885860e-01, 8.79100708e-01, ...,
           2.76965680e-01, 1.59598470e-01, 1.28769986e-01],
          [3.87925661e-01, 3.78227206e-01, 3.37310048e-02, ...,
           1.24399532e-01, 8.24135821e-01, 9.60334804e-01],
          [3.79457552e-01, 3.78402661e-03, 7.50703658e-01, ...,
           8.08910715e-01, 2.57473920e-01, 5.66491152e-01],
          ...,
          [8.15943257e-01, 6.80262178e-01, 3.05033084e-01, ...,
           5.80951315e-01, 3.45062774e-01, 3.12933145e-01],
          [6.74765021e-01, 5.47978537e-01, 8.15455878e-01, ...,
           8.33216685e-01, 8.19061753e-01, 9.15607205e-02],
          [1.68160780e-01, 6.72677409e-01, 8.59920624e-01, ...,
           8.26846583e-01, 7.50304102e-01, 5.92610699e-01]],
 
         [[3.30371302e-01, 3.15782918e-01, 1.11099701e-01, ...,
           2.94054481e-01, 8.12819619e-02, 5.54832270e-01],
          [9.50376656e-01, 2.60377505e-03, 6.06783927e-01, ...,
           5.51162917e-01, 5.68507727e-01, 2.55570

In [9]:
path, path_info = oe.contract_path(einsum_string, *views)

In [11]:
path_info

  Complete contraction:  bdik,acaj,ikab,ajac,ikbd->
         Naive scaling:  7
     Optimized scaling:  4
      Naive FLOP count:  2.387e+8
  Optimized FLOP count:  8.673e+4
   Theoretical speedup:  2.752e+3
  Largest intermediate:  1.600e+3 elements
--------------------------------------------------------------------------------
scaling        BLAS                current                             remaining
--------------------------------------------------------------------------------
   4              0         ikbd,bdik->ikb                  acaj,ikab,ajac,ikb->
   4    GEMV/EINSUM            ikb,ikab->a                         acaj,ajac,a->
   3              0           ajac,acaj->a                                 a,a->
   1            DOT                  a,a->                                    ->

In [14]:
sizes_dict

{'b': 10, 'd': 17, 'a': 9, 'i': 10, 'j': 13, 'k': 16, 'c': 15}

In [26]:
import numpy as np

einsum_string = "bdik,acaj,ikab,ajac,ikbd->"
bdik = np.random.randn(10, 17, 10, 16)
acaj = np.random.randn(9, 15, 9, 13)
ikab = np.random.randn(10, 16, 9, 10)
ajac = np.random.randn(9, 13, 9, 15)
ikbd = np.random.randn(10, 16, 10, 17)

for i, tensor in enumerate([bdik, acaj, ikab, ajac, ikbd]):
    assert views[i].shape == tensor.shape

In [27]:
ikb = ikbd.reshape(10, 16, -1) @ bdik.reshape(-1, 160)
ikb = ikb.reshape(10, 16, 10, -1).sum(axis=-1)
ikb.shape

(10, 16, 10)

In [33]:
a = ikb.reshape(-1, 10).T @ ikab.reshape(-1, 90)
a = a.reshape(9, -1).sum(axis=-1)
a

array([ -1951.2297429 ,  -3743.92732334,   9300.62462895,   2425.66398296,
         2119.73749352,  -3978.27862869, -11144.73523855,  -2984.9926402 ,
        -2851.86661815])

In [39]:
a_second = ajac.reshape(9, -1) @ acaj.reshape(9, -1).T
a_second = a.reshape(9, -1).sum(axis=-1)

In [40]:
final_result = a @ a_second
final_result

np.float64(271778189.2550141)

In [42]:
oe.contract(einsum_string, *views)

array(1529910.11624432)

In [44]:
(
    (
        (ikbd.reshape(10, 16, -1) @ bdik.reshape(-1, 160))
        .reshape(10, 16, 10, -1)
        .sum(axis=-1)
        .reshape(-1, 10)
        .T
        @ ikab.reshape(-1, 90)
    ).reshape(-1)
).shape

(900,)

In [45]:
final_result = (
    (
        (ikbd.reshape(10, 16, -1) @ bdik.reshape(-1, 160))
        .reshape(10, 16, 10, -1)
        .sum(axis=-1)
        .reshape(-1, 10)
        .T
        @ ikab.reshape(-1, 90)
    ).reshape(-1)
    @ np.ones((900))
    * (ajac.reshape(9 * 13, 9 * 15) @ acaj.reshape(9 * 15, 9 * 13)).sum()
)

In [46]:
final_result

np.float64(-59747804.41907154)

In [51]:
import numpy as np

# Define the tensors
bdik = np.random.randn(10, 17, 10, 16)
acaj = np.random.randn(9, 15, 9, 13)
ikab = np.random.randn(10, 16, 9, 10)
ajac = np.random.randn(9, 13, 9, 15)
ikbd = np.random.randn(10, 16, 10, 17)

# Step 1: Contract ikbd and bdik over common dimensions (b, d)
step1 = np.tensordot(ikbd, bdik, axes=((0, 1), (3, 2)))  # Shape: (10, 16, 16, 17)

# Step 2: Contract step1 and ikab over common dimensions (i, k, b)
step2 = np.tensordot(step1, ikab, axes=((0, 1, 2), (0, 1, 3)))  # Shape: (10, 9)

# Step 3: Contract step2 and ajac over common dimensions (a, j, c)
step3 = np.tensordot(step2, ajac, axes=((0, 1), (2, 0)))  # Shape: (15)

# Step 4: Contract step3 and acaj over common dimensions (a, c, j)
final_result = np.tensordot(step3, acaj, axes=((0), (3)))  # Shape: scalar

print(final_result)  # Output: scalar

ValueError: shape-mismatch for sum

# Learn Einsum

Let’s say we have two matrices A and B that we want to multiply them element-wise and then take sum for axis = 1 (row wise)



In [1]:
import numpy as np

A = np.array([0, 1, 2])  # shape (3,)

B = np.array(
    [
        [0, 1, 2, 3],  # (3, 4)
        [4, 5, 6, 7],
        [8, 9, 10, 11],
    ]
)

In [2]:
np.einsum("i, ij -> i", A, B)

array([ 0, 22, 76])

In [5]:
(A[:, None] * B).sum(axis=1)

array([ 0, 22, 76])

In [6]:
np.einsum("ij->", B)

66

In [21]:
C = B * 2
np.einsum("ij, kj -> ikj", B, C)

array([[[  0,   2,   8,  18],
        [  0,  10,  24,  42],
        [  0,  18,  40,  66]],

       [[  0,  10,  24,  42],
        [ 32,  50,  72,  98],
        [ 64,  90, 120, 154]],

       [[  0,  18,  40,  66],
        [ 64,  90, 120, 154],
        [128, 162, 200, 242]]])

In [23]:
import torch

a = torch.arange(6).reshape(2, 3)
b = torch.arange(3)

torch.einsum("ij, j -> i", a, b)

tensor([ 5, 14])

In [24]:
a = torch.arange(6).reshape(2, 3)
b = torch.arange(15).reshape(3, 5)

torch.einsum("ij, jk -> ik", a, b)

tensor([[ 25,  28,  31,  34,  37],
        [ 70,  82,  94, 106, 118]])

In [26]:
a = torch.arange(3)
b = torch.arange(3, 6)  # -- a vector of length 3 containing [3, 4, 5]

torch.einsum("i, i", a, b)

tensor(14)

In [27]:
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6, 12).reshape(2, 3)

torch.einsum("ij, ij", a, b)

tensor(145)

In [28]:
a = torch.arange(6).reshape(2, 3)
b = torch.arange(6, 12).reshape(2, 3)

torch.einsum("ij, ij->ij", a, b)

tensor([[ 0,  7, 16],
        [27, 40, 55]])

In [29]:
a = torch.arange(3)
b = torch.arange(3, 7)  # -- a vector of length 4 containing [3, 4, 5, 6]

torch.einsum("i, j -> ij", a, b)

tensor([[ 0,  0,  0,  0],
        [ 3,  4,  5,  6],
        [ 6,  8, 10, 12]])

In [30]:
a = torch.randn(3, 2, 5)
b = torch.randn(3, 5, 3)

torch.einsum("bij, bjk -> bik", a, b)

tensor([[[-1.0206e-01, -1.0155e+00, -5.0660e-03],
         [-2.5707e+00,  1.3002e+00, -5.1847e-02]],

        [[ 3.0338e+00, -3.6176e-02,  1.9567e+00],
         [ 3.6641e+00,  3.0247e-01,  2.9984e+00]],

        [[-2.9061e-01, -1.8554e-01, -7.4797e-01],
         [ 5.1176e+00, -1.5683e-01, -1.6907e+00]]])

In [33]:
# i2 = j3 and i3 = j5
a = torch.rand(2, 5, 6, 8)  # length of dimension is 4
b = torch.rand(10, 7, 5, 9, 6)  # length of dimension is 5

torch.einsum("ijkl, mnjok->ilmno", a, b)

tensor([[[[[ 6.7968,  8.1131,  5.8946,  ...,  8.4670,  6.8575,  6.9958],
           [ 8.1204,  6.6542,  7.0176,  ...,  6.5356,  7.1529,  7.3653],
           [ 6.8247,  7.8043,  8.3069,  ...,  6.4716,  6.4453,  7.4474],
           ...,
           [ 7.4007,  7.6729,  6.5106,  ...,  8.5467,  8.7577,  5.6266],
           [ 7.2576,  7.0697,  6.9894,  ...,  7.0265,  5.8699,  7.6196],
           [ 7.0938,  7.1244,  6.3140,  ...,  7.5240,  7.1434,  7.9037]],

          [[ 7.4518,  6.9562,  6.9441,  ...,  6.9254,  8.1319,  5.9144],
           [ 7.8143,  7.6178,  7.0612,  ...,  7.3658,  8.6168,  7.6718],
           [ 6.7776,  4.7999,  8.0208,  ...,  6.4700,  6.2627,  7.7761],
           ...,
           [ 5.8409,  8.0961,  7.3983,  ...,  5.9801,  8.0419,  6.4702],
           [ 7.4315,  7.0053,  7.1440,  ...,  5.1458,  6.0750,  7.4817],
           [ 6.7271,  6.9307,  7.6580,  ...,  5.9900,  8.2490,  6.8422]],

          [[ 8.7691,  8.7695,  6.2483,  ...,  8.2896,  8.3024,  7.1377],
           [ 7.

In [34]:
a = torch.randn(2, 3)
b = torch.randn(5, 3, 7)
c = torch.randn(2, 7)

torch.einsum("jk, lkm, jm -> jl", a, b, c)

tensor([[ 0.8396,  0.0357, -0.2733, -0.2993,  0.7561],
        [-6.7138,  6.9402,  3.0302, -2.6569, -3.5242]])

In [35]:
import torch.nn.functional as F


def random_tensors(shape, num=1, requires_grad=False):
    tensors = [torch.randn(shape, requires_grad=requires_grad) for i in range(0, num)]
    return tensors[0] if num == 1 else tensors


bM, br, w = random_tensors([7], num=3, requires_grad=True)
# -- [hidden_dimension x hidden_dimension]
WY, Wh, Wr, Wt = random_tensors([7, 7], num=4, requires_grad=True)