In [None]:
# In this notebook, you learn:
#
# 1) How torch.triu works?
# 2) How torch.Tensor.masked_fill works?
# 3) How torch.index_select works?

In [1]:
import torch

## [torch.triu](https://pytorch.org/docs/stable/generated/torch.triu.html#torch.triu)

In [3]:
# This zero outs the lower triangular part of the matrix and returns an upper traingular matrix.
# It also works for rectangular matrices.

In [5]:
t1 = torch.randn(5, 5)
print("shape: ", t1.shape)
print("t1: ", t1)

shape:  torch.Size([5, 5])
t1:  tensor([[-1.6374,  2.5723, -1.1342, -0.2347,  0.3572],
        [ 0.2100, -1.4732, -0.0893, -0.2497, -0.6651],
        [ 0.9017,  0.8352,  1.8637, -1.3146,  1.2185],
        [-2.2434, -0.5520, -1.1878, -0.3896,  0.0609],
        [-0.9350, -0.1111,  1.6628, -1.4316, -0.0488]])


In [6]:
# This zero outs the lower triangular part of the matrix and returns an upper traingular matrix.
# The elements on the main diagonal are not touched.
t2 = torch.triu(input=t1, diagonal=0)
print("shape: ", t2.shape)
print("t2: ", t2)

shape:  torch.Size([5, 5])
t2:  tensor([[-1.6374,  2.5723, -1.1342, -0.2347,  0.3572],
        [ 0.0000, -1.4732, -0.0893, -0.2497, -0.6651],
        [ 0.0000,  0.0000,  1.8637, -1.3146,  1.2185],
        [ 0.0000,  0.0000,  0.0000, -0.3896,  0.0609],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -0.0488]])


In [8]:
# The diagonal parameter can be set to a non-zero value. If the diagonal parameter is set to a positive value, 
# in addition to setting the ower part of the matrix below the diagonal to zeros, it also sets the elements 
# on the diagonal and above the diagonal to zeros. The number diagonals to be zeroed out is equal to the number
# specified by the diagonal parameter. The main diagonal is referenced by 1, the diagonal above the main 
# diagonal is referenced by 2, and so on. 
# 
# It works in a slightly different way when the diagonal parameter is set to a negative value. The diagonal
# below the main diagonal is referenced by -1, the diagonal below that is referenced by -2, and so on.
# So, if the diagonal parameter is set to -1, it will only zero out all the elements only below the diagonal 
# referenced by -1.

In [7]:
# Since the diagonal paramater is set to 1, the elements on the main diagonal are also set to zeros.
t3 = torch.triu(input=t1, diagonal=1)
print("shape: ", t3.shape)
print("t3: ", t3)

shape:  torch.Size([5, 5])
t3:  tensor([[ 0.0000,  2.5723, -1.1342, -0.2347,  0.3572],
        [ 0.0000,  0.0000, -0.0893, -0.2497, -0.6651],
        [ 0.0000,  0.0000,  0.0000, -1.3146,  1.2185],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0609],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]])


In [9]:
# Since the diagonal paramater is set to 2, the elements upto the diagonal referenced by 2 are set to zeros. 
t4 = torch.triu(input=t1, diagonal=2)
print("shape: ", t4.shape)
print("t4: ", t4)

shape:  torch.Size([5, 5])
t4:  tensor([[ 0.0000,  0.0000, -1.1342, -0.2347,  0.3572],
        [ 0.0000,  0.0000,  0.0000, -0.2497, -0.6651],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  1.2185],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000]])


In [10]:
# Since the diagonal paramater is set to -1, only the elements until the diagonal referenced by -1 are set 
# to zeros.
t5 = torch.triu(input=t1, diagonal=-1)
print("shape: ", t5.shape)
print("t5: ", t5)

shape:  torch.Size([5, 5])
t5:  tensor([[-1.6374,  2.5723, -1.1342, -0.2347,  0.3572],
        [ 0.2100, -1.4732, -0.0893, -0.2497, -0.6651],
        [ 0.0000,  0.8352,  1.8637, -1.3146,  1.2185],
        [ 0.0000,  0.0000, -1.1878, -0.3896,  0.0609],
        [ 0.0000,  0.0000,  0.0000, -1.4316, -0.0488]])


In [11]:
# Since the diagonal paramater is set to -2, only the elements until the diagonal referenced by -2 are set 
# to zeros.
t6 = torch.triu(input=t1, diagonal=-2)
print("shape: ", t6.shape)
print("t6: ", t6)

