# PyTorch Tensor Manipulations

In [1]:
import torch

## Tensor Shaping

### reshape: Change Tensor Shape

In [5]:
x = torch.FloatTensor([[[1, 2],
                        [3, 4]],
                       [[5, 6],
                        [7, 8]],
                       [[9, 10],
                        [11, 12]]])

print(x.size())

torch.Size([3, 2, 2])


In [6]:
print(x.reshape(12)) # 12 = 3 * 2 * 2
print(x.reshape(-1))

tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.])
tensor([ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10., 11., 12.])


In [7]:
print(x.reshape(3, 4)) # 3 * 4 = 3 * 2 * 2
print(x.reshape(3, -1))

tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.]])
tensor([[ 1.,  2.,  3.,  4.],
        [ 5.,  6.,  7.,  8.],
        [ 9., 10., 11., 12.]])


In [7]:
print(x.reshape(3, 1, 4))
print(x.reshape(-1, 1, 4))

tensor([[[ 1.,  2.,  3.,  4.]],

        [[ 5.,  6.,  7.,  8.]],

        [[ 9., 10., 11., 12.]]])
tensor([[[ 1.,  2.,  3.,  4.]],

        [[ 5.,  6.,  7.,  8.]],

        [[ 9., 10., 11., 12.]]])


In [None]:
print(x.reshape(3, 2, 2, 1))

You can use 'view()' instead of 'reshape()' in some cases.

- https://pytorch.org/docs/stable/tensor_view.html
- https://pytorch.org/docs/stable/tensors.html?highlight=view#torch.Tensor.view

### squeeze: Remove dimension which has only one element.

In [8]:
x = torch.FloatTensor([[[1, 2],
                        [3, 4]]])
print(x.size())

torch.Size([1, 2, 2])


Remove any dimension which has only one element.

In [9]:
print(x.squeeze())
print(x.squeeze().size())
y = x.squeeze()
print(y.squeeze())
print(y.squeeze().size())

tensor([[1., 2.],
        [3., 4.]])
torch.Size([2, 2])
tensor([[1., 2.],
        [3., 4.]])
torch.Size([2, 2])


Remove certain dimension, if it has only one element.
If it is not, there would be no change.

In [14]:
print(x)
print(x.size())
print(x.squeeze(0).size())
print(x.squeeze(0))
print(x.squeeze(1).size())
print(x.squeeze(1))
print(x.squeeze(2).size())
print(x.squeeze(2))
print(x.squeeze(-3).size())
print(x.squeeze(-3))

tensor([[[1., 2.],
         [3., 4.]]])
torch.Size([1, 2, 2])
torch.Size([2, 2])
tensor([[1., 2.],
        [3., 4.]])
torch.Size([1, 2, 2])
tensor([[[1., 2.],
         [3., 4.]]])
torch.Size([1, 2, 2])
tensor([[[1., 2.],
         [3., 4.]]])
torch.Size([2, 2])
tensor([[1., 2.],
        [3., 4.]])


### unsqueeze: Insert dimension at certain index.

In [15]:
x = torch.FloatTensor([[1, 2],
                       [3, 4]])
print(x.size())

torch.Size([2, 2])


In [17]:
print(x.unsqueeze(0))
print(x.unsqueeze(0).size())
print(x.size())

tensor([[[1., 2.],
         [3., 4.]]])
torch.Size([1, 2, 2])
torch.Size([2, 2])


In [18]:
print(x.unsqueeze(2))
print(x.unsqueeze(2).size())
print(x.unsqueeze(-1))
print(x.unsqueeze(-1).size())
print(x.reshape(2, 2, -1))
print(x.reshape(2, 2, -1).size())
print(x.size())

tensor([[[1.],
         [2.]],

        [[3.],
         [4.]]])
torch.Size([2, 2, 1])
tensor([[[1.],
         [2.]],

        [[3.],
         [4.]]])
torch.Size([2, 2, 1])
tensor([[[1.],
         [2.]],

        [[3.],
         [4.]]])
torch.Size([2, 2, 1])
torch.Size([2, 2])
