In [2]:
# In this notebook, you learn:
#
# 1) What does torch.cat do?
# 2) What does torch.stack do?

In [2]:
import torch

## [torch.cat](https://pytorch.org/docs/stable/generated/torch.cat.html#torch.cat)

In [None]:
# Concatenating along a dimension means joining multiple tensors end-to-end along 
# the specified dimension. It effectively increases the size of the specified 
# dimension by adding tensors to that dimension.

In [3]:
t1 = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(t1, t1.shape, "\n")
t2 = torch.tensor([[7, 8, 9], [10, 11, 12]])
print(t2, t2.shape)


tensor([[1, 2, 3],
        [4, 5, 6]]) torch.Size([2, 3])
tensor([[ 7,  8,  9],
        [10, 11, 12]]) torch.Size([2, 3])


In [5]:
# Both t1 and t2 have shape (2, 3) meaning they have 2 rows and 3 columns.
# Dimension 0 points in the direction of rows i.e., top to bottom in a matrix.
# When concatenating along dimension 0, we are joining the tensors along the rows effectively 
# increasing the number of rows in the concatenated result.
t3 = torch.cat(tensors=[t1, t2], dim=0)
print(t3, t3.shape)

tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]]) torch.Size([4, 3])


In [6]:
# Both t1 and t2 have shape (2, 3) meaning they have 2 rows and 3 columns.
# Dimension 1 points in the direction of columns i.e., left to right in a matrix.
# When concatenating along dimension 1, we are joining the tensors along the columns effectively
# increasing the number of columns in the concatenated result.
t4 = torch.cat(tensors=[t1, t2], dim=1)
print(t4, t4.shape)

tensor([[ 1,  2,  3,  7,  8,  9],
        [ 4,  5,  6, 10, 11, 12]]) torch.Size([2, 6])


In [7]:
t5 = torch.tensor([[[1, 2], [3, 4]], [[5, 6], [7, 8]]])
print(t5, t5.shape, "\n")
t6 = torch.tensor([[[9, 10], [11, 12]], [[13, 14], [15, 16]]])
print(t6, t6.shape)

tensor([[[1, 2],
         [3, 4]],

        [[5, 6],
         [7, 8]]]) torch.Size([2, 2, 2]) 

tensor([[[ 9, 10],
         [11, 12]],

        [[13, 14],
         [15, 16]]]) torch.Size([2, 2, 2])


In [8]:
# Lets now understand the concatenation operation by looking at tensors as containers.
# As we traverse along dimension 0, we get tensors of shape (2, 2). We append these (2, 2) tensors
# from 't6' to 't5' to obtain the concatenated result.
# To elaborate, we get the tensors [[9, 10], [11, 12]] and [[13, 14], [15, 16]] as we traverse
# along dimension 0 in 't6'. We append these tensors at the end of 't5' as we traverse along
# dimension 0 in 't5'. So, we first have [[1, 2], [3, 4]] and [[5, 6], [7, 8]] followed by the
# tensors from 't6' in concatenated result.
t7 = torch.cat(tensors=[t5, t6], dim=0)
print(t7, t7.shape)

tensor([[[ 1,  2],
         [ 3,  4]],

        [[ 5,  6],
         [ 7,  8]],

        [[ 9, 10],
         [11, 12]],

        [[13, 14],
         [15, 16]]]) torch.Size([4, 2, 2])


In [9]:
# As we traverse along dimension 1, we get the tensors of shape (2,). We append these (2,) tensors from
# 't6' to 't5' to obtain the concatenated result.
# To elaborate, we get the tensors [9, 10], [11, 12], [13, 14], [15, 16] as we traverse along dimension 1
# in 't6'. We append these tensors to the corresponding tensors in 't5' as we traverse along dimension 1
# in 't5'. Note that we obtain [9, 10] and [11, 12] by traversing the first (2, 2) tensor in 't6'. So, 
# these tensors are appended to the first (2, 2) tensor ([[1, 2], [3, 4]]) in 't5'. Similarly, we obtain 
# [13, 14] and [15, 16] by traversing the second (2, 2) tensor in 't6'. So, these tensors are appended to 
# the second (2, 2) tensor ([[5, 6], [7, 8]]) in 't6' in the concatenated result.
t8 = torch.cat(tensors=[t5, t6], dim=1)
print(t8, t8.shape)

tensor([[[ 1,  2],
         [ 3,  4],
         [ 9, 10],
         [11, 12]],

        [[ 5,  6],
         [ 7,  8],
         [13, 14],
         [15, 16]]]) torch.Size([2, 4, 2])