shape:  torch.Size([5, 5])
t6:  tensor([[-1.6374,  2.5723, -1.1342, -0.2347,  0.3572],
        [ 0.2100, -1.4732, -0.0893, -0.2497, -0.6651],
        [ 0.9017,  0.8352,  1.8637, -1.3146,  1.2185],
        [ 0.0000, -0.5520, -1.1878, -0.3896,  0.0609],
        [ 0.0000,  0.0000,  1.6628, -1.4316, -0.0488]])


In [12]:
# Now lets look at how torch.triu works for rectangular matrices. The definition of main diagonal is all the
# elements where the row index is equal to the column index.
t6 = torch.randn(6, 8)
print("shape: ", t6.shape)
print("t6: ", t6)

shape:  torch.Size([6, 8])
t6:  tensor([[ 0.9437, -1.0712,  0.5793, -1.6571,  0.8172, -0.5198,  0.6952,  0.3813],
        [-1.2688, -0.0252, -1.6834, -0.4943, -0.2003, -0.5946, -1.8651, -0.7360],
        [-1.9460, -0.1967,  1.2759, -1.3324,  2.0369, -1.0557, -0.9046, -1.5130],
        [ 0.7944, -0.5504,  0.4719, -0.2772,  0.7654, -0.3927, -0.3990, -1.4614],
        [ 0.2334,  0.4527,  1.1078,  0.4759, -0.6146, -0.3308,  0.0224, -0.1788],
        [ 1.3209, -0.1060, -1.2837, -0.8006, -1.3743, -0.6603,  0.8471, -2.4122]])


In [13]:
# All the elements below the main diagonal are set to zeros.
t6 = torch.triu(input=t6, diagonal=0)
print("shape: ", t6.shape)
print("t6: ", t6)

shape:  torch.Size([6, 8])
t6:  tensor([[ 0.9437, -1.0712,  0.5793, -1.6571,  0.8172, -0.5198,  0.6952,  0.3813],
        [ 0.0000, -0.0252, -1.6834, -0.4943, -0.2003, -0.5946, -1.8651, -0.7360],
        [ 0.0000,  0.0000,  1.2759, -1.3324,  2.0369, -1.0557, -0.9046, -1.5130],
        [ 0.0000,  0.0000,  0.0000, -0.2772,  0.7654, -0.3927, -0.3990, -1.4614],
        [ 0.0000,  0.0000,  0.0000,  0.0000, -0.6146, -0.3308,  0.0224, -0.1788],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.6603,  0.8471, -2.4122]])


In [14]:
# The elements on the main diagonal are also set to zero since the diagonal parameter is set to 1.
t7 = torch.triu(input=t6, diagonal=1)
print("shape: ", t7.shape)
print("t7: ", t7)

shape:  torch.Size([6, 8])
t7:  tensor([[ 0.0000, -1.0712,  0.5793, -1.6571,  0.8172, -0.5198,  0.6952,  0.3813],
        [ 0.0000,  0.0000, -1.6834, -0.4943, -0.2003, -0.5946, -1.8651, -0.7360],
        [ 0.0000,  0.0000,  0.0000, -1.3324,  2.0369, -1.0557, -0.9046, -1.5130],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.7654, -0.3927, -0.3990, -1.4614],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.3308,  0.0224, -0.1788],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.8471, -2.4122]])


In [15]:
t8 = torch.triu(input=t6, diagonal=2)
print("shape: ", t8.shape)
print("t8: ", t8)

shape:  torch.Size([6, 8])
t8:  tensor([[ 0.0000,  0.0000,  0.5793, -1.6571,  0.8172, -0.5198,  0.6952,  0.3813],
        [ 0.0000,  0.0000,  0.0000, -0.4943, -0.2003, -0.5946, -1.8651, -0.7360],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  2.0369, -1.0557, -0.9046, -1.5130],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -0.3927, -0.3990, -1.4614],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0224, -0.1788],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000, -2.4122]])


## [torch.tensor.masked_fill](https://pytorch.org/docs/stable/generated/torch.Tensor.masked_fill_.html#torch-tensor-masked-fill)

In [8]:
t9 = torch.randint(low=1, high=200, size=(5, 6))
print("shape: ", t9.shape)
print("t9: \n", t9)

shape:  torch.Size([5, 6])
t9: 
 tensor([[139, 112, 185, 121,  10,  10],
        [154,  66, 144,  78,   4, 184],
        [122,  69,   6,  44, 192,  67],
        [ 65,  71, 125, 199, 101, 149],
        [166,  42,  73,  22, 141, 185]])


In [9]:
t10 = torch.randint(low=-30, high=30, size=(5, 6)) < 0
print("shape: ", t10.shape)
print("t10: \n", t10)

shape:  torch.Size([5, 6])
t10: 
 tensor([[ True,  True,  True,  True, False, False],
        [False, False,  True, False,  True, False],
        [ True, False, False, False,  True, False],
        [ True, False, False,  True,  True,  True],
        [False, False,  True, False,  True,  True]])


