In [52]:
import torch


In [53]:
# trace of a tensor (i.e. sum of main-diagonal elements)
x = torch.randn(2,2)
print(f'x: {x}')
print(f"sum of main-diagonal elements (aka trace): {torch.einsum('ii->', x)}")

x: tensor([[ 0.3641,  0.7233],
        [ 0.7548, -0.8655]])
sum of main-diagonal elements (aka trace): -0.501384973526001


In [54]:
# return a diagonal
x = torch.randn(2,2)
print(f'x: {x}')
print(f"extract elements along the main-diagonal: {torch.einsum('ii->i', x)}")

x: tensor([[-0.2158,  0.2564],
        [-2.3935,  1.6143]])
extract elements along the main-diagonal: tensor([-0.2158,  1.6143])


In [55]:
# axis summations 
x = torch.randn(2,3)
print(f'x: {x}')
print(f"summations along axis1: {torch.einsum('ij->i', x)}")

x: tensor([[-0.0582,  0.8060,  1.2411],
        [-0.8027, -1.0002, -0.6116]])
summations along axis1: tensor([ 1.9889, -2.4145])


In [56]:
# sum all elements 
x = torch.randn(2,3)
print(f'x: {x}')
print(f"sum all elements: {torch.einsum('ij->', x)}")

x: tensor([[ 0.0970, -0.4149,  0.0923],
        [ 0.5179,  1.2046, -0.4569]])
sum all elements: 1.0399783849716187


In [57]:
# transpositions and permutations
x = torch.randn(2,3)
print(f'x: {x}')
print(f" matrix transpose: {torch.einsum('ij->ji', x)}")

x: tensor([[ 1.1707,  0.2190, -0.1990],
        [-2.0503, -1.4213,  0.0845]])
 matrix transpose: tensor([[ 1.1707, -2.0503],
        [ 0.2190, -1.4213],
        [-0.1990,  0.0845]])


In [58]:
# Hadamard product (aka, element-wise product)
x = torch.randn(2,3)
y = torch.randn(2,3)
print(f'x: {x}')
print(f'y: {y}')
print(f"element-wise product: {torch.einsum('ij, ij->ij', x, y)}")

x: tensor([[ 1.3122,  0.6330,  0.4928],
        [ 1.1124,  0.6957, -0.7170]])
y: tensor([[ 1.2270, -0.2900,  0.2905],
        [ 1.6783,  2.1546,  0.9772]])
element-wise product: tensor([[ 1.6101, -0.1835,  0.1432],
        [ 1.8669,  1.4990, -0.7006]])


In [59]:
# Batch Hadamard product (aka, batch element-wise product)
x = torch.randn(2,3,4)
y = torch.randn(2,3,4)
print(f'x: {x}')
print(f'y: {y}')
print(f"batch element-wise product: {torch.einsum('ijk, ijk->ijk', x, y)}")

x: tensor([[[ 0.3801,  0.3663, -0.8577,  0.2298],
         [-0.7928,  2.0844,  1.0269, -0.0355],
         [-0.1082,  1.2953, -1.2705, -0.5055]],

        [[-1.4554,  0.0119,  0.3708,  0.3849],
         [-3.1081,  0.6219, -0.2015, -0.8840],
         [ 0.8204,  1.4839, -0.6978, -0.4825]]])
y: tensor([[[ 1.2400,  0.3695, -0.1728, -0.9412],
         [ 0.1810, -0.4023, -1.3176, -2.5765],
         [ 0.5193, -0.5816, -0.5057, -0.1050]],

        [[ 0.2236, -0.8351,  0.4344, -0.4601],
         [ 1.4683, -0.5040, -0.3001, -2.2385],
         [ 1.1537,  0.4762,  1.0413,  1.0141]]])
