In [45]:
# In this notebook, you learn:
#
# 1) How torch.reshape works?
# 2) How torch.view works?
# 3) How torch.transpose works?
# 4) How torch.repeat works?

In [2]:
import torch

In [47]:
# Resources to go through before continuining further in this notebook:
#
# 1) tensor_manipulations/understanding_tensors_part_2.ipynb
#       -- Explains how stride, contiguity and underlying storage works in a tensor.
# 2) https://dzone.com/articles/reshaping-pytorch-tensors
#       -- Explains how reshape and view work using stride and contiguous properties.

## [torch.reshape](https://pytorch.org/docs/stable/generated/torch.reshape.html#torch-reshape)

In [48]:
# One of the most confusing things when I tried to understand how reshape works is figuring out
# how the elements from the original tensor are rearranged in the reshaped tensor. To understand 
# this, you need to know how the tensors are stored internally in the memory. Please refer to
# 'tensor_manipulations/understanding_tensors_part_2.ipynb' notebook to understand that before 
# continuining further.

In [49]:
t1 = torch.tensor(data=[[1, 2, 3, 4, 5], [6, 7, 8, 9, 10], [11, 12, 13, 14, 15], [16, 17, 18, 19, 20]], dtype=torch.int64)
print("t1: ", t1)
print("shape: ", t1.shape)

t1:  tensor([[ 1,  2,  3,  4,  5],
        [ 6,  7,  8,  9, 10],
        [11, 12, 13, 14, 15],
        [16, 17, 18, 19, 20]])
shape:  torch.Size([4, 5])


In [50]:
print("is_contiguous: ", t1.is_contiguous())
t1_storage = t1.storage()
# This shows the elements in the underlying storage.
print("storage: ", t1_storage)

is_contiguous:  True
storage:   1
 2
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 20]


In [51]:
# Notice the placing of the elements in t2. The elements from the tensor (t1) are taken sequentially 
# and put into the new tensor (t2). This is equivalent to iterating on the storage of t1, taking each 
# element and the filling the positions in t2 sequentially i.e., filling the indices in the order
# (0, 0); (0, 1); (1, 0); (1, 1); (2, 0): (2, 1); ... (9, 0); (9, 1)
# 
# When the underlying storage holds the elements contiguously, the sequence order in tensor t2 is 
# same as the order of the elements in the storage. If the elements are not hold contiguously, then 
# the sequence order used for reshape is determined by the order of the elements in t2 and not the 
# storage (examples for this below). In this example, the storage order is good to compute reshape.
t2 = t1.reshape(10, 2)
print("t2: ", t2)
print("shape: ", t2.shape)
print("stride: ", t2.stride())

t2:  tensor([[ 1,  2],
        [ 3,  4],
        [ 5,  6],
        [ 7,  8],
        [ 9, 10],
        [11, 12],
        [13, 14],
        [15, 16],
        [17, 18],
        [19, 20]])
shape:  torch.Size([10, 2])
stride:  (2, 1)


In [52]:
# reshape did not copy the underlying storage for the new tensor (t2) since the new shape
# is compatible with the existing storage. The new tensor (view) can be obtained by just 
# adjusting the stride within the new tensor (t2) using the same storage.
if t1.storage().data_ptr() == t2.storage().data_ptr():
    print("Both the tensors t1 and t2 have same underlying storage")
else:
    print("The tensors t1 and t2 do not share the same storage")

Both the tensors t1 and t2 have same underlying storage


In [53]:
t3 = t1.transpose(0, 1)
print("t3: ", t3)
print("shape: ", t3.shape)
print("stride: ", t3.stride())
print("storage: ", t3.storage())
# is_contiguous is False because the elements in the tensor t3 are not stored sequentially 
# (contiguously) in the underlying storage.
#
# The order of the elements in the tensor t3 is:
# [1, 6, 11, 16, 2, 7, 12, 17, 3, 8, 13, 18, 4, 9, 14, 19, 5, 10, 15, 20].
#
# The sequence order of elements for any tensor can be determined by:
# 1) Get the tensor at index 0 in dimension 0. 
# 2) Keep traversing it recursively to get first 1D tensor.
# 3) List all the elements in this 1D tensor.
# 4) Go to the next 1D tensor and list all its elements.
# 5) Repeat this process until all the elements in the tensor at index 0 in dimension 0 are exhausted.
# 6) Repeat this process for all tensors in dimension 0. 
#
# At the end of this process, we have a single 1D array of elements which is the required order of elements.
#
# The order of the elements in the storage for t3 is:
# [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
print("is_contiguous: ", t3.is_contiguous())
if t1.storage().data_ptr() == t3.storage().data_ptr():
    print("Both the tensors t1 and t3 have same underlying storage")
