# Tensor Shape Manipulation

Notebook on how to use the different reshaping functions for high dimensional tensors. Focuses on:
1. `torch.view() / torch.reshape()`
2. `torch.permute()`

## Example Tensor 

The examples in this notebook will start with the 4D tensor below with shape 2 x 3 x 2 x 2 (2 RGB 2x2 images).

```python
[
    [
        [
            [0,1],
            [2,3]
        ],
        [
            [4,5],
            [6,7]
        ],
        [
            [8,9],
            [10,11]
        ],
    ],
    [
        [
            [12,13],
            [14,15]
        ],
        [
            [16,17],
            [18,19]
        ],
        [
            [20,21],
            [22,23]
        ],
    ]
]
```

## How is this tensor stored in memory? 

[Great reference!](https://blog.ezyang.com/2019/05/pytorch-internals/)

Under the hood, PyTorch stores tensors as multi-dimensional arrays as contiguous blocks of memory. We also store additional metadata such as the size, device, dtype and **stride**. Stride is used for dense tensors, and is used under the hood for indexing logic. Since all the data is stored as a contiguous block of memory, the stride tells us which physical indices to retrieve data for given an indexing query. 

For example, if we have a 2x2 tensor and we want to access the bottom left element, we would do `tensor[1,0]`. The stride of this tensor would be `(2,1)`, and we would fetch the (2x1 + 1x0) element using that stride. If we had a 2x3x2x2 tensor, the stride would be `(12,4,2,1)`. To get the very last element, we would do `tensor[1,2,1,1]` and we would access the (12x1 + 4x2 + 2x1 + 1x1) = 23rd element. The stride is the product of all numbers to the right, and we implicitly add a 1 on the end. 

In [29]:
# Imports

import torch
import time
from collections import defaultdict

In [5]:
EXAMPLE_TENSOR = torch.Tensor([
    [
        [
            [0,1],
            [2,3]
        ],
        [
            [4,5],
            [6,7]
        ],
        [
            [8,9],
            [10,11]
        ],
    ],
    [
        [
            [12,13],
            [14,15]
        ],
        [
            [16,17],
            [18,19]
        ],
        [
            [20,21],
            [22,23]
        ],
    ]
])
# I could have reshaped it, but the whole point is to learn how these work
# Check we wrote the tensor correctly
assert EXAMPLE_TENSOR.shape == (2,3,2,2)

### Benchmarking

In [128]:
# Util function for timing the execution time
def parent_decorator(num_iterations=10000):
    def average_execution_time(func):
        def wrapper(*args, **kwargs):
            stats = defaultdict(int)
            stats["min_time"] = float("inf")
            for i in range(num_iterations):
                start_time = time.time()
                func(*args, **kwargs)
                end_time = time.time()
                stats["total_time"] += end_time - start_time
                stats["num_times"] += 1
                stats["max_time"] = max(end_time - start_time, stats["max_time"])
                stats["min_time"] = min(end_time - start_time, stats["min_time"])
            stats["avg_time"] = stats["total_time"] / stats["num_times"]
            print(stats)
        return wrapper
    return average_execution_time


@parent_decorator()
def reshape(t: torch.Tensor, shape: list[int]) -> torch.Tensor:
    return t.reshape(shape)

@parent_decorator()
def view(t: torch.Tensor, shape: list[int]) -> torch.Tensor:
    return t.view(shape)

@parent_decorator()
def permute(t: torch.Tensor, permute_shape: list[int]) -> torch.Tensor:
    return t.permute(permute_shape)

BENCHMARK_TENSOR = torch.cat([EXAMPLE_TENSOR for i in range(10000)])

In [135]:
# The execution times for each method is pretty similar, because all of these return a view
# We can't benchmark a tensor that needs to be copied because then view and permute will fail
# Given these findings, to raise issues if a copy is required, view should be used
# If not having errors is preferred, then reshape is better
reshape(BENCHMARK_TENSOR, (20000,2,3,2))
view(BENCHMARK_TENSOR, (20000,2,3,2))


permute(BENCHMARK_TENSOR, (0,1,2,3))

defaultdict(<class 'int'>, {'min_time': 7.152557373046875e-07, 'total_time': 0.020767688751220703, 'num_times': 10000, 'max_time': 0.0032639503479003906, 'avg_time': 2.07676887512207e-06})
defaultdict(<class 'int'>, {'min_time': 7.152557373046875e-07, 'total_time': 0.02815866470336914, 'num_times': 10000, 'max_time': 0.0007233619689941406, 'avg_time': 2.8158664703369143e-06})
defaultdict(<class 'int'>, {'min_time': 7.152557373046875e-07, 'total_time': 0.01273488998413086, 'num_times': 10000, 'max_time': 2.193450927734375e-05, 'avg_time': 1.273488998413086e-06})


### View / Reshape

In [137]:
EXAMPLE_TENSOR.view((2,2,2,3))

tensor([[[[ 0.,  1.,  2.],
          [ 3.,  4.,  5.]],

         [[ 6.,  7.,  8.],
          [ 9., 10., 11.]]],


        [[[12., 13., 14.],
          [15., 16., 17.]],

         [[18., 19., 20.],
          [21., 22., 23.]]]])

In [138]:
EXAMPLE_TENSOR.reshape((2,2,2,3))

tensor([[[[ 0.,  1.,  2.],
          [ 3.,  4.,  5.]],

         [[ 6.,  7.,  8.],
          [ 9., 10., 11.]]],


        [[[12., 13., 14.],
          [15., 16., 17.]],

         [[18., 19., 20.],
          [21., 22., 23.]]]])

In [None]:
# PATCH EMBEDDING + MERGING EXAMPLE
# Imagine we had a batch of patch embeddings, and turn into a grid of patch embeddings
# Here we assume the embedding dimension is 3
p_embed = EXAMPLE_TENSOR.view(2, -1, 3)
print(p_embed)
# Now we want a grid of patches, i.e split the middle dimension into 2
# In the patch merge layer, if we had a merge window of 2, we would combine every 4th row in the first printed tensor
grid_embed = p_embed.view(2, 2, 2, 3)
print(grid_embed)

x0 = grid_embed[:, 0::2, 0::2, :]
x1 = grid_embed[:, 0::2, 1::2, :]
x2 = grid_embed[:, 1::2, 0::2, :]
x3 = grid_embed[:, 1::2, 1::2, :]
merged = torch.concat([x0, x1, x2, x3], dim=-1)
# The merged results should have the first block as a 1-D array in the first tensor, followed by the second one
print(merged)
# Finally we flatten it so we just have 3 dimensions (i.e B x num_merged_patches x (4x3))
# The num of merged patches is 1 for each image
merged.view(2, -1, 12)

tensor([[[ 0.,  1.,  2.],
         [ 3.,  4.,  5.],
         [ 6.,  7.,  8.],
         [ 9., 10., 11.]],

        [[12., 13., 14.],
         [15., 16., 17.],
         [18., 19., 20.],
         [21., 22., 23.]]])
tensor([[[[ 0.,  1.,  2.],
          [ 3.,  4.,  5.]],

         [[ 6.,  7.,  8.],
          [ 9., 10., 11.]]],


        [[[12., 13., 14.],
          [15., 16., 17.]],

         [[18., 19., 20.],
          [21., 22., 23.]]]])
tensor([[[[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11.]]],


        [[[12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23.]]]])


tensor([[[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11.]],

        [[12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23.]]])

In [164]:
# Flattening is the same as reshaping
# It uses reshape under the hood, and is a convenience function
# Same with unsqueeze
print(EXAMPLE_TENSOR.flatten(start_dim=-2, end_dim=-1))
EXAMPLE_TENSOR.flatten(start_dim=-2, end_dim=-1).shape

print(EXAMPLE_TENSOR.reshape(2,3,4))

tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])
tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])