In [10]:
# As we traverse along dimension 2, we get individual numbers. We append these numbers from 't6' to 't5'
# to obtain the concatenated result.
# To elaborate, we get the numbers 9, 10, 11, 12, 13, 14, 15, 16 as we traverse along dimension 2 in 't6'.
# We append these numbers to the corresponding tensors to 't5' as we traverse along dimension 2 in 't5'.
# Note that we obtain 9, 10 by traversiog the first (2,) tensor in 't6'. So, these numbers are appended
# to the first (2,) tensor ([1, 2]) in 't5'. Similarly, with pairs {11, 12}; {13, 14}; {15, 16}. 
t9 = torch.cat(tensors=[t5, t6], dim=2)
print(t9, t9.shape)

tensor([[[ 1,  2,  9, 10],
         [ 3,  4, 11, 12]],

        [[ 5,  6, 13, 14],
         [ 7,  8, 15, 16]]]) torch.Size([2, 2, 4])


In [11]:
t10 = torch.tensor([[[[1, 2], [3, 4]], [[5, 6], [7, 8]]], [[[9, 10], [11, 12]], [[13, 14], [15, 16]]]])
print(t10, t10.shape, "\n")
t11 = torch.tensor([[[[17, 18], [19, 20]], [[21, 22], [23, 24]]], [[[25, 26], [27, 28]], [[29, 30], [31, 32]]]])
print(t11, t11.shape)

tensor([[[[ 1,  2],
          [ 3,  4]],

         [[ 5,  6],
          [ 7,  8]]],


        [[[ 9, 10],
          [11, 12]],

         [[13, 14],
          [15, 16]]]]) torch.Size([2, 2, 2, 2]) 

tensor([[[[17, 18],
          [19, 20]],

         [[21, 22],
          [23, 24]]],


        [[[25, 26],
          [27, 28]],

         [[29, 30],
          [31, 32]]]]) torch.Size([2, 2, 2, 2])


In [12]:
# Following the same logic as above, the 2 '3D' tensors from 't11' are appended along dimension 0
# to the 2 '3D' tensors in 't10'
t12 = torch.cat(tensors=[t10, t11], dim=0)
print(t12, t12.shape)

tensor([[[[ 1,  2],
          [ 3,  4]],

         [[ 5,  6],
          [ 7,  8]]],


        [[[ 9, 10],
          [11, 12]],

         [[13, 14],
          [15, 16]]],


        [[[17, 18],
          [19, 20]],

         [[21, 22],
          [23, 24]]],


        [[[25, 26],
          [27, 28]],

         [[29, 30],
          [31, 32]]]]) torch.Size([4, 2, 2, 2])


In [13]:
# Following the same logic as above, the 4 '2D' tensors from 't11' are appended along dimension 1
# to the corresponding '2D' tensors in 't10'.
t13 = torch.cat(tensors=[t10, t11], dim=1)
print(t13, t13.shape)

tensor([[[[ 1,  2],
          [ 3,  4]],

         [[ 5,  6],
          [ 7,  8]],

         [[17, 18],
          [19, 20]],

         [[21, 22],
          [23, 24]]],


        [[[ 9, 10],
          [11, 12]],

         [[13, 14],
          [15, 16]],

         [[25, 26],
          [27, 28]],

         [[29, 30],
          [31, 32]]]]) torch.Size([2, 4, 2, 2])


In [14]:
# Following the same logic as above, the 8 '1D' tensors from 't11' are appended along dimension 2
# to the corresponding '1D' tensors in 't10'.
t14 = torch.cat(tensors=[t10, t11], dim=2)
print(t14, t14.shape)

tensor([[[[ 1,  2],
          [ 3,  4],
          [17, 18],
          [19, 20]],

         [[ 5,  6],
          [ 7,  8],
          [21, 22],
          [23, 24]]],


        [[[ 9, 10],
          [11, 12],
          [25, 26],
          [27, 28]],

         [[13, 14],
          [15, 16],
          [29, 30],
          [31, 32]]]]) torch.Size([2, 2, 4, 2])


In [16]:
# Following the same logic as above, the 16 numbers from 't11' are appended along dimension 3
# to the corresponding numbers in 't10'.
t15 = torch.cat(tensors=[t10, t11], dim=3)
print(t15, t15.shape)

tensor([[[[ 1,  2, 17, 18],
          [ 3,  4, 19, 20]],

         [[ 5,  6, 21, 22],
          [ 7,  8, 23, 24]]],


        [[[ 9, 10, 25, 26],
          [11, 12, 27, 28]],

         [[13, 14, 29, 30],
          [15, 16, 31, 32]]]]) torch.Size([2, 2, 2, 4])


