In [2]:
import torch


  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# trace of a tensor (i.e. sum of main-diagonal elements)
x = torch.randn(2,2)
print(f'x: {x}')
print(f"sum of main-diagonal elements (aka trace): {torch.einsum('ii->', x)}")

x: tensor([[-0.9246,  0.5725],
        [ 1.5614,  0.7287]])
sum of main-diagonal elements (aka trace): -0.19591361284255981


In [4]:
# return a diagonal
x = torch.randn(2,2)
print(f'x: {x}')
print(f"extract elements along the main-diagonal: {torch.einsum('ii->i', x)}")

x: tensor([[-1.2552, -1.5755],
        [-0.4351, -1.1378]])
extract elements along the main-diagonal: tensor([-1.2552, -1.1378])


In [5]:
# axis summations 
x = torch.randn(2,3)
print(f'x: {x}')
print(f"summations along axis1: {torch.einsum('ij->i', x)}")

x: tensor([[-0.6573, -2.6793,  0.5065],
        [-1.2995,  1.0958, -1.0628]])
summations along axis1: tensor([-2.8301, -1.2666])


In [6]:
# sum all elements 
x = torch.randn(2,3)
print(f'x: {x}')
print(f"sum all elements: {torch.einsum('ij->', x)}")

x: tensor([[ 1.1395,  0.1412,  0.4186],
        [-1.3967,  1.1091,  1.7122]])
sum all elements: 3.1238160133361816


In [7]:
# transpositions and permutations
x = torch.randn(2,3)
print(f'x: {x}')
print(f" matrix transpose: {torch.einsum('ij->ji', x)}")

x: tensor([[-0.7484,  0.4253,  0.0512],
        [ 0.8153, -0.2533,  1.1448]])
 matrix transpose: tensor([[-0.7484,  0.8153],
        [ 0.4253, -0.2533],
        [ 0.0512,  1.1448]])


In [8]:
# Hadamard product (aka, element-wise product)
x = torch.randn(2,3)
y = torch.randn(2,3)
print(f'x: {x}')
print(f'y: {y}')
print(f"element-wise product: {torch.einsum('ij, ij->ij', x, y)}")

x: tensor([[ 1.9482, -0.4157,  0.5294],
        [-0.4393,  0.6180, -0.7716]])
y: tensor([[ 1.0552, -1.9182,  0.4017],
        [ 0.4868, -0.0984, -1.6488]])
element-wise product: tensor([[ 2.0558,  0.7973,  0.2126],
        [-0.2139, -0.0608,  1.2723]])


In [9]:
# Batch Hadamard product (aka, batch element-wise product)
x = torch.randn(2,3,4)
y = torch.randn(2,3,4)
print(f'x: {x}')
print(f'y: {y}')
print(f"batch element-wise product: {torch.einsum('ijk, ijk->ijk', x, y)}")

x: tensor([[[-0.4088,  0.4388, -0.0113,  0.2040],
         [ 0.1149, -0.8467, -0.5215, -0.3189],
         [-0.8025,  0.6985,  1.9038,  0.4198]],

        [[-0.1224,  0.4112,  1.8690, -0.8694],
         [ 0.4843, -1.6301, -0.4105, -0.0347],
         [-0.3375,  0.4601, -2.0345,  0.4969]]])
y: tensor([[[ 0.9675, -1.9007,  0.5053,  0.8311],
         [-0.0776, -0.5151,  0.4124,  0.9237],
         [ 1.0182, -0.5842,  0.8982,  2.0471]],

        [[-0.2510, -0.1995,  0.6869, -0.1266],
         [ 0.2144,  0.9323, -1.3991, -1.6575],
         [-0.4247,  1.5779,  0.5201, -0.3391]]])
batch element-wise product: tensor([[[-0.3955, -0.8340, -0.0057,  0.1695],
         [-0.0089,  0.4361, -0.2151, -0.2945],
         [-0.8171, -0.4080,  1.7100,  0.8593]],

        [[ 0.0307, -0.0820,  1.2837,  0.1100],
         [ 0.1038, -1.5198,  0.5744,  0.0575],
         [ 0.1433,  0.7259, -1.0581, -0.1685]]])


