In [147]:
import torch


In [148]:
# trace of a tensor (i.e. sum of main-diagonal elements)
x = torch.randn(2,2)
print(f'x: {x}')
print(f"sum of main-diagonal elements (aka trace): {torch.einsum('ii->', x)}")

x: tensor([[ 1.4554,  0.6291],
        [-0.4008,  0.3982]])
sum of main-diagonal elements (aka trace): 1.8536256551742554


In [149]:
# return a diagonal
x = torch.randn(2,2)
print(f'x: {x}')
print(f"extract elements along the main-diagonal: {torch.einsum('ii->i', x)}")

x: tensor([[ 0.7136, -1.7813],
        [-0.2499, -0.7382]])
extract elements along the main-diagonal: tensor([ 0.7136, -0.7382])


In [150]:
# axis summations 
x = torch.randn(2,3)
print(f'x: {x}')
print(f"summations along axis1: {torch.einsum('ij->i', x)}")

x: tensor([[ 0.5722, -0.9945,  1.2906],
        [ 0.5316, -1.8923, -0.9052]])
summations along axis1: tensor([ 0.8682, -2.2659])


In [151]:
# sum all elements 
x = torch.randn(2,3)
print(f'x: {x}')
print(f"sum all elements: {torch.einsum('ij->', x)}")

x: tensor([[-0.5291, -0.2948, -2.1050],
        [-1.0407, -0.9205,  0.6792]])
sum all elements: -4.210868835449219


In [152]:
# transpositions and permutations
x = torch.randn(2,3)
print(f'x: {x}')
print(f" matrix transpose: {torch.einsum('ij->ji', x)}")

x: tensor([[ 1.0318,  1.4415,  1.4365],
        [-1.0435,  1.1437, -0.4630]])
 matrix transpose: tensor([[ 1.0318, -1.0435],
        [ 1.4415,  1.1437],
        [ 1.4365, -0.4630]])


In [153]:
# Hadamard product (aka, element-wise product)
x = torch.randn(2,3)
y = torch.randn(2,3)
print(f'x: {x}')
print(f'y: {y}')
print(f"element-wise product: {torch.einsum('ij, ij->ij', x, y)}")

x: tensor([[ 0.9742, -0.4430, -1.8800],
        [ 0.5570, -1.7751,  0.2413]])
y: tensor([[ 1.9212,  0.9515,  0.2009],
        [-2.9630,  0.7491, -0.5690]])
element-wise product: tensor([[ 1.8716, -0.4215, -0.3777],
        [-1.6505, -1.3298, -0.1373]])


In [154]:
# Batch Hadamard product (aka, batch element-wise product)
x = torch.randn(2,3,4)
y = torch.randn(2,3,4)
print(f'x: {x}')
print(f'y: {y}')
print(f"batch element-wise product: {torch.einsum('ijk, ijk->ijk', x, y)}")

x: tensor([[[ 1.0526,  1.8596,  0.1192,  1.6131],
         [-0.3572,  0.1965,  1.4893, -0.2191],
         [ 0.9539,  1.4574,  0.0053,  0.5452]],

        [[ 0.0578,  0.4697,  1.8059, -0.4695],
         [-0.7423, -0.7920, -0.0337, -0.6462],
         [-0.6740,  0.8839, -0.0392,  0.1205]]])
y: tensor([[[-0.2666,  0.8063,  0.5379,  0.2846],
         [ 1.2565, -0.5064, -0.4661,  1.6448],
         [ 0.2930, -0.4391, -0.1996,  0.8416]],

        [[ 0.1985,  0.4576,  0.2440,  1.2956],
         [-0.7262,  0.2235,  0.5902,  2.4531],
         [-0.0584,  0.1470,  0.6056,  0.7997]]])