else:
    print("The tensors t1 and t3 do not share the same storage")

t3:  tensor([[ 1,  6, 11, 16],
        [ 2,  7, 12, 17],
        [ 3,  8, 13, 18],
        [ 4,  9, 14, 19],
        [ 5, 10, 15, 20]])
shape:  torch.Size([5, 4])
stride:  (1, 5)
storage:   1
 2
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 20]
is_contiguous:  False
Both the tensors t1 and t3 have same underlying storage


In [54]:
# Notice the order of the elements in the reshaped tensor (t4). The order is based on the order
# of the elements in the tensor (t3) and not its underlying storage. 
#
# The order of the elements in the tensor t3 is defined by its stride. As explained above, you start 
# iterating from the outer-most dimension and go inside until the last dimension and list all the 
# elements along this last dimension. You then go to the next tensor and repeat this process until 
# all the elements in the tensor are accounted. The resultant flattened 1D array is then reshaped 
# into the new tensor.
t4 = t3.reshape(10, 2)
print("t4: ", t4)
print("shape: ", t4.shape)
print("stride: ", t4.stride())
print("storage: ", t4.storage())
print("is_contiguous: ", t4.is_contiguous())
# A new storage is created for t4 since t3 is not-contiguous and the new shape doesn't align with
# the original storage. In general, you cannot depend on reshape to create / not create a new storage.
# reshape first tries to do a view operation if possible which does not create a copy of the storage.
# However, if the view operation is not possible, it creates a new storage and does the reshaping 
# appropriately.
if t3.storage().data_ptr() == t4.storage().data_ptr():
    print("Both the tensors t3 and t4 have same underlying storage")
else:
    print("The tensors t3 and t4 do not share the same storage")

t4:  tensor([[ 1,  6],
        [11, 16],
        [ 2,  7],
        [12, 17],
        [ 3,  8],
        [13, 18],
        [ 4,  9],
        [14, 19],
        [ 5, 10],
        [15, 20]])
shape:  torch.Size([10, 2])
stride:  (2, 1)
storage:   1
 6
 11
 16
 2
 7
 12
 17
 3
 8
 13
 18
 4
 9
 14
 19
 5
 10
 15
 20
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 20]
is_contiguous:  True
The tensors t3 and t4 do not share the same storage


In [55]:
# This returns a new tensor (t5) same as t3 but with the underlying storage contiguous in t5 unlike t3.
# Notice that the order of the elements in the underlying storage for t5 is same as t4. 
# In general, if it gets harder to determine the order of the elements used by reshape operation, simply
# make the original tensor contiguous and figure out the order by printing its storage.
t5 = t3.contiguous()
print("t5: ", t5)
print("shape: ", t5.shape)
print("stride: ", t5.stride())
print("storage: ", t5.storage())

t5:  tensor([[ 1,  6, 11, 16],
        [ 2,  7, 12, 17],
        [ 3,  8, 13, 18],
        [ 4,  9, 14, 19],
        [ 5, 10, 15, 20]])
shape:  torch.Size([5, 4])
stride:  (4, 1)
storage:   1
 6
 11
 16
 2
 7
 12
 17
 3
 8
 13
 18
 4
 9
 14
 19
 5
 10
 15
 20
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 20]


## [tensor.view](https://pytorch.org/docs/stable/generated/torch.Tensor.view.html#torch-tensor-view)

In [None]:
# Please go through the reshape operation first before understanding the view operation.
# The order of the elements in the new tensor after view opeartion is determined in the same way 
# as we did for reshape operation. The only difference is that view does not create a new storage 
# for the new tensor. If it cannot create a view without changing the storage, it throws an error.