batch element-wise product: tensor([[[ 0.4713,  0.1353,  0.1482, -0.2163],
         [-0.1435, -0.8385, -1.3530,  0.0915],
         [-0.0562, -0.7533,  0.6425,  0.0531]],

        [[-0.3255, -0.0099,  0.1611, -0.1771],
         [-4.5637, -0.3135,  0.0605,  1.9787],
         [ 0.9464,  0.7066, -0.7266, -0.4893]]])


In [60]:
x[0][0]

tensor([ 0.3801,  0.3663, -0.8577,  0.2298])

In [61]:
y[0][0]

tensor([ 1.2400,  0.3695, -0.1728, -0.9412])

In [62]:
torch.mul(x[0][0], y[0][0])

tensor([ 0.4713,  0.1353,  0.1482, -0.2163])

In [63]:
#Element-wise squaring
x = torch.randn(2,3)
print(f'x: {x}')
print(f"element-wise squaring: {torch.einsum('ij, ij->ij', x, x)}")

x: tensor([[ 0.3525,  0.7294, -0.6127],
        [ 0.0971,  0.2584, -0.4535]])
element-wise squaring: tensor([[0.1243, 0.5320, 0.3754],
        [0.0094, 0.0668, 0.2056]])


In [64]:
#batch element-wise squaring
x = torch.randn(2,3,4)
print(f'x: {x}')
print(f"batch element-wise squaring of 3D: {torch.einsum('ijk, ijk->ijk', x, x)}")

x: tensor([[[-0.6853, -0.5923, -1.3569,  0.5794],
         [ 0.9784,  0.7166,  0.3900,  0.5264],
         [-1.6944, -0.5283, -1.3119, -1.1720]],

        [[-0.6830,  1.0840,  0.3115,  0.4600],
         [-0.2965, -0.1621, -1.8388, -0.0961],
         [ 1.5060,  1.2872,  0.9905, -0.3560]]])
batch element-wise squaring of 3D: tensor([[[0.4696, 0.3508, 1.8412, 0.3357],
         [0.9573, 0.5135, 0.1521, 0.2770],
         [2.8711, 0.2792, 1.7210, 1.3736]],

        [[0.4665, 1.1751, 0.0971, 0.2116],
         [0.0879, 0.0263, 3.3811, 0.0092],
         [2.2682, 1.6569, 0.9812, 0.1267]]])


In [65]:
# matrix multiplication (aka dot product, aka scalar product, aka inner product)
x = torch.randn(2,3)
y = torch.randn(3,4)
print(f'x: {x}')
print(f'y: {y}')
print(f"matrix multiplication/aka dot product/aka inner product: {torch.einsum('ij, jk->ik', x, y)}")

x: tensor([[ 1.4512,  0.7749, -0.6889],
        [-0.7218,  1.8518, -0.0519]])
y: tensor([[ 0.8589,  1.3685,  0.8603,  0.9341],
        [-1.9758, -0.4687, -0.0614, -0.5973],
        [ 0.5570,  1.0502,  0.8297,  0.3846]])
matrix multiplication/aka dot product/aka inner product: tensor([[-0.6684,  0.8992,  0.6293,  0.6278],
        [-4.3076, -1.9102, -0.7778, -1.8002]])


In [66]:
# batch matrix multiplication
x = torch.randn(2,3,4)
y = torch.randn(2,4,5)
print(f'x: {x}')
print(f'y: {y}')
print(f"batch matrix multiplication/aka batch dot product: {torch.einsum('bij,bjk ->bik', x, y)}")


x: tensor([[[-1.0919e+00,  1.2996e+00,  1.2206e+00,  7.1390e-01],
         [ 6.6417e-01, -2.8571e-01,  1.0276e-01,  3.9829e-01],
         [ 1.7737e+00,  3.5152e-01, -1.4926e+00, -3.4014e-03]],

        [[ 1.5701e+00,  1.4211e-01,  9.0065e-01, -3.1597e-01],
         [-4.5500e-01,  9.0665e-04,  5.5186e-01, -9.0227e-01],
         [ 6.1600e-01, -1.1899e+00,  1.2303e+00, -5.9795e-01]]])