batch element-wise product: tensor([[[-2.8066e-01,  1.4994e+00,  6.4118e-02,  4.5915e-01],
         [-4.4878e-01, -9.9515e-02, -6.9407e-01, -3.6037e-01],
         [ 2.7951e-01, -6.3991e-01, -1.0505e-03,  4.5885e-01]],

        [[ 1.1468e-02,  2.1495e-01,  4.4055e-01, -6.0824e-01],
         [ 5.3908e-01, -1.7698e-01, -1.9891e-02, -1.5853e+00],
         [ 3.9338e-02,  1.2992e-01, -2.3750e-02,  9.6343e-02]]])


In [155]:
x[0][0]

tensor([1.0526, 1.8596, 0.1192, 1.6131])

In [156]:
y[0][0]

tensor([-0.2666,  0.8063,  0.5379,  0.2846])

In [157]:
torch.mul(x[0][0], y[0][0])

tensor([-0.2807,  1.4994,  0.0641,  0.4591])

In [158]:
#Element-wise squaring
x = torch.randn(2,3)
print(f'x: {x}')
print(f"element-wise squaring: {torch.einsum('ij, ij->ij', x, x)}")

x: tensor([[-0.6970,  0.5416, -2.3438],
        [-0.7064, -0.4908,  0.2314]])
element-wise squaring: tensor([[0.4858, 0.2933, 5.4933],
        [0.4991, 0.2409, 0.0536]])


In [159]:
#batch element-wise squaring
x = torch.randn(2,3,4)
print(f'x: {x}')
print(f"batch element-wise squaring of 3D: {torch.einsum('ijk, ijk->ijk', x, x)}")

x: tensor([[[-0.6483, -0.2203,  1.7334, -0.1931],
         [ 1.6125,  0.2170,  0.4474,  0.6518],
         [ 0.5514,  0.7976,  0.8095, -0.1708]],

        [[-0.2352, -0.6071,  0.0503, -0.0709],
         [-0.0826, -0.5198, -1.7337, -0.1842],
         [ 2.0146, -0.2615,  0.9516, -0.3381]]])
batch element-wise squaring of 3D: tensor([[[4.2026e-01, 4.8525e-02, 3.0045e+00, 3.7298e-02],
         [2.6000e+00, 4.7094e-02, 2.0021e-01, 4.2487e-01],
         [3.0400e-01, 6.3618e-01, 6.5530e-01, 2.9176e-02]],

        [[5.5330e-02, 3.6859e-01, 2.5348e-03, 5.0274e-03],
         [6.8249e-03, 2.7020e-01, 3.0058e+00, 3.3919e-02],
         [4.0586e+00, 6.8387e-02, 9.0555e-01, 1.1428e-01]]])


In [160]:
# matrix multiplication (aka dot product, aka scalar product, aka inner product)
x = torch.randn(2,3)
y = torch.randn(3,4)
print(f'x: {x}')
print(f'y: {y}')
print(f"matrix multiplication/aka dot product/aka inner product: {torch.einsum('ij, jk->ik', x, y)}")

x: tensor([[ 1.1797, -0.5684,  0.1082],
        [ 0.8384, -1.0231, -1.7403]])
y: tensor([[ 0.6265, -1.0299, -1.3438, -0.5172],
        [ 0.1588, -1.0652, -0.8883, -1.4499],
        [-1.8385, -1.2262,  1.2874, -0.3961]])
matrix multiplication/aka dot product/aka inner product: tensor([[ 0.4497, -0.7422, -0.9409,  0.1712],
        [ 3.5623,  2.3602, -2.4582,  1.7391]])


In [161]:
# batch matrix multiplication
x = torch.randn(2,3,4)
y = torch.randn(2,4,5)
print(f'x: {x}')
print(f'y: {y}')
print(f"batch matrix multiplication/aka batch dot product: {torch.einsum('bij,bjk ->bik', x, y)}")


x: tensor([[[ 0.1515, -0.3496,  0.7650,  0.6040],
         [ 0.8685, -1.4364,  0.0914,  0.1212],
         [ 0.3453, -0.2340, -0.3260,  0.5720]],

        [[-2.8902,  0.6418, -0.0195,  1.2386],
         [ 1.8281, -0.2336, -0.3397,  0.0692],
         [-0.4830,  0.0451,  0.5253,  0.0101]]])
