## PyTorch Tensor Manipulations

In [1]:
import torch

## Tensor Shaping

### reshape: Change Tensor Shape

In [2]:
x = torch.FloatTensor([[[1, 2],
                       [3, 4]],
                      [[5, 6],
                      [7, 8]],
                      [[9, 10],
                      [11, 12]]])
print(x.size())

torch.Size([3, 2, 2])


In [3]:
print(x.reshape(12)) # 12 = 3 * 2 * 2
print(x.reshape(-1))

tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.])
tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.])


In [5]:
print(x.reshape(3, 4)) # 3 * 4 = 3 * 2 * 2
print(x.reshape(3, -1))

tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.]])
tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.]])


In [6]:
print(x.reshape(3, 1, 4))
print(x.reshape(-1, 1, 4)) # 3 * 2 * 2와 같은 결과

tensor([[[ 1.,  2.,  3.,  4.]],

        [[ 5.,  6.,  7.,  8.]],

        [[ 9., 10., 11., 12.]]])
tensor([[[ 1.,  2.,  3.,  4.]],

        [[ 5.,  6.,  7.,  8.]],

        [[ 9., 10., 11., 12.]]])


In [7]:
print(x.reshape(3, 2, 2, 1))

tensor([[[[ 1.],
          [ 2.]],

         [[ 3.],
          [ 4.]]],


        [[[ 5.],
          [ 6.]],

         [[ 7.],
          [ 8.]]],


        [[[ 9.],
          [10.]],

         [[11.],
          [12.]]]])


You can use 'view()' instead of 'reshape()' in some cases.<br>
contiguous + view = reshape

### squeeze : Remove dimension which has only one element.

In [8]:
x = torch.FloatTensor([[[1, 2],
                       [3, 4]]])
print(x.size())

torch.Size([1, 2, 2])


Remove any dimension which has only one element.

In [10]:
print(x.squeeze())
print(x.squeeze().size())

tensor([[1., 2.],
        [3., 4.]])
torch.Size([2, 2])


Remove certain dimension, if it has only one element. if it is not, there would be no charge.

In [11]:
print(x.squeeze(0).size())
print(x.squeeze(1).size()) # 여기에 대해서는 squeeze가 working하지 않아서 1,2,2

torch.Size([2, 2])
torch.Size([1, 2, 2])


### unsqueeze : Insert dimension at certain index.

In [12]:
x = torch.FloatTensor([[1, 2],
                      [3, 4]])
print(x.size())

torch.Size([2, 2])


In [13]:
print(x.unsqueeze(2))
print(x.unsqueeze(-1))
print(x.reshape(2, 2, -1))

tensor([[[1.],
         [2.]],

        [[3.],
         [4.]]])
tensor([[[1.],
         [2.]],

        [[3.],
         [4.]]])
tensor([[[1.],
         [2.]],

        [[3.],
         [4.]]])
