# Tensor Shape Manipulation

Notebook on how to use the different reshaping functions for high dimensional tensors. Focuses on:
1. `torch.view() / torch.reshape()`
2. `torch.permute()`

## Example Tensor 

The examples in this notebook will start with the 4D tensor below with shape 2 x 3 x 2 x 2 (2 RGB 2x2 images).

```python
[
    [
        [
            [0,1],
            [2,3]
        ],
        [
            [4,5],
            [6,7]
        ],
        [
            [8,9],
            [10,11]
        ],
    ],
    [
        [
            [12,13],
            [14,15]
        ],
        [
            [16,17],
            [18,19]
        ],
        [
            [20,21],
            [22,23]
        ],
    ]
]
```

## How is this tensor stored in memory? 

[Great reference!](https://blog.ezyang.com/2019/05/pytorch-internals/)

Under the hood, PyTorch stores tensors as multi-dimensional arrays as contiguous blocks of memory. We also store additional metadata such as the size, device, dtype and **stride**. Stride is used for dense tensors, and is used under the hood for indexing logic. Since all the data is stored as a contiguous block of memory, the stride tells us which physical indices to retrieve data for given an indexing query. 

For example, if we have a 2x2 tensor and we want to access the bottom left element, we would do `tensor[1,0]`. The stride of this tensor would be `(2,1)`, and we would fetch the (2x1 + 1x0) element using that stride. If we had a 2x3x2x2 tensor, the stride would be `(12,4,2,1)`. To get the very last element, we would do `tensor[1,2,1,1]` and we would access the (12x1 + 4x2 + 2x1 + 1x1) = 23rd element. The stride is the product of all numbers to the right, and we implicitly add a 1 on the end. 

In [173]:
# Imports

import torch
import time
from collections import defaultdict

In [174]:
EXAMPLE_TENSOR = torch.Tensor([
    [
        [
            [0,1],
            [2,3]
        ],
        [
            [4,5],
            [6,7]
        ],
        [
            [8,9],
            [10,11]
        ],
    ],
    [
        [
            [12,13],
            [14,15]
        ],
        [
            [16,17],
            [18,19]
        ],
        [
            [20,21],
            [22,23]
        ],
    ]
])
# I could have reshaped it, but the whole point is to learn how these work
# Check we wrote the tensor correctly
assert EXAMPLE_TENSOR.shape == (2,3,2,2)

### Benchmarking

In [175]:
# Util function for timing the execution time
def parent_decorator(num_iterations=10000):
    def average_execution_time(func):
        def wrapper(*args, **kwargs):
            stats = defaultdict(int)
            stats["min_time"] = float("inf")
            for i in range(num_iterations):
                start_time = time.time()
                func(*args, **kwargs)
                end_time = time.time()
                stats["total_time"] += end_time - start_time
                stats["num_times"] += 1
                stats["max_time"] = max(end_time - start_time, stats["max_time"])
                stats["min_time"] = min(end_time - start_time, stats["min_time"])
            stats["avg_time"] = stats["total_time"] / stats["num_times"]
            print(stats)
        return wrapper
    return average_execution_time


@parent_decorator()
def reshape(t: torch.Tensor, shape: list[int]) -> torch.Tensor:
    return t.reshape(shape)

@parent_decorator()
def view(t: torch.Tensor, shape: list[int]) -> torch.Tensor:
    return t.view(shape)

@parent_decorator()
def permute(t: torch.Tensor, permute_shape: list[int]) -> torch.Tensor:
    return t.permute(permute_shape)

BENCHMARK_TENSOR = torch.cat([EXAMPLE_TENSOR for i in range(10000)])

In [176]:
# The execution times for each method is pretty similar, because all of these return a view
# We can't benchmark a tensor that needs to be copied because then view and permute will fail
# Given these findings, to raise issues if a copy is required, view should be used
# If not having errors is preferred, then reshape is better
reshape(BENCHMARK_TENSOR, (20000,2,3,2))
view(BENCHMARK_TENSOR, (20000,2,3,2))


permute(BENCHMARK_TENSOR, (0,1,2,3))

defaultdict(<class 'int'>, {'min_time': 7.152557373046875e-07, 'total_time': 0.013417482376098633, 'num_times': 10000, 'max_time': 0.002130270004272461, 'avg_time': 1.3417482376098633e-06})
defaultdict(<class 'int'>, {'min_time': 7.152557373046875e-07, 'total_time': 0.01190328598022461, 'num_times': 10000, 'max_time': 0.0008130073547363281, 'avg_time': 1.190328598022461e-06})
defaultdict(<class 'int'>, {'min_time': 7.152557373046875e-07, 'total_time': 0.014023065567016602, 'num_times': 10000, 'max_time': 0.0009658336639404297, 'avg_time': 1.4023065567016602e-06})


### View / Reshape

In [177]:
EXAMPLE_TENSOR.view((2,2,2,3))

tensor([[[[ 0.,  1.,  2.],
          [ 3.,  4.,  5.]],

         [[ 6.,  7.,  8.],
          [ 9., 10., 11.]]],


        [[[12., 13., 14.],
          [15., 16., 17.]],

         [[18., 19., 20.],
          [21., 22., 23.]]]])

In [178]:
EXAMPLE_TENSOR.reshape((2,2,2,3))

tensor([[[[ 0.,  1.,  2.],
          [ 3.,  4.,  5.]],

         [[ 6.,  7.,  8.],
          [ 9., 10., 11.]]],


        [[[12., 13., 14.],
          [15., 16., 17.]],

         [[18., 19., 20.],
          [21., 22., 23.]]]])

In [179]:
# PATCH EMBEDDING + MERGING EXAMPLE
# Imagine we had a batch of patch embeddings, and turn into a grid of patch embeddings
# Here we assume the embedding dimension is 3
p_embed = EXAMPLE_TENSOR.view(2, -1, 3)
print(p_embed)
# Now we want a grid of patches, i.e split the middle dimension into 2
# In the patch merge layer, if we had a merge window of 2, we would combine every 4th row in the first printed tensor
grid_embed = p_embed.view(2, 2, 2, 3)
print(grid_embed)

x0 = grid_embed[:, 0::2, 0::2, :]
x1 = grid_embed[:, 0::2, 1::2, :]
x2 = grid_embed[:, 1::2, 0::2, :]
x3 = grid_embed[:, 1::2, 1::2, :]
merged = torch.concat([x0, x1, x2, x3], dim=-1)
# The merged results should have the first block as a 1-D array in the first tensor, followed by the second one
print(merged)
# Finally we flatten it so we just have 3 dimensions (i.e B x num_merged_patches x (4x3))
# The num of merged patches is 1 for each image
merged.view(2, -1, 12)

tensor([[[ 0.,  1.,  2.],
         [ 3.,  4.,  5.],
         [ 6.,  7.,  8.],
         [ 9., 10., 11.]],

        [[12., 13., 14.],
         [15., 16., 17.],
         [18., 19., 20.],
         [21., 22., 23.]]])
tensor([[[[ 0.,  1.,  2.],
          [ 3.,  4.,  5.]],

         [[ 6.,  7.,  8.],
          [ 9., 10., 11.]]],


        [[[12., 13., 14.],
          [15., 16., 17.]],

         [[18., 19., 20.],
          [21., 22., 23.]]]])
tensor([[[[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11.]]],


        [[[12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23.]]]])


tensor([[[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11.]],

        [[12., 13., 14., 15., 16., 17., 18., 19., 20., 21., 22., 23.]]])

In [180]:
# Flattening is the same as reshaping
# It uses reshape under the hood, and is a convenience function
# Same with unsqueeze
print(EXAMPLE_TENSOR.flatten(start_dim=-2, end_dim=-1))
EXAMPLE_TENSOR.flatten(start_dim=-2, end_dim=-1).shape

print(EXAMPLE_TENSOR.reshape(2,3,4))

tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])
tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.],
         [ 8.,  9., 10., 11.]],

        [[12., 13., 14., 15.],
         [16., 17., 18., 19.],
         [20., 21., 22., 23.]]])