In [56]:
t6 = torch.tensor(data=[[1, 2, 3, 4], [5, 6, 7, 8], [9, 10, 11, 12]], dtype=torch.int64)
print("t6: ", t6)
print("shape: ", t6.shape)
print("stride: ", t6.stride())
print("is_contiguous: ", t6.is_contiguous())
print("storage: ", t6.storage())

t6:  tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])
shape:  torch.Size([3, 4])
stride:  (4, 1)
is_contiguous:  True
storage:   1
 2
 3
 4
 5
 6
 7
 8
 9
 10
 11
 12
[torch.storage.TypedStorage(dtype=torch.int64, device=cpu) of size 12]


In [1]:
# I don't actually understand the condition mentioned in the official documentation for the
# view operation to be valid. From, what I understand, all we need is that the original 
# tensor need to contiguous for view operation to be valid. Please let me know if this is
# not the case.
#
# Apparently, a view operation can be applied even if the tensor is not contiguous. It only
# needs to satisfy a loose contiguity-like condition. This is probably what the official
# pytorch documentation is explaining. (Example for this below)
#
# 1) https://kamilelukosiute.com/pytorch/When+can+a+tensor+be+view()ed%3F
#       -- Explains the necessary conditions for a view operation to be valid.

In [58]:
t7 = t6.view(4, 3)
print("t7: ", t7)
print("shape: ", t7.shape)
print("stride: ", t7.stride())
print("is_contiguous: ", t7.is_contiguous())

t7:  tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]])
shape:  torch.Size([4, 3])
stride:  (3, 1)
is_contiguous:  True


In [59]:
t8 = t6.view(2, 6)
print("t8: ", t8)
print("shape: ", t8.shape)
print("stride: ", t8.stride())
print("is_contiguous: ", t8.is_contiguous())

t8:  tensor([[ 1,  2,  3,  4,  5,  6],
        [ 7,  8,  9, 10, 11, 12]])
shape:  torch.Size([2, 6])
stride:  (6, 1)
is_contiguous:  True


In [60]:
t10 = t6.view(2, 3, 2)
print("t10: ", t10)
print("shape: ", t10.shape)
print("stride: ", t10.stride())
print("is_contiguous: ", t10.is_contiguous())

t10:  tensor([[[ 1,  2],
         [ 3,  4],
         [ 5,  6]],

        [[ 7,  8],
         [ 9, 10],
         [11, 12]]])
shape:  torch.Size([2, 3, 2])
stride:  (6, 2, 1)
is_contiguous:  True


In [61]:
# Transpose operation usually breaks continuity
t11 = t6.transpose(0, 1)
print("t11: ", t11)
print("shape: ", t11.shape)
print("stride: ", t11.stride())
print("is_contiguous: ", t11.is_contiguous())

t11:  tensor([[ 1,  5,  9],
        [ 2,  6, 10],
        [ 3,  7, 11],
        [ 4,  8, 12]])
shape:  torch.Size([4, 3])
stride:  (1, 4)
is_contiguous:  False


In [62]:
# NOTICE THAT THE VIEW OPERATION ON t11 DID NOT RAISE AN ERROR AND HAS SUCCEEDED.
# THIS MEANS VIEW CAN BE APPLIED EVEN IF THE TENSOR IS NOT CONTIGUOUS. IT JUST
# NEEDS TO SATISFY CONTIGUITY LIKE CONDITION (NEED TO EXPLORE MORE ABOUT THIS).
t12 = t11.view(2, 2, 3)
print("t12: ", t12)
print("shape: ", t12.shape)
print("stride: ", t12.stride())
print("is_contiguous: ", t12.is_contiguous())
if t12.storage().data_ptr() == t6.storage().data_ptr():
    print("Both t12 and t6 share the same storage.")
else:
    print("Both t12 and t6 do not share the same storage.")

t12:  tensor([[[ 1,  5,  9],
         [ 2,  6, 10]],

        [[ 3,  7, 11],
         [ 4,  8, 12]]])
shape:  torch.Size([2, 2, 3])
stride:  (2, 1, 4)
is_contiguous:  False
Both t12 and t6 share the same storage.