In [10]:
# masked_fill replaces the elements in the tensor with the specified value where the mask is True.
t11 = t9.masked_fill(mask=t10, value=78)
print("shape: ", t11.shape)
print("t11: \n", t11)

shape:  torch.Size([5, 6])
t11: 
 tensor([[ 78,  78,  78,  78,  10,  10],
        [154,  66,  78,  78,  78, 184],
        [ 78,  69,   6,  44,  78,  67],
        [ 78,  71, 125,  78,  78,  78],
        [166,  42,  78,  22,  78,  78]])


## [torch.index_select](https://pytorch.org/docs/stable/generated/torch.index_select.html)

In [None]:
# This is to select a few sub-tensors from a tensor based on the indices and return a new tensor. This
# is straight forward when the tensor is 1D or 2D. However, when the tensor is 3D or more, it is a bit
# tricky.

In [4]:
t12 = torch.randn(size=[4, 5], dtype=torch.float32)
print("shape: ", t12.shape)
print("t12: \n", t12)

shape:  torch.Size([4, 5])
t12: 
 tensor([[-1.3549,  0.1244, -1.2694, -1.0895,  1.6906],
        [ 1.5157, -1.3576, -0.8575,  1.2428, -0.7041],
        [ 1.4359, -0.3758,  1.1248, -1.1707,  0.3895],
        [-1.0861, -0.7470, -1.0660,  0.1809, -0.6891]])


In [5]:
t13_indices = torch.tensor([0, 3])
print("t13_indices shape: ", t13_indices.shape)
print("t13_indices: ", t13_indices)

t13_indices shape:  torch.Size([2])
t13_indices:  tensor([0, 3])


In [6]:
# dim is the dimension along which the indices are selected. In this case, the indices are selected along the
# 0th dimension. The indices are 0 and 3. So, the 0th and 3rd rows are selected. 
t14 = torch.index_select(input=t12, dim=0, index=t13_indices)
print("shape: ", t14.shape)
print("t14: \n", t14)

shape:  torch.Size([2, 5])
t14: 
 tensor([[-1.3549,  0.1244, -1.2694, -1.0895,  1.6906],
        [-1.0861, -0.7470, -1.0660,  0.1809, -0.6891]])


In [7]:
# Since dim is set to 1, the indices are selected along the 1st dimension. The 0th and 3rd columns are selected.
t15 = torch.index_select(input=t12, dim=1, index=t13_indices)
print("shape: ", t15.shape)
print("t15: \n", t15)

shape:  torch.Size([4, 2])
t15: 
 tensor([[-1.3549, -1.0895],
        [ 1.5157,  1.2428],
        [ 1.4359, -1.1707],
        [-1.0861,  0.1809]])


#### Lets see how index_select works for a 3D tensor

In [8]:
t16 = torch.arange(60).reshape(3, 4, 5)
print("shape: ", t16.shape)
print("t16: \n", t16)

shape:  torch.Size([3, 4, 5])
t16: 
 tensor([[[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9],
         [10, 11, 12, 13, 14],
         [15, 16, 17, 18, 19]],

        [[20, 21, 22, 23, 24],
         [25, 26, 27, 28, 29],
         [30, 31, 32, 33, 34],
         [35, 36, 37, 38, 39]],

        [[40, 41, 42, 43, 44],
         [45, 46, 47, 48, 49],
         [50, 51, 52, 53, 54],
         [55, 56, 57, 58, 59]]])


In [11]:
t17_indices = torch.tensor([0, 1])
print("t17_indices shape: ", t17_indices.shape)
print("t17_indices: ", t17_indices)

t17_indices shape:  torch.Size([2])
t17_indices:  tensor([0, 1])


In [13]:
# dim is set to 0. So, the 0th and 1st tensors are selected along the 0th dimension.
t18 = torch.index_select(input=t16, dim=0, index=t17_indices)
print("shape: ", t18.shape)
print("t18: \n", t18)

shape:  torch.Size([2, 4, 5])
t18: 
 tensor([[[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9],
         [10, 11, 12, 13, 14],
         [15, 16, 17, 18, 19]],

        [[20, 21, 22, 23, 24],
         [25, 26, 27, 28, 29],
         [30, 31, 32, 33, 34],
         [35, 36, 37, 38, 39]]])


In [None]:
# Lets understand the behaviour here when dim is set to 1. It seems a bit surprising at first. 
# When dim is set to 1, it is traversing 2 levels deep into the original tensor t16. Then, it is selecting
# the 0th and 1st tensors along the 0th dimension in each of those tensors obtained after traversing 2 levels.
#
# Tensors obtained after traversing 2 levels deep:
# tensor([[ 0,  1,  2,  3,  4],
#         [ 5,  6,  7,  8,  9],
#         [10, 11, 12, 13, 14],
#         [15, 16, 17, 18, 19]])

