In [1]:
# In this notebook, you learn:
#
# 1) What does torch.unsqueeze do?
# 2) What does torch.nn.functional.pad do?

In [1]:
import torch

## [torch.unsqueeze](https://www.google.com/url?q=https%3A%2F%2Fpytorch.org%2Fdocs%2Fstable%2Fgenerated%2Ftorch.unsqueeze.html)

In [2]:
# unsqueeze basically adds a dimension at the given position. Lets think of a tensor as a container of 
# smaller tensors. If dim = 2 is used with unsqueeze, it means we go inside 2 containers and add a 
# container for all the tensors after traversing 2 steps i.e., we traverse 0, 1 dimensions and add an 
# extra dimension to every tensor we encounter after traversing 0, 1 dimensions.    

In [3]:
t1 = torch.tensor(data=[[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[10, 11, 12], [13, 14, 15], [16, 17, 18]]])
print("shape: ", t1.shape)
print("t1: ", t1)

shape:  torch.Size([2, 3, 3])
t1:  tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 7,  8,  9]],

        [[10, 11, 12],
         [13, 14, 15],
         [16, 17, 18]]])


In [6]:
# Creates a new dimension (which acts as dimension 0) and places the original tensor 't1' along this dimension.
# To summarize, it just adds an additional container on top of our tensor 't1' to create 't2'.
t2 = torch.unsqueeze(input=t1, dim = 0)
print("shape: ", t2.shape)
print("t2: ", t2)

shape:  torch.Size([1, 2, 3, 3])
t2:  tensor([[[[ 1,  2,  3],
          [ 4,  5,  6],
          [ 7,  8,  9]],

         [[10, 11, 12],
          [13, 14, 15],
          [16, 17, 18]]]])


In [7]:
# Traverse 1 level inside (1 container) 't1'. We get the 2 '2D' tensors [[1, 2, 3], [4, 5, 6], [7, 8, 9]] and 
# [[10, 11, 12], [13, 14, 15], [16, 17, 18]]. Each of these two tensors of shape (3, 3) are put inside another 
# container to create new tensors of shape (1, 3, 3). So, finally we get a '4D' tensor containing 2 '3D' tensors. 
t3 = torch.unsqueeze(input=t1, dim=1)
print("shape: ", t3.shape)
print("t3: ", t3)

shape:  torch.Size([2, 1, 3, 3])
t3:  tensor([[[[ 1,  2,  3],
          [ 4,  5,  6],
          [ 7,  8,  9]]],


        [[[10, 11, 12],
          [13, 14, 15],
          [16, 17, 18]]]])


In [8]:
# Traverse 2 levels inside (1 container) 't1'. We get the six '1D' tensors [1, 2, 3], [4, 5, 6], [7, 8, 9] and 
# [10, 11, 12], [13, 14, 15], [16, 17, 18]. Each of these six tensors of shape (3,) are put inside another 
# container to create new tensors of shape (1, 3). So, finally we get a '4D' tensor containing 2 '3D' tensors. 
t4 = torch.unsqueeze(input=t1, dim=2)
print("shape: ", t4.shape)
print("t4: ", t4)

shape:  torch.Size([2, 3, 1, 3])
t4:  tensor([[[[ 1,  2,  3]],

         [[ 4,  5,  6]],

         [[ 7,  8,  9]]],


        [[[10, 11, 12]],

         [[13, 14, 15]],

         [[16, 17, 18]]]])


In [16]:
# Traverse 3 levels inside (1 container) 't1'. We get the 18 individual numbers 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 
# 11, 12, 13, 14, 15, 16, 17, 18. Each of these 18 numbers are put inside a container to create new tensors of shape (1,). 
t5 = torch.unsqueeze(input=t1, dim=3)
print(t5, t5.shape)

tensor([[[[ 1],
          [ 2],
          [ 3]],

         [[ 4],
          [ 5],
          [ 6]],

         [[ 7],
          [ 8],
          [ 9]]],


        [[[10],
          [11],
          [12]],

         [[13],
          [14],
          [15]],

         [[16],
          [17],
          [18]]]]) torch.Size([2, 3, 3, 1])