In [181]:
EXAMPLE_TENSOR.unsqueeze(3).shape

torch.Size([2, 3, 2, 1, 2])

### Permute

The main difference with Permute is that it changes the order of the underlying elements. Previously, view and reshape altered the shape and stride so that the elements are in order, but just in a different configuration. 

In [182]:
# Permute the last two dimensions
# Visually, it means that for each 2x2 matrix, we collect the data across columns instead of across rows
# For example, [[0,1],[2,3]] -> [[0,2], [1,3]]
torch.permute(EXAMPLE_TENSOR, (0,1,3,2))

tensor([[[[ 0.,  2.],
          [ 1.,  3.]],

         [[ 4.,  6.],
          [ 5.,  7.]],

         [[ 8., 10.],
          [ 9., 11.]]],


        [[[12., 14.],
          [13., 15.]],

         [[16., 18.],
          [17., 19.]],

         [[20., 22.],
          [21., 23.]]]])

In [183]:
t = torch.Tensor([[0,1,2], [3,4,5]])
t.permute((1,0))

tensor([[0., 3.],
        [1., 4.],
        [2., 5.]])

In [184]:
# Here we reverse the dimension
# The columns in the original tensor switch to the batch dimension
# The row dimension remains unchanged
# The new column dimension is the original batch dimension, so same position in the subsequent element in the batch
# The new stride is the original stride but the order changed
# A reshape would not change stride order, it would just change the stride to match the reshaped dimensions
# Swapping strides is what changes the order
t = torch.Tensor([[[0,1,2],[3,4,5]], [[6,7,8],[9,10,11]]])
t.permute((2,1,0))