In [10]:
x[0][0]

tensor([-0.4088,  0.4388, -0.0113,  0.2040])

In [11]:
y[0][0]

tensor([ 0.9675, -1.9007,  0.5053,  0.8311])

In [12]:
torch.mul(x[0][0], y[0][0])

tensor([-0.3955, -0.8340, -0.0057,  0.1695])

In [13]:
#Element-wise squaring
x = torch.randn(2,3)
print(f'x: {x}')
print(f"element-wise squaring: {torch.einsum('ij, ij->ij', x, x)}")

x: tensor([[-0.8581,  0.7851, -2.1807],
        [ 0.5431,  1.8429, -0.4270]])
element-wise squaring: tensor([[0.7364, 0.6164, 4.7555],
        [0.2950, 3.3962, 0.1823]])


In [14]:
#batch element-wise squaring
x = torch.randn(2,3,4)
print(f'x: {x}')
print(f"batch element-wise squaring of 3D: {torch.einsum('ijk, ijk->ijk', x, x)}")

x: tensor([[[-1.2866,  0.7288,  0.9549, -0.0288],
         [ 0.8938,  1.3885,  0.5825, -1.0367],
         [-0.0518,  1.2771, -0.2792, -0.4933]],

        [[-0.1655, -1.3628, -0.3494,  2.0258],
         [ 0.1322, -0.3116,  0.5265, -1.5448],
         [ 0.6981, -1.0780,  1.0337,  2.4459]]])
batch element-wise squaring of 3D: tensor([[[1.6553e+00, 5.3116e-01, 9.1193e-01, 8.2992e-04],
         [7.9893e-01, 1.9278e+00, 3.3927e-01, 1.0747e+00],
         [2.6847e-03, 1.6309e+00, 7.7980e-02, 2.4335e-01]],

        [[2.7405e-02, 1.8571e+00, 1.2207e-01, 4.1040e+00],
         [1.7477e-02, 9.7080e-02, 2.7716e-01, 2.3865e+00],
         [4.8736e-01, 1.1620e+00, 1.0685e+00, 5.9826e+00]]])


In [15]:
# matrix multiplication (aka dot product, aka scalar product, aka inner product)
x = torch.randn(2,3)
y = torch.randn(3,4)
print(f'x: {x}')
print(f'y: {y}')
print(f"matrix multiplication/aka dot product/aka inner product: {torch.einsum('ij, jk->ik', x, y)}")

x: tensor([[ 0.1936,  0.5547,  0.3681],
        [ 1.8390,  1.0020, -1.1445]])
y: tensor([[-1.3602,  0.2244, -0.3784, -0.3083],
        [-0.6085, -0.7188,  0.2970, -0.2583],
        [-0.0902, -0.5231,  0.3228,  0.9014]])
matrix multiplication/aka dot product/aka inner product: tensor([[-0.6341, -0.5478,  0.2103,  0.1289],
        [-3.0079,  0.2911, -0.7676, -1.8574]])


In [16]:
# batch matrix multiplication
x = torch.randn(2,3,4)
y = torch.randn(2,4,5)
print(f'x: {x}')
print(f'y: {y}')
print(f"batch matrix multiplication/aka batch dot product: {torch.einsum('bij,bjk ->bik', x, y)}")


x: tensor([[[ 0.1046,  1.1601, -1.3931, -0.3127],
         [ 1.4275, -1.3718,  0.6854, -1.2418],
         [ 0.1169,  0.0023,  0.8855,  0.0998]],

        [[-0.5069, -0.6111, -0.9057, -1.5333],
         [-0.0666,  1.0754, -1.2889, -1.4844],
         [-0.6905, -0.4721, -0.9851, -0.7915]]])
