In [8]:
import numpy as np
import torch
from einops import rearrange

# 2d case

In [19]:
np_tensor = np.arange(6).reshape(2, 3)
np_tensor

array([[0, 1, 2],
       [3, 4, 5]])

In [20]:
torch_tensor = torch.tensor(np_tensor)
torch_tensor

tensor([[0, 1, 2],
        [3, 4, 5]])

In [21]:
np.transpose(np_tensor, (1,0))

array([[0, 3],
       [1, 4],
       [2, 5]])

In [25]:
torch.transpose(torch_tensor, 0, 1)

tensor([[0, 3],
        [1, 4],
        [2, 5]])

In [31]:
torch.transpose(torch_tensor, 1, 0) # does not change Returns a tensor that is a transposed version of input. The given dimensions dim0 and dim1 are swapped.



tensor([[0, 3],
        [1, 4],
        [2, 5]])

# 3d case

In [2]:
np_tensor = np.arange(6).reshape(2, 3, 4)
np_tensor

array([[[ 0,  1,  2,  3],
        [ 4,  5,  6,  7],
        [ 8,  9, 10, 11]],

       [[12, 13, 14, 15],
        [16, 17, 18, 19],
        [20, 21, 22, 23]]])

In [3]:
torch_tensor = torch.tensor(np_tensor)
torch_tensor

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])

In [6]:
np.transpose(np_tensor, (2, 1, 0))

array([[[ 0, 12],
        [ 4, 16],
        [ 8, 20]],

       [[ 1, 13],
        [ 5, 17],
        [ 9, 21]],

       [[ 2, 14],
        [ 6, 18],
        [10, 22]],

       [[ 3, 15],
        [ 7, 19],
        [11, 23]]])

In [9]:
rearrange(torch_tensor, 'b c h -> h c b')

tensor([[[ 0, 12],
         [ 4, 16],
         [ 8, 20]],

        [[ 1, 13],
         [ 5, 17],
         [ 9, 21]],

        [[ 2, 14],
         [ 6, 18],
         [10, 22]],

        [[ 3, 15],
         [ 7, 19],
         [11, 23]]])

In [36]:
A = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23]
A = torch.tensor(A).reshape(4, 3, 2) * 1.0
A

tensor([[[ 0.,  1.],
         [ 2.,  3.],
         [ 4.,  5.]],

        [[ 6.,  7.],
         [ 8.,  9.],
         [10., 11.]],

        [[12., 13.],
         [14., 15.],
         [16., 17.]],

        [[18., 19.],
         [20., 21.],
         [22., 23.]]])

In [43]:
A_reshaped = A.reshape(4, 3, 2)
A_reshaped

tensor([[[ 0.,  1.],
         [ 2.,  3.],
         [ 4.,  5.]],

        [[ 6.,  7.],
         [ 8.,  9.],
         [10., 11.]],

        [[12., 13.],
         [14., 15.],
         [16., 17.]],

        [[18., 19.],
         [20., 21.],
         [22., 23.]]])

In [45]:
# A.permute(1,0,2).reshape(12, 2)
A_permute = torch.permute(A, (1, 0, 2))
A_permute

tensor([[[ 0.,  1.],
         [ 6.,  7.],
         [12., 13.],
         [18., 19.]],

        [[ 2.,  3.],
         [ 8.,  9.],
         [14., 15.],
         [20., 21.]],

        [[ 4.,  5.],
         [10., 11.],
         [16., 17.],
         [22., 23.]]])

In [46]:
A_mat = A_permute.reshape(12, 2)
A_mat

tensor([[ 0.,  1.],
        [ 6.,  7.],
        [12., 13.],
        [18., 19.],
        [ 2.,  3.],
        [ 8.,  9.],
        [14., 15.],
        [20., 21.],
        [ 4.,  5.],
        [10., 11.],
        [16., 17.],
        [22., 23.]])

In [47]:
A

tensor([[[ 0.,  1.],
         [ 2.,  3.],
         [ 4.,  5.]],

        [[ 6.,  7.],
         [ 8.,  9.],
         [10., 11.]],

        [[12., 13.],
         [14., 15.],
         [16., 17.]],

        [[18., 19.],
         [20., 21.],
         [22., 23.]]])

In [53]:
A

tensor([[[ 0.,  1.],
         [ 2.,  3.],
         [ 4.,  5.]],

        [[ 6.,  7.],
         [ 8.,  9.],
         [10., 11.]],

        [[12., 13.],
         [14., 15.],
         [16., 17.]],

        [[18., 19.],
         [20., 21.],
         [22., 23.]]])

In [52]:
rearrange(A, 'b h (d l) -> b h d l', d=2)

tensor([[[[ 0.],
          [ 1.]],

         [[ 2.],
          [ 3.]],

         [[ 4.],
          [ 5.]]],


        [[[ 6.],
          [ 7.]],

         [[ 8.],
          [ 9.]],

         [[10.],
          [11.]]],


        [[[12.],
          [13.]],

         [[14.],
          [15.]],

         [[16.],
          [17.]]],


        [[[18.],
          [19.]],

         [[20.],
          [21.]],

         [[22.],
          [23.]]]])

# 4D

In [13]:
import torch
torch.manual_seed(0)
torch_iuput = torch.randn(1, 1, 1, 729, 1152)
torch_iuput[0,0,0,:3,:4]

tensor([[-1.1258, -1.1524, -0.2506, -0.4339],
        [ 0.4047, -0.6549,  0.0521,  0.3401],
        [ 0.2245,  0.2179, -0.9257,  0.3448]])

In [14]:
torch.permute(torch_iuput, (4, 3, 2, 1, 0)).squeeze(dim=-1)[:4, :3, 0, 0]

tensor([[-1.1258,  0.4047,  0.2245],
        [-1.1524, -0.6549,  0.2179],
        [-0.2506,  0.0521, -0.9257],
        [-0.4339,  0.3401,  0.3448]])