# tensor([[20, 21, 22, 23, 24],
#         [25, 26, 27, 28, 29],
#         [30, 31, 32, 33, 34],
#         [35, 36, 37, 38, 39]])
#
# tensor([[40, 41, 42, 43, 44],
#         [45, 46, 47, 48, 49],
#         [50, 51, 52, 53, 54],
#         [55, 56, 57, 58, 59]])
# 
# Now, coming to the selection in each of the tensors obtained after traversing 2 levels deep:
#
# In the first tensor, the 0th and 1st tensors along the 0th dimension are:
# tensor([[ 0,  1,  2,  3,  4],
#         [ 5,  6,  7,  8,  9]])
# which are part of the output tensor.
#
#
# In the second tensor, the 0th and 1st tensors along the 0th dimension are:
# tensor([[20, 21, 22, 23, 24],
#         [25, 26, 27, 28, 29]])
# which are part of the output tensor.
#
#
# In the third tensor, the 0th and 1st tensors along the 0th dimension are:
# tensor([[40, 41, 42, 43, 44],
#         [45, 46, 47, 48, 49]])
# which are part of the output tensor.
#
# However, the original shape in all other dimensions is retained. So, the output tensor is of shape (3, 2, 5).
t19 = torch.index_select(input=t16, dim=1, index=t17_indices)
print("shape: ", t19.shape)
print("t19: \n", t19)

shape:  torch.Size([3, 2, 5])
t19: 
 tensor([[[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9]],

        [[20, 21, 22, 23, 24],
         [25, 26, 27, 28, 29]],

        [[40, 41, 42, 43, 44],
         [45, 46, 47, 48, 49]]])


In [None]:
# When dim is set to 2, it is traversing 3 levels deep into the original tensor t16. Then, it is selecting
# the 0th and 1st tensors along the 0th dimension in each of those tensors obtained after traversing 3 levels.
#
# Tensors obtained after traversing 3 levels deep:
# tensor([ 0,  1,  2,  3,  4])
# tensor([ 5,  6,  7,  8,  9])
# tensor([10, 11, 12, 13, 14])
# tensor([15, 16, 17, 18, 19])
# tensor([20, 21, 22, 23, 24])
# tensor([25, 26, 27, 28, 29])
# tensor([30, 31, 32, 33, 34])
# tensor([35, 36, 37, 38, 39])
# tensor([40, 41, 42, 43, 44])
# tensor([45, 46, 47, 48, 49])
# tensor([50, 51, 52, 53, 54])
# tensor([55, 56, 57, 58, 59])
#
# For each of the tensors obtained after traversing 3 levels deep, the 0th and 1st tensors along the 0th dimension
# are selected. The output tensor is of shape (3, 4, 2).
t20 = torch.index_select(input=t16, dim=2, index=t17_indices)
print("shape: ", t20.shape)
print("t20: \n", t20)

shape:  torch.Size([3, 4, 2])
t20: 
 tensor([[[ 0,  1],
         [ 5,  6],
         [10, 11],
         [15, 16]],

        [[20, 21],
         [25, 26],
         [30, 31],
         [35, 36]],

        [[40, 41],
         [45, 46],
         [50, 51],
         [55, 56]]])


In [16]:
# index_select can also be used to select the same tensor multiple times. The indices can be repeated.
t21_indices = torch.tensor([0, 0, 0, 1, 1])
print("t21_indices shape: ", t21_indices.shape)
print("t21_indices: ", t21_indices)

t21_indices shape:  torch.Size([5])
t21_indices:  tensor([0, 0, 0, 1, 1])


In [None]:
# As expected, the 0th tensor is selected 3 times and the 1st tensor is selected 2 times.
t22 = torch.index_select(input=t16, dim=0, index=t21_indices)
print("shape: ", t22.shape)
print("t22: \n", t22)

shape:  torch.Size([5, 4, 5])
t22: 
 tensor([[[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9],
         [10, 11, 12, 13, 14],
         [15, 16, 17, 18, 19]],

        [[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9],
         [10, 11, 12, 13, 14],
         [15, 16, 17, 18, 19]],

        [[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9],
         [10, 11, 12, 13, 14],
         [15, 16, 17, 18, 19]],

        [[20, 21, 22, 23, 24],
         [25, 26, 27, 28, 29],
         [30, 31, 32, 33, 34],
         [35, 36, 37, 38, 39]],

        [[20, 21, 22, 23, 24],
         [25, 26, 27, 28, 29],
         [30, 31, 32, 33, 34],
         [35, 36, 37, 38, 39]]])