## [torch.stack](https://pytorch.org/docs/stable/generated/torch.stack.html#torch.stack)

In [None]:
# This [video](https://www.youtube.com/watch?v=kF2AlpykJGY) presents a view where torch.stack operation
# is a 'torch.unqueeze' operation followed by 'torch.cat'.
#
# Honestly, I didn't really understand intuitively what 'torch.stack' in cases other than when dim=0 is used. 
# The official documentation says that it concatenates the tensors along the new dimension which is not 
# clear to me. However, to understand how the stack operation manipulates the tensors, it can be viewed as 
# a combination of 'unsqueeze' and 'cat' i.e., We first unsqueeze (add a dimension) the tensor at given index
# and then concatente the unsqueezed tensors along the given dimension.
# 
# Please refer to 'understanding_simple_pytorch_tensor_manipulations_part_1.ipynb' notebook to understand
# the unsqueeze operation.

In [4]:
t16 = torch.tensor([1, 2, 3])
print(t16, t16.shape, "\n")
t17 = torch.tensor([4, 5, 6])
print(t17, t17.shape)

tensor([1, 2, 3]) torch.Size([3]) 

tensor([4, 5, 6]) torch.Size([3])


In [None]:
# Add a new dimension (at dim 0) and stack the tensors along this dimension.
# So, [1, 2, 3] and [4, 5, 6] are stacked on top of one another to get a 2D tensor. 
t18 = torch.stack(tensors=[t16, t17], dim=0)
print(t18, t18.shape, "\n")

t19 = torch.cat(tensors=[torch.unsqueeze(input=t16, dim=0), torch.unsqueeze(input=t17, dim=0)], dim=0)
if torch.equal(t18, t19):
    print("t18 = t19")
else:
    print("t18 != t19")

tensor([[1, 2, 3],
        [4, 5, 6]]) torch.Size([2, 3]) 

t18 = t19


In [None]:
t20 = torch.stack(tensors=[t16, t17], dim=1)
print(t20, t20.shape, "\n")

t21 = torch.cat(tensors=[torch.unsqueeze(input=t16, dim=1), torch.unsqueeze(input=t17, dim=1)], dim=1)
if torch.equal(t20, t21):
    print("t20 = t21")
else:
    print("t20 != t21")

tensor([[1, 4],
        [2, 5],
        [3, 6]]) torch.Size([3, 2]) 

t20 = t21


In [None]:
t22 = torch.tensor([[1, 2, 3], [4, 5, 6]])
print(t22, t22.shape, "\n")

t23 = torch.tensor([[7, 8, 9], [10, 11, 12]])
print(t23, t23.shape)

tensor([[1, 2, 3],
        [4, 5, 6]]) torch.Size([2, 3]) 

tensor([[ 7,  8,  9],
        [10, 11, 12]]) torch.Size([2, 3])


In [None]:
t24 = torch.stack(tensors=[t22, t23], dim=0)
print(t24, t24.shape, "\n")

t25 = torch.cat(tensors=[torch.unsqueeze(input=t22, dim=0), torch.unsqueeze(input=t23, dim=0)], dim=0)
if torch.equal(t24, t25):
    print("t24 = t25")
else:
    print("t24 != t25")

tensor([[[ 1,  2,  3],
         [ 4,  5,  6]],

        [[ 7,  8,  9],
         [10, 11, 12]]]) torch.Size([2, 2, 3]) 

t24 = t25


In [None]:
t26 = torch.stack(tensors=[t22, t23], dim=1)
print(t26, t26.shape, "\n")

t27 = torch.cat(tensors=[torch.unsqueeze(input=t22, dim=1), torch.unsqueeze(input=t23, dim=1)], dim=1)
if torch.equal(t26, t27):
    print("t26 = t27")
else:
    print("t26 != t27")

tensor([[[ 1,  2,  3],
         [ 7,  8,  9]],

        [[ 4,  5,  6],
         [10, 11, 12]]]) torch.Size([2, 2, 3]) 

t26 = t27


In [None]:
t28 = torch.stack(tensors=[t22, t23], dim=2)
print(t28, t28.shape, "\n")

t29 = torch.cat(tensors=[torch.unsqueeze(input=t22, dim=2), torch.unsqueeze(input=t23, dim=2)], dim=2)
if torch.equal(t28, t29):
    print("t28 = t29")
else:
    print("t28 != t29")

tensor([[[ 1,  7],
         [ 2,  8],
         [ 3,  9]],

        [[ 4, 10],
         [ 5, 11],
         [ 6, 12]]]) torch.Size([2, 3, 2]) 

t28 = t29