## [torch.nn.functional.pad](https://pytorch.org/docs/stable/generated/torch.nn.functional.pad.html#torch.nn.functional.pad)

In [None]:
# Pads the given 'input' tensor with the provided 'value'.
# argument pad=(3, 2) means pads 3 values at the start and 2 values at the end for the tensors in the last dimension.
# So, [10, 20] tensor when padded using pad=(3, 2) turns into [2.5, 2.5, 2.5, 10.0, 20.0, 2.5, 2.5].
# Size in the last dimension increase by 3 + 2 = 5.
#
# In general, then pad has the following form:
#
# (padding_left, padding_right) to pad only the last dimension of the input tensor. 
# (padding_left, padding_right, padding_top, padding_bottom) to pad the last 2 dimensions of the input tensor.
# (padding_left, padding_right, padding_top, padding_bottom, padding_front, padding_back) to pad the last 3 
#       dimensions of the input tensor.
# 
# Similary extend the logic to all higher dimensions.
#

In [7]:
t6 = torch.tensor([[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]], [[13, 14, 15], [16, 17, 18]]], dtype=float)
print(t6, t6.shape)


tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.]],

        [[ 7.,  8.,  9.],
         [10., 11., 12.]],

        [[13., 14., 15.],
         [16., 17., 18.]]], dtype=torch.float64) torch.Size([3, 2, 3])


In [6]:
# Notice that it added 3 values at the start and 2 values at the end for the tensors in the last dimension.
t7 = torch.nn.functional.pad(input=t6, pad=(3, 2), mode="constant", value=2.5)
print(t7, t7.shape)

tensor([[[ 2.5000,  2.5000,  2.5000,  1.0000,  2.0000,  3.0000,  2.5000,
           2.5000],
         [ 2.5000,  2.5000,  2.5000,  4.0000,  5.0000,  6.0000,  2.5000,
           2.5000]],

        [[ 2.5000,  2.5000,  2.5000,  7.0000,  8.0000,  9.0000,  2.5000,
           2.5000],
         [ 2.5000,  2.5000,  2.5000, 10.0000, 11.0000, 12.0000,  2.5000,
           2.5000]],

        [[ 2.5000,  2.5000,  2.5000, 13.0000, 14.0000, 15.0000,  2.5000,
           2.5000],
         [ 2.5000,  2.5000,  2.5000, 16.0000, 17.0000, 18.0000,  2.5000,
           2.5000]]], dtype=torch.float64) torch.Size([3, 2, 8])


In [8]:
# Notice that it added two new 1D tensors for every 2D tensor. 
t8 = torch.nn.functional.pad(input=t6, pad=(3, 2, 1, 1), mode="constant", value=2.5)
print(t8, t8.shape)

tensor([[[ 2.5000,  2.5000,  2.5000,  2.5000,  2.5000,  2.5000,  2.5000,
           2.5000],
         [ 2.5000,  2.5000,  2.5000,  1.0000,  2.0000,  3.0000,  2.5000,
           2.5000],
         [ 2.5000,  2.5000,  2.5000,  4.0000,  5.0000,  6.0000,  2.5000,
           2.5000],
         [ 2.5000,  2.5000,  2.5000,  2.5000,  2.5000,  2.5000,  2.5000,
           2.5000]],

        [[ 2.5000,  2.5000,  2.5000,  2.5000,  2.5000,  2.5000,  2.5000,
           2.5000],
         [ 2.5000,  2.5000,  2.5000,  7.0000,  8.0000,  9.0000,  2.5000,
           2.5000],
         [ 2.5000,  2.5000,  2.5000, 10.0000, 11.0000, 12.0000,  2.5000,
           2.5000],
         [ 2.5000,  2.5000,  2.5000,  2.5000,  2.5000,  2.5000,  2.5000,
           2.5000]],

        [[ 2.5000,  2.5000,  2.5000,  2.5000,  2.5000,  2.5000,  2.5000,
           2.5000],
         [ 2.5000,  2.5000,  2.5000, 13.0000, 14.0000, 15.0000,  2.5000,
           2.5000],
         [ 2.5000,  2.5000,  2.5000, 16.0000, 17.0000, 18.0000,  2