y: tensor([[[ 0.1877, -1.3799, -0.2127, -0.6544, -0.0269],
         [ 0.7993,  0.2086, -0.4630, -1.4449,  0.1118],
         [ 0.7959, -0.0929, -0.0520, -0.0802, -1.0885],
         [-0.3624,  0.3412,  0.6097,  0.0980, -0.2581]],

        [[ 1.3594,  1.0189,  0.3122, -0.9438, -0.8878],
         [ 0.6475, -0.9310, -0.0697, -0.9962,  0.6202],
         [-1.4462, -0.1789, -1.3859, -0.1202,  0.8754],
         [ 0.1682,  1.8850, -0.7201, -1.2129,  0.2301]]])
batch matrix multiplication/aka batch dot product: tensor([[[ 0.1390, -0.1470,  0.4581,  0.4038, -1.0317],
         [-0.9563, -1.4653,  0.5495,  1.5116, -0.3147],
         [-0.5890, -0.2998,  0.4006,  0.1943,  0.1718]],

        [[-3.2770, -1.2042, -1.8120

In [162]:
x[0][0]

tensor([ 0.1515, -0.3496,  0.7650,  0.6040])

In [163]:
y[0].T[0]

tensor([ 0.1877,  0.7993,  0.7959, -0.3624])

In [164]:
torch.sum(x[0][0]*y[0].T[0])

tensor(0.1390)

In [165]:
# double dot product/ Frobenius inner product (same as: torch.sum(hadamard-product))
x = torch.randn(2,3)
y = torch.randn(2,3)
print(f'x: {x}')
print(f'y: {y}')
print(f"double dot products: {torch.einsum('ij,ij ->', x, y)}")

x: tensor([[-1.0790, -0.1838, -1.1243],
        [ 0.5787, -0.2924,  1.6458]])
y: tensor([[ 0.4051, -1.7559,  0.0279],
        [-0.0902, -0.5370, -0.5490]])
double dot products: -0.9443585872650146


In [166]:
# Batch sum of element-wise product
x = torch.randn(2,3,4)
y = torch.randn(2,3,4)
print(f'x: {x}')
print(f'y: {y}')
print(f": {torch.einsum('bij,bij ->bj', x, y)}")

x: tensor([[[ 0.8471,  0.6298, -1.0686,  0.6563],
         [-0.9393,  0.0840,  0.2303,  0.4052],
         [ 1.1011, -0.8386, -1.3216,  0.9700]],

        [[ 1.7534, -0.6420,  0.8794, -0.1019],
         [-0.0932,  1.1071,  0.2051,  0.8260],
         [ 0.3092,  0.2295, -1.9509, -1.1938]]])
y: tensor([[[ 1.7827, -0.4926,  1.2277,  0.5379],
         [-0.1553, -0.6410, -0.9363, -1.0946],
         [-2.5254, -1.0114,  0.9536, -0.2253]],

        [[-0.5575,  0.4838,  0.5891,  0.1379],
         [-0.3922, -0.6975, -0.9899,  2.2213],
         [-0.7167, -0.0087,  0.0480, -0.4166]]])
: tensor([[-1.1249,  0.4841, -2.7879, -0.3091],
        [-1.1626, -1.0849,  0.2213,  2.3181]])


In [167]:
x[0]

tensor([[ 0.8471,  0.6298, -1.0686,  0.6563],
        [-0.9393,  0.0840,  0.2303,  0.4052],
        [ 1.1011, -0.8386, -1.3216,  0.9700]])

In [168]:
y[0]

tensor([[ 1.7827, -0.4926,  1.2277,  0.5379],
        [-0.1553, -0.6410, -0.9363, -1.0946],
        [-2.5254, -1.0114,  0.9536, -0.2253]])

In [169]:
torch.sum(torch.mul(x[0], y[0]), 0)

tensor([-1.1249,  0.4841, -2.7879, -0.3091])