## torch view and permute

In [1]:
import torch

In [2]:
# view vs. permute
z = torch.arange(2*4).view(2, 4)
print(z)
print()
print(z.view(4, 2))
print()
print(z.permute(1, 0))
print()
print(z.permute(0, 1))

tensor([[0, 1, 2, 3],
        [4, 5, 6, 7]])

tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7]])

tensor([[0, 4],
        [1, 5],
        [2, 6],
        [3, 7]])

tensor([[0, 1, 2, 3],
        [4, 5, 6, 7]])


In [3]:
x = torch.arange(4*3*2).view(4, 3, 2)
y = x.permute(2, 0, 1)

print(x)
print()
print(y)
print()
print(y.shape)

tensor([[[ 0,  1],
         [ 2,  3],
         [ 4,  5]],

        [[ 6,  7],
         [ 8,  9],
         [10, 11]],

        [[12, 13],
         [14, 15],
         [16, 17]],

        [[18, 19],
         [20, 21],
         [22, 23]]])

tensor([[[ 0,  2,  4],
         [ 6,  8, 10],
         [12, 14, 16],
         [18, 20, 22]],

        [[ 1,  3,  5],
         [ 7,  9, 11],
         [13, 15, 17],
         [19, 21, 23]]])

torch.Size([2, 4, 3])


## torch view an reshape

In [4]:
display(x.view(2,3,-1))
print()
display(y.view(2,3,-1))

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])




tensor([[[ 0,  2,  4,  6],
         [ 8, 10, 12, 14],
         [16, 18, 20, 22]],

        [[ 1,  3,  5,  7],
         [ 9, 11, 13, 15],
         [17, 19, 21, 23]]])

In [5]:
z = x.reshape(2,3,-1)
z[0,0,0] = 101
print((x[0,0,0], y[0,0,0]))
print()
z = y.reshape(2,3,-1)
z[0,0,0] = 707
print((x[0,0,0], y[0,0,0]))

(tensor(101), tensor(101))

(tensor(707), tensor(707))


In [6]:
# view works on contiguous tensors
print(x.is_contiguous())
print()
print(x.view(-1))

True

tensor([707,   1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
         14,  15,  16,  17,  18,  19,  20,  21,  22,  23])


In [7]:
# Reshape works on non-contugous tensors (contiguous() + view)
print(y.is_contiguous())
print()
try: 
    print(y.view(-1))
except RuntimeError as e:
    print(e)
print()

False

view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.



In [8]:
z = x.reshape(-1)
z[0] = 101
print((x[0,0,0], y[0,0,0]))
print()
z = y.reshape(-1)
z[0] = 202
print((x[0,0,0], y[0,0,0]))

(tensor(101), tensor(101))

(tensor(101), tensor(101))


In [9]:
print(y.reshape(-1))
print()
print(y.contiguous().view(-1))

tensor([101,   2,   4,   6,   8,  10,  12,  14,  16,  18,  20,  22,   1,   3,
          5,   7,   9,  11,  13,  15,  17,  19,  21,  23])

tensor([101,   2,   4,   6,   8,  10,  12,  14,  16,  18,  20,  22,   1,   3,
          5,   7,   9,  11,  13,  15,  17,  19,  21,  23])


In [10]:
z = y.contiguous()
z[0] = 303
print((x[0,0,0], y[0,0,0]))

(tensor(101), tensor(101))


## torch transpose (not recommend to use)

In [11]:
# only works for two dims
print(x.transpose(0,1))
print()
print(x.transpose(1,0))

tensor([[[101,   1],
         [  6,   7],
         [ 12,  13],
         [ 18,  19]],

        [[  2,   3],
         [  8,   9],
         [ 14,  15],
         [ 20,  21]],

        [[  4,   5],
         [ 10,  11],
         [ 16,  17],
         [ 22,  23]]])

tensor([[[101,   1],
         [  6,   7],
         [ 12,  13],
         [ 18,  19]],

        [[  2,   3],
         [  8,   9],
         [ 14,  15],
         [ 20,  21]],

        [[  4,   5],
         [ 10,  11],
         [ 16,  17],
         [ 22,  23]]])


## numpy transpose (recommend in numpy)

In [12]:
import numpy as np
x = np.arange(4).reshape((2,2))
display(x)
print()
display(x.T)
print()
display(x.transpose())
print()
display(np.transpose(x))

array([[0, 1],
       [2, 3]])




array([[0, 2],
       [1, 3]])




array([[0, 2],
       [1, 3]])




array([[0, 2],
       [1, 3]])

In [13]:
x = np.arange(6).reshape(1, 2, 3)
display(x)
print()
display(x.transpose((1, 0, 2)))
print()
display(np.transpose(x, (1, 0, 2)))

array([[[0, 1, 2],
        [3, 4, 5]]])




array([[[0, 1, 2]],

       [[3, 4, 5]]])




array([[[0, 1, 2]],

       [[3, 4, 5]]])