tensor([[[ 0.,  6.],
         [ 3.,  9.]],

        [[ 1.,  7.],
         [ 4., 10.]],

        [[ 2.,  8.],
         [ 5., 11.]]])

In [185]:
# Permuting and slicing make the tensors not contiguous, as the physical order in memory no longer matches the logical memory
t.permute((2,1,0)).is_contiguous()

False

### Window Splitter Example

In [193]:
# Imagine we have a grid of (merged) patch embeddings
# Has shape B x H x W x C
# We will split H and W by M in each dimension, so H and W need to be slightly bigger
# Create a 2 x 6 x 6 x 3 tensor
t = torch.arange(2*6*6*3).view((2,6,6,3))
t

tensor([[[[  0,   1,   2],
          [  3,   4,   5],
          [  6,   7,   8],
          [  9,  10,  11],
          [ 12,  13,  14],
          [ 15,  16,  17]],

         [[ 18,  19,  20],
          [ 21,  22,  23],
          [ 24,  25,  26],
          [ 27,  28,  29],
          [ 30,  31,  32],
          [ 33,  34,  35]],

         [[ 36,  37,  38],
          [ 39,  40,  41],
          [ 42,  43,  44],
          [ 45,  46,  47],
          [ 48,  49,  50],
          [ 51,  52,  53]],

         [[ 54,  55,  56],
          [ 57,  58,  59],
          [ 60,  61,  62],
          [ 63,  64,  65],
          [ 66,  67,  68],
          [ 69,  70,  71]],

         [[ 72,  73,  74],
          [ 75,  76,  77],
          [ 78,  79,  80],
          [ 81,  82,  83],
          [ 84,  85,  86],
          [ 87,  88,  89]],

         [[ 90,  91,  92],
          [ 93,  94,  95],
          [ 96,  97,  98],
          [ 99, 100, 101],
          [102, 103, 104],
          [105, 106, 107]]],


        [[[108

In [194]:
# Now we permute so and reduce the shape so we have the number of patches absorbed into the batch dimension
windowed = t.view((2,3,2,3,2,3))
windowed.permute((0,1,3,2,4,5)).reshape(-1, 2,2,3)

tensor([[[[  0,   1,   2],
          [  3,   4,   5]],

         [[ 18,  19,  20],
          [ 21,  22,  23]]],


        [[[  6,   7,   8],
          [  9,  10,  11]],

         [[ 24,  25,  26],
          [ 27,  28,  29]]],


        [[[ 12,  13,  14],
          [ 15,  16,  17]],

         [[ 30,  31,  32],
          [ 33,  34,  35]]],


        [[[ 36,  37,  38],
          [ 39,  40,  41]],

         [[ 54,  55,  56],
          [ 57,  58,  59]]],


        [[[ 42,  43,  44],
          [ 45,  46,  47]],

         [[ 60,  61,  62],
          [ 63,  64,  65]]],


        [[[ 48,  49,  50],
          [ 51,  52,  53]],

         [[ 66,  67,  68],
          [ 69,  70,  71]]],


        [[[ 72,  73,  74],
          [ 75,  76,  77]],

         [[ 90,  91,  92],
          [ 93,  94,  95]]],


        [[[ 78,  79,  80],
          [ 81,  82,  83]],

         [[ 96,  97,  98],
          [ 99, 100, 101]]],


        [[[ 84,  85,  86],
          [ 87,  88,  89]],

         [[102, 103, 104],
     

In [199]:
# What if we moved the windows to the outer dims and then permuted
windowed_2 = t.view((2,3,3,2,2,3))
windowed_2.permute((0,1,2,3,4,5)).reshape(-1, 2,2,3)

tensor([[[[  0,   1,   2],
          [  3,   4,   5]],

         [[  6,   7,   8],
          [  9,  10,  11]]],


        [[[ 12,  13,  14],
          [ 15,  16,  17]],

         [[ 18,  19,  20],
          [ 21,  22,  23]]],


        [[[ 24,  25,  26],
          [ 27,  28,  29]],

         [[ 30,  31,  32],
          [ 33,  34,  35]]],


        [[[ 36,  37,  38],
          [ 39,  40,  41]],

         [[ 42,  43,  44],
          [ 45,  46,  47]]],


        [[[ 48,  49,  50],
          [ 51,  52,  53]],

         [[ 54,  55,  56],
          [ 57,  58,  59]]],


        [[[ 60,  61,  62],
          [ 63,  64,  65]],

         [[ 66,  67,  68],
          [ 69,  70,  71]]],


        [[[ 72,  73,  74],
          [ 75,  76,  77]],

         [[ 78,  79,  80],
          [ 81,  82,  83]]],


        [[[ 84,  85,  86],
          [ 87,  88,  89]],

         [[ 90,  91,  92],
          [ 93,  94,  95]]],


        [[[ 96,  97,  98],
          [ 99, 100, 101]],

         [[102, 103, 104],
     

### Window Self Attention Example

In [None]:
# Imagine we have a (B x H/M x W/M) x M^2 x C tensor to run windowed self-attention on
# We need to split the projected tensor into 3 separate tensors, and have a head dimension.
# We want to end up with a 3 x (B x H/M x W/M) x HEADS x M^2 x C tensor
# What is the difference if we reshape it directly versus permuting?

t = torch.arange(8 * 4 * 12).reshape(8, 4, 12)
t
# The initial tensor has 12 dimensions per embedding, with a head dim of 2, 2 heads and Q,K,V respectively.
# Imagine that for the first patch, 0-3 are Q, 4-7 are K and 8-11 are V
# The first head would use 0-1 as Q, 4-5 as K and 8-9 as V
# Therefore to pay attention to the right things, we would want: 
# Q's first head as [0,1],[12,13],[24,25],[36,37]. This is the first element in the overall tensor.
# We can't keep the order, otherwise the Q tensor contains information from the same patch!

tensor([[[  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11],
         [ 12,  13,  14,  15,  16,  17,  18,  19,  20,  21,  22,  23],
         [ 24,  25,  26,  27,  28,  29,  30,  31,  32,  33,  34,  35],
         [ 36,  37,  38,  39,  40,  41,  42,  43,  44,  45,  46,  47]],

        [[ 48,  49,  50,  51,  52,  53,  54,  55,  56,  57,  58,  59],
         [ 60,  61,  62,  63,  64,  65,  66,  67,  68,  69,  70,  71],
         [ 72,  73,  74,  75,  76,  77,  78,  79,  80,  81,  82,  83],
         [ 84,  85,  86,  87,  88,  89,  90,  91,  92,  93,  94,  95]],

        [[ 96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107],
         [108, 109, 110, 111, 112, 113, 114, 115, 116, 117, 118, 119],
         [120, 121, 122, 123, 124, 125, 126, 127, 128, 129, 130, 131],
         [132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143]],

        [[144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155],
         [156, 157, 158, 159, 160, 161, 162, 163, 164, 165, 166, 167],


In [None]:
# A view of the output shape we want is incorrect, because Q for the first head now only contains info from the first patch
# But actually it should have certain dimensions from the first 4 patches
t.view(3,8,2,4,2)

tensor([[[[[  0,   1],
           [  2,   3],
           [  4,   5],
           [  6,   7]],

          [[  8,   9],
           [ 10,  11],
           [ 12,  13],
           [ 14,  15]]],


         [[[ 16,  17],
           [ 18,  19],
           [ 20,  21],
           [ 22,  23]],

          [[ 24,  25],
           [ 26,  27],
           [ 28,  29],
           [ 30,  31]]],


         [[[ 32,  33],
           [ 34,  35],
           [ 36,  37],
           [ 38,  39]],

          [[ 40,  41],
           [ 42,  43],
           [ 44,  45],
           [ 46,  47]]],


         [[[ 48,  49],
           [ 50,  51],
           [ 52,  53],
           [ 54,  55]],

          [[ 56,  57],
           [ 58,  59],
           [ 60,  61],
           [ 62,  63]]],


         [[[ 64,  65],
           [ 66,  67],
           [ 68,  69],
           [ 70,  71]],

          [[ 72,  73],
           [ 74,  75],
           [ 76,  77],
           [ 78,  79]]],


         [[[ 80,  81],
           [ 82,  83],
    

In [None]:
# Instead, we need to decompose the embedding dimension into Q,K,V and number of heads
# This reshaping splits the embedding dimension into Q,K,V and per head for each patch, a 3 x 2 x 2 tensor
# Then we want to keep the final dimension together, as this is the 1st head's 2-vector query
# We want to collect these across patches in a window, so we need to go across the within window patch dimension (size 4, position 1)
# We want to collect across the head dimension next, so the same within window patches across heads
# We collect this together for each window we have
# And finally we separate the Q, K and V. 
t.view(8,4,3,2,2).permute(2,0,3,1,4)

tensor([[[[[  0,   1],
           [ 12,  13],
           [ 24,  25],
           [ 36,  37]],

          [[  2,   3],
           [ 14,  15],
           [ 26,  27],
           [ 38,  39]]],


         [[[ 48,  49],
           [ 60,  61],
           [ 72,  73],
           [ 84,  85]],

          [[ 50,  51],
           [ 62,  63],
           [ 74,  75],
           [ 86,  87]]],


         [[[ 96,  97],
           [108, 109],
           [120, 121],
           [132, 133]],

          [[ 98,  99],
           [110, 111],
           [122, 123],
           [134, 135]]],


         [[[144, 145],
           [156, 157],
           [168, 169],
           [180, 181]],

          [[146, 147],
           [158, 159],
           [170, 171],
           [182, 183]]],


         [[[192, 193],
           [204, 205],
           [216, 217],
           [228, 229]],

          [[194, 195],
           [206, 207],
           [218, 219],
           [230, 231]]],


         [[[240, 241],
           [252, 253],
    