## [torch.transpose](https://pytorch.org/docs/stable/generated/torch.transpose.html#torch-transpose)

In [2]:
# In 2D matrices, performing a transpose interchanges rows and columns. Old Rows
# become the New columns and Old columns become the New rows.
#
# [0, 1, 2]
# [3, 4, 5]
#
# when transposed changes to
#
# [0, 3]
# [1, 4]
# [2, 5]
#
# Transpose at higher dimensions also works in the same way but with a bit of
# added complexity. Lets say we have a tensor of shape (2, 3, 4, 5) and want to
# transpose the dimensions 0 and 3. The shape of the transposed matrix now
# becomes (5, 3, 4, 2). The transpose operation follows the following logic:
#
# The order of the elements as we traverse the original tensor along dimension 0
# is the same as the order of the elements as we traverse the transposed tensor
# along dimension 3 and viceversa --> [I AM NOT 100% CERTAIN IF THIS IS CORRECT - 
# EXAMPLE BELOW]. 
# 
# It might be hard to intuitively understand what transpose does for higher dimensions. 
# The best way I understood is through the mathematical definition of transpose and 
# extending it to higher dimensions.
#
# The exact mathematical logic if a 4D tensor is transposed along dimensions 0 and 3 is:
# OriginalTensor[ind0][ind1][ind2][ind3] = TransposedTensor[ind3][ind1][ind2][ind0].
#
# 1) Refer to understanding_tensors_part_1.ipynb (link to the notebook) to understand more
#    about dimensions in a tensor.

In [13]:
t13 = torch.arange(27).reshape(3, 3, 3)
print("shape: ", t13.shape)
print("t15: ", t13, "\n")

shape:  torch.Size([3, 3, 3])
t15:  tensor([[[ 0,  1,  2],
         [ 3,  4,  5],
         [ 6,  7,  8]],

        [[ 9, 10, 11],
         [12, 13, 14],
         [15, 16, 17]],

        [[18, 19, 20],
         [21, 22, 23],
         [24, 25, 26]]]) 



In [16]:
# The order of elements as you traverse along dimension 0 in the original tensor (t13) is:
# (0, 9, 18); (1, 10, 19); (2, 11, 20); (3, 12, 21); ... etc.
#
# The order of elements as you traverse along dimension 2 in the transposed tensor (t14) is:
# (0, 9, 18); (1, 10, 19); (2, 11, 20); (3, 12, 21); ... etc.
#
# I am still a little confused here as to why the order of elements as you traverse along
# dimension 2 in the transposed tensor (t14) is not:
# (0, 9, 18); (3, 12, 21); (6, 15, 24): (1, 10, 19); .. etc.
#
# This is why I am not 100% certain about the logic I mentioned above. Either the transpose 
# logic (my verbal explanation) is wrong or the logic in determining the order of elements as 
# you traverse along some dimension is wrong. So, for now lets stick to the mathematical
# definition of the transpose operation.
t14 = t13.transpose(0, 2)
print("shape: ", t14.shape)
print("t16: ", t14)

shape:  torch.Size([3, 3, 3])
t16:  tensor([[[ 0,  9, 18],
         [ 3, 12, 21],
         [ 6, 15, 24]],

        [[ 1, 10, 19],
         [ 4, 13, 22],
         [ 7, 16, 25]],

        [[ 2, 11, 20],
         [ 5, 14, 23],
         [ 8, 17, 26]]])


In [15]:
# Printing the elements in the original tensor and the transpose of it.
for ind0 in range(t13.size(0)):
  for ind1 in range(t13.size(1)):
    for ind2 in range(t13.size(2)):
        print("original_tensor_elem: ", t13[ind0][ind1][ind2], " : transposed_tensor_elem: ", t14[ind2][ind1][ind0])

# Basically, like with the 2d matrices, the transpose logic follows that Mat[i][j] = Mat[j][i].
# Visualizing the transpose for higher order tensors is not easy.

original_tensor_elem:  tensor(0)  : transposed_tensor_elem:  tensor(0)
original_tensor_elem:  tensor(1)  : transposed_tensor_elem:  tensor(1)
original_tensor_elem:  tensor(2)  : transposed_tensor_elem:  tensor(2)
original_tensor_elem:  tensor(3)  : transposed_tensor_elem:  tensor(3)
original_tensor_elem:  tensor(4)  : transposed_tensor_elem:  tensor(4)
original_tensor_elem:  tensor(5)  : transposed_tensor_elem:  tensor(5)
original_tensor_elem:  tensor(6)  : transposed_tensor_elem:  tensor(6)
original_tensor_elem:  tensor(7)  : transposed_tensor_elem:  tensor(7)
original_tensor_elem:  tensor(8)  : transposed_tensor_elem:  tensor(8)
original_tensor_elem:  tensor(9)  : transposed_tensor_elem:  tensor(9)
original_tensor_elem:  tensor(10)  : transposed_tensor_elem:  tensor(10)
original_tensor_elem:  tensor(11)  : transposed_tensor_elem:  tensor(11)
original_tensor_elem:  tensor(12)  : transposed_tensor_elem:  tensor(12)
original_tensor_elem:  tensor(13)  : transposed_tensor_elem:  tensor(13

## [torch.tensor.repeat](https://pytorch.org/docs/stable/generated/torch.Tensor.repeat.html#torch-tensor-repeat)

In [4]:
# Repeat basically copies the tensor specified number of times along the specified dimension.
t15 = torch.tensor(data=[[[1, 2, 3], [4, 5, 6]], [[7, 8, 9], [10, 11, 12]]], dtype=torch.int64)
print("shape: ", t15.shape)
print("t15: \n", t15)

shape:  torch.Size([2, 2, 3])
t15: 
 tensor([[[ 1,  2,  3],
         [ 4,  5,  6]],

        [[ 7,  8,  9],
         [10, 11, 12]]])


In [5]:
# Note: Repeat does not change the number of dimensions of a tensor. It just copies the tensor along 
# the specified dimension. 
# The original two 2D tensors are repeated twice along the 0th dimension.
# The arguments to repeat are the number of times to repeat the tensors along each dimension starting from the
# 0th dimension. 
# 2 --> Repeat the tensor twice along the 0th dimension.
# 1 --> Repeat the tensor once along the 1st dimension i.e., no change.
# 1 --> Repeat the tensor once along the 2nd dimension i.e., no change.
t16 = t15.repeat(2, 1, 1)
print("shape: ", t16.shape)
print("t16: \n", t16)

shape:  torch.Size([4, 2, 3])
t16: 
 tensor([[[ 1,  2,  3],
         [ 4,  5,  6]],

        [[ 7,  8,  9],
         [10, 11, 12]],

        [[ 1,  2,  3],
         [ 4,  5,  6]],

        [[ 7,  8,  9],
         [10, 11, 12]]])


In [6]:
# The original four 1D tensors (two each in one 2D tensor) are repeated twice along the 1st dimension. 
t17 = t15.repeat(1, 2, 1)
print("shape: ", t17.shape)
print("t17: \n", t17)

shape:  torch.Size([2, 4, 3])
t17: 
 tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 1,  2,  3],
         [ 4,  5,  6]],

        [[ 7,  8,  9],
         [10, 11, 12],
         [ 7,  8,  9],
         [10, 11, 12]]])


In [7]:
# Now, lets try to figure out how repeat works if it is given positive value (>1) in multiple dimensions.
# It is equivalent to first applying the repeat operation in the 0th dimension and then applying the repeat
# operation in the 1st dimension and so on.
t18 = t15.repeat(2, 2, 1)
print("shape: ", t18.shape)
print("t18: \n", t18)

shape:  torch.Size([4, 4, 3])
t18: 
 tensor([[[ 1,  2,  3],
         [ 4,  5,  6],
         [ 1,  2,  3],
         [ 4,  5,  6]],

        [[ 7,  8,  9],
         [10, 11, 12],
         [ 7,  8,  9],
         [10, 11, 12]],

        [[ 1,  2,  3],
         [ 4,  5,  6],
         [ 1,  2,  3],
         [ 4,  5,  6]],

        [[ 7,  8,  9],
         [10, 11, 12],
         [ 7,  8,  9],
         [10, 11, 12]]])
