# Operations on tensors

In [2]:
import torch

Mutlipication of all elements present in `x` by `10`:

In [3]:
x = torch.tensor([[1,2,3,4], [5,6,7,8]])
print(x * 10)

tensor([[10, 20, 30, 40],
        [50, 60, 70, 80]])


Adding `10` to the elements in `x`:

In [4]:
x = torch.tensor([[1,2,3,4], [5,6,7,8]])
y = x.add(10)
print(y)
z = x + 10
print(z)

tensor([[11, 12, 13, 14],
        [15, 16, 17, 18]])
tensor([[11, 12, 13, 14],
        [15, 16, 17, 18]])


Reshaping a tensor:

In [5]:
y = torch.tensor([2,3,1,0])
print(y, y.shape)
y = y.view(4,1)
print(y, y.shape)

tensor([2, 3, 1, 0]) torch.Size([4])
tensor([[2],
        [3],
        [1],
        [0]]) torch.Size([4, 1])


Another way to reshape is by using `squeeze` method, where we provide the axis index that we want to remove, this is only applicable when the axis we want to remove has only one item in that dimension.

In [7]:
x = torch.randn(4, 1, 4)
z1 = torch.squeeze(x, 1)
# the same can be performed directly on x by calling squeeze method
z2 = x.squeeze(1)

print(x)
print(z1)
print(z2)

tensor([[[-0.6210,  0.4516, -0.5400, -0.5010]],

        [[-0.4470,  2.0191, -0.1397, -0.4768]],

        [[ 0.0896,  0.2328, -1.2045,  0.2512]],

        [[-0.0093, -0.2077,  0.1938,  0.9376]]])
tensor([[-0.6210,  0.4516, -0.5400, -0.5010],
        [-0.4470,  2.0191, -0.1397, -0.4768],
        [ 0.0896,  0.2328, -1.2045,  0.2512],
        [-0.0093, -0.2077,  0.1938,  0.9376]])
tensor([[-0.6210,  0.4516, -0.5400, -0.5010],
        [-0.4470,  2.0191, -0.1397, -0.4768],
        [ 0.0896,  0.2328, -1.2045,  0.2512],
        [-0.0093, -0.2077,  0.1938,  0.9376]])


Matrix multiplication of two different tensors:

In [9]:
x = torch.tensor([[1,2,3,4], [5,6,7,8]])
y = torch.tensor([[2], [3], [1], [0]])
print(torch.matmul(x, y))

tensor([[11],
        [35]])


Alternatively, matrix multiplication can also be performed by using the `@` operator:

In [10]:
print(x @ y)

tensor([[11],
        [35]])


Similar to `concatenate` in NumPy, we can perform concatanation of tensors using the `cat` method:

In [26]:
x = torch.tensor([
    [[1,2],[3,4]], 
    [[5,6],[7,8]]
])
print(x, x.shape)
z = torch.cat([x,x], axis=0)
print("\ncat axis=0\n", z, z.shape)
z = torch.cat([x,x], axis=1)
print("\ncat axis=1\n", z, z.shape)

tensor([[[1, 2],
         [3, 4]],

        [[5, 6],
         [7, 8]]]) torch.Size([2, 2, 2])

cat axis=0
 tensor([[[1, 2],
         [3, 4]],

        [[5, 6],
         [7, 8]],

        [[1, 2],
         [3, 4]],

        [[5, 6],
         [7, 8]]]) torch.Size([4, 2, 2])

cat axis=1
 tensor([[[1, 2],
         [3, 4],
         [1, 2],
         [3, 4]],

        [[5, 6],
         [7, 8],
         [5, 6],
         [7, 8]]]) torch.Size([2, 4, 2])


Extraction of the maximum value:

In [28]:
x = torch.arange(25).reshape(5,5)
print('Max:', x.max())

Max: tensor(24)


We can exctract the maximum value along with the row index where the maximum is present:

In [37]:
print(x)
print(x.max(dim=0))

m, argm = x.max(dim=1)
print("\nmax in axis 1:\n", m, argm)

tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14],
        [15, 16, 17, 18, 19],
        [20, 21, 22, 23, 24]])
torch.return_types.max(
values=tensor([20, 21, 22, 23, 24]),
indices=tensor([4, 4, 4, 4, 4]))

max in axis 1:
 tensor([ 4,  9, 14, 19, 24]) tensor([4, 4, 4, 4, 4])


Permute the dimensions:

In [40]:
x = torch.tensor([[[1,2,3],[4,5,6]]])
print(x, x.shape)
z = x.permute(2, 0, 1)
print(z, z.shape)

tensor([[[1, 2, 3],
         [4, 5, 6]]]) torch.Size([1, 2, 3])
tensor([[[1, 4]],

        [[2, 5]],

        [[3, 6]]]) torch.Size([3, 1, 2])


**PRO TIP**

Never reshape (that is use `tensor.view` on) a tensor to swap the dimensions.
Even though Torch will not throw an an error, this is *wrong* and create unforseen results during training. 
If you need to swap dimensions, *always* use permute.

There are much more operations.

You can always run `dir(torch.Tensors)` to see all the methods possible for a Torch tensor and `help(torch.Tensor.<method>)` to go through official help and documentation for that method.