In [165]:
EXAMPLE_TENSOR.unsqueeze(3).shape

torch.Size([2, 3, 2, 1, 2])

### Permute

The main difference with Permute is that it changes the order of the underlying elements. Previously, view and reshape altered the shape and stride so that the elements are in order, but just in a different configuration. 

In [None]:
# Permute the last two dimensions
# Visually, it means that for each 2x2 matrix, we collect the data across columns instead of across rows
# For example, [[0,1],[2,3]] -> [[0,2], [1,3]]
torch.permute(EXAMPLE_TENSOR, (0,1,3,2))

tensor([[[[ 0.,  2.],
          [ 1.,  3.]],

         [[ 4.,  6.],
          [ 5.,  7.]],

         [[ 8., 10.],
          [ 9., 11.]]],


        [[[12., 14.],
          [13., 15.]],

         [[16., 18.],
          [17., 19.]],

         [[20., 22.],
          [21., 23.]]]])

In [154]:
t = torch.Tensor([[0,1,2], [3,4,5]])
t.permute((1,0))

tensor([[0., 3.],
        [1., 4.],
        [2., 5.]])

In [159]:
# Here we reverse the dimension
# The columns in the original tensor switch to the batch dimension
# The row dimension remains unchanged
# The new column dimension is the original batch dimension, so same position in the subsequent element in the batch
# The new stride is the original stride but the order changed
# A reshape would not change stride order, it would just change the stride to match the reshaped dimensions
# Swapping strides is what changes the order
t = torch.Tensor([[[0,1,2],[3,4,5]], [[6,7,8],[9,10,11]]])
t.permute((2,1,0))

tensor([[[ 0.,  6.],
         [ 3.,  9.]],

        [[ 1.,  7.],
         [ 4., 10.]],

        [[ 2.,  8.],
         [ 5., 11.]]])

In [None]:
# Permuting and slicing make the tensors not contiguous, as the physical order in memory no longer matches the logical memory
t.permute((2,1,0)).is_contiguous()

False