y: tensor([[[-0.9000, -0.6691, -1.6288, -0.4369,  0.1838],
         [ 1.0008,  0.2858,  0.3163,  0.0715,  1.7307],
         [-1.0459, -1.0507, -1.2698, -0.1862,  1.3261],
         [ 0.5438,  0.1990, -0.5217,  0.0895,  0.1649]],

        [[-0.3631, -0.1758, -0.2286, -0.9361, -1.3516],
         [-0.2695, -0.4870, -1.4853,  2.0657, -0.4406],
         [-0.1402,  0.3775,  1.2639, -0.9296, -0.2855],
         [ 1.2885,  0.0359,  0.6569, -0.1324, -0.6414]]])
batch matrix multiplication/aka batch dot product: tensor([[[ 2.3539,  1.6631,  2.1286,  0.2686,  0.1281],
         [-4.0497, -2.3144, -2.9815, -0.9605, -1.4077],
         [-0.9748, -0.9881, -1.3662, -0.2068,  1.2162]],

        [[-1.5000, -0.0102, -1.1284

In [17]:
x[0][0]

tensor([ 0.1046,  1.1601, -1.3931, -0.3127])

In [18]:
y[0].T[0]

tensor([-0.9000,  1.0008, -1.0459,  0.5438])

In [19]:
torch.sum(x[0][0]*y[0].T[0])

tensor(2.3539)

In [20]:
# double dot product/ Frobenius inner product (same as: torch.sum(hadamard-product))
x = torch.randn(2,3)
y = torch.randn(2,3)
print(f'x: {x}')
print(f'y: {y}')
print(f"double dot products: {torch.einsum('ij,ij ->', x, y)}")

x: tensor([[-0.8592,  0.6566,  0.7164],
        [-1.0285, -1.2278, -0.5428]])
y: tensor([[-1.4377,  0.9829,  0.9472],
        [ 1.2654, -0.3706,  0.1436]])
double dot products: 1.6347986459732056


In [26]:
# Batch sum of element-wise product along axis
x = torch.randn(2,3,4)
y = torch.randn(2,3,4)
print(f'x: {x}')
print(f'y: {y}')
print(f"Batch sum of element-wise product along the first dimension (axis=1): {torch.einsum('bij,bij ->bj', x, y)}")

x: tensor([[[ 0.1832,  1.0843, -1.1946, -0.8681],
         [-0.9427,  0.2572, -0.1599, -0.6560],
         [-0.9601,  0.3200,  0.7864,  0.3208]],

        [[ 1.0330, -0.8597, -2.4755, -0.1428],
         [-0.8975, -0.8498, -1.0475,  0.1290],
         [ 1.7242, -0.0284, -0.0817, -1.3094]]])
y: tensor([[[ 0.0761,  0.4934,  1.0517,  0.3773],
         [-0.9937, -0.0217, -0.2790,  1.3316],
         [ 1.8066,  0.1890,  1.5630,  0.5897]],

        [[-0.1358,  0.2856,  0.5682,  0.6544],
         [-0.9042,  0.3156, -0.4115,  0.9386],
         [-0.2048,  0.7711,  1.0042, -1.0369]]])
Batch sum of element-wise product along the first dimension (axis=1): tensor([[-0.7839,  0.5898,  0.0174, -1.0119],
        [ 0.3182, -0.5356, -1.0577,  1.3853]])


In [27]:
x[0]

tensor([[ 0.1832,  1.0843, -1.1946, -0.8681],
        [-0.9427,  0.2572, -0.1599, -0.6560],
        [-0.9601,  0.3200,  0.7864,  0.3208]])

In [28]:
y[0]

tensor([[ 0.0761,  0.4934,  1.0517,  0.3773],
        [-0.9937, -0.0217, -0.2790,  1.3316],
        [ 1.8066,  0.1890,  1.5630,  0.5897]])

In [24]:
torch.sum(torch.mul(x[0], y[0]), 0)

tensor([1.4747, 0.1943, 0.6307, 0.1314])