y: tensor([[[ 2.8698,  1.5064,  0.2011,  1.1555,  1.7895],
         [-0.8025, -0.0198, -0.7197,  0.9443, -1.5183],
         [-1.3365,  0.8935,  0.0840, -0.8880,  2.0808],
         [-0.1128, -1.6667, -0.9509, -0.1920, -0.7911]],

        [[-1.4455,  0.3156, -1.2371,  0.6933,  1.2301],
         [ 0.1289, -0.5228,  1.7736, -1.5112,  0.6014],
         [-2.4403,  0.7106, -0.4682,  1.4587,  0.2035],
         [-2.2412, -0.2329,  0.3385, -0.0991,  0.9527]]])
batch matrix multiplication/aka batch dot product: tensor([[[-5.8883e+00, -1.7698e+00, -1.7311e+00, -1.2555e+00, -1.9519e+00],
         [ 1.9531e+00,  4.3417e-0

In [67]:
x[0][0]

tensor([-1.0919,  1.2996,  1.2206,  0.7139])

In [68]:
y[0].T[0]

tensor([ 2.8698, -0.8025, -1.3365, -0.1128])

In [69]:
torch.sum(x[0][0]*y[0].T[0])

tensor(-5.8883)

In [70]:
# double dot product/ Frobenius inner product (same as: torch.sum(hadamard-product))
x = torch.randn(2,3)
y = torch.randn(2,3)
print(f'x: {x}')
print(f'y: {y}')
print(f"double dot products: {torch.einsum('ij,ij ->', x, y)}")

x: tensor([[-0.2114,  0.2712, -0.5630],
        [ 0.0877,  0.0414, -0.8312]])
y: tensor([[ 0.2867,  2.2151, -2.3144],
        [ 1.3213, -0.8275, -0.3276]])
double dot products: 2.197045087814331


In [71]:
# Batch sum of element-wise product along axis
x = torch.randn(2,3,4)
y = torch.randn(2,3,4)
print(f'x: {x}')
print(f'y: {y}')
print(f"Batch sum of element-wise product along the first dimension (axis=1): {torch.einsum('bij,bij ->bj', x, y)}")

x: tensor([[[-0.4511,  1.2946,  1.8473,  0.2751],
         [-2.1980,  0.8060,  0.4860, -0.2481],
         [-0.9504, -0.7786, -0.8639, -0.2424]],

        [[ 0.1349, -1.0983, -0.6413, -1.3904],
         [ 1.0772, -0.1990,  1.2253, -1.2364],
         [-0.6647, -0.9308,  0.5650,  0.4862]]])
y: tensor([[[ 0.0911, -0.3562,  0.7171, -0.3257],
         [-1.4227,  0.1090,  1.9728,  0.4494],
         [-0.5175, -0.2848,  0.0972, -0.8430]],

        [[ 0.2131, -1.9654, -0.6124, -1.9931],
         [-0.4625,  1.0667,  0.8689,  0.5110],
         [-1.1879, -0.1370, -0.4452,  1.5956]]])
Batch sum of element-wise product along the first dimension (axis=1): tensor([[ 3.5778e+00, -1.5150e-01,  2.1997e+00,  3.3348e-03],
        [ 3.2013e-01,  2.0738e+00,  1.2059e+00,  2.9151e+00]])


In [72]:
x[0]

tensor([[-0.4511,  1.2946,  1.8473,  0.2751],
        [-2.1980,  0.8060,  0.4860, -0.2481],
        [-0.9504, -0.7786, -0.8639, -0.2424]])

In [73]:
y[0]

tensor([[ 0.0911, -0.3562,  0.7171, -0.3257],
        [-1.4227,  0.1090,  1.9728,  0.4494],
        [-0.5175, -0.2848,  0.0972, -0.8430]])

In [74]:
torch.sum(torch.mul(x[0], y[0]), 0)

tensor([ 3.5778e+00, -1.5150e-01,  2.1997e+00,  3.3348e-03])