## PyTorch Tensor Slicing and Concatenation

In [1]:
import torch

## Slicing and Concatenation

### Indexing and Slicing

Prepare target tensor.

In [2]:
x = torch.FloatTensor([[[1, 2],
                        [3, 4]],
                       [[5, 6],
                        [7, 8]],
                       [[9, 10],
                        [11, 12]]])
print(x.size())

torch.Size([3, 2, 2])


Access to certain dimension.

In [3]:
print(x[0])
print(x[0,:])
print(x[0,:,:])

tensor([[1., 2.],
        [3., 4.]])
tensor([[1., 2.],
        [3., 4.]])
tensor([[1., 2.],
        [3., 4.]])


In [4]:
print(x[-1])
print(x[-1,:])
print(x[-1,:,:])

tensor([[ 9., 10.],
        [11., 12.]])
tensor([[ 9., 10.],
        [11., 12.]])
tensor([[ 9., 10.],
        [11., 12.]])


In [5]:
print(x[:,0,:])

tensor([[ 1.,  2.],
        [ 5.,  6.],
        [ 9., 10.]])


In [6]:
print(x[1:3, :, :].size())
print(x[:, :1, :].size())
print(x[:, :-1, :].size())

torch.Size([2, 2, 2])
torch.Size([3, 1, 2])
torch.Size([3, 1, 2])


### split: Split tensor to desirable shapes.

In [7]:
x = torch.FloatTensor(10,4)
x

tensor([[1.8615e+25, 5.5358e-11, 4.3059e+21, 3.0734e+29],
        [1.8467e+20, 6.7293e-04, 4.7429e+30, 1.2567e+19],
        [6.7452e+25, 7.0816e+31, 1.6530e+19, 4.6530e+33],
        [6.5357e+28, 6.7131e+22, 1.1257e+24, 1.8467e+20],
        [6.7674e-04, 4.7429e+30, 1.2567e+19, 6.7452e+25],
        [7.0816e+31, 1.6530e+19, 4.6530e+33, 1.5997e+34],
        [1.4607e-19, 1.3472e-08, 7.5553e+28, 5.2839e-11],
        [1.2557e+19, 1.7612e+19, 6.7721e+22, 6.8906e+22],
        [2.7912e+03, 3.4453e-12, 1.1446e+24, 5.0778e+31],
        [4.2964e+24, 4.4422e-11, 7.3972e+31, 7.1560e+22]])

In [8]:
x.shape

torch.Size([10, 4])

In [9]:
x.size()

torch.Size([10, 4])

In [11]:
splits = x.split(4,dim=0) # 0번 dim이 4가 되도록 쪼갠다 ex: 10 = 4 + 4+ 2
splits

(tensor([[1.8615e+25, 5.5358e-11, 4.3059e+21, 3.0734e+29],
         [1.8467e+20, 6.7293e-04, 4.7429e+30, 1.2567e+19],
         [6.7452e+25, 7.0816e+31, 1.6530e+19, 4.6530e+33],
         [6.5357e+28, 6.7131e+22, 1.1257e+24, 1.8467e+20]]),
 tensor([[6.7674e-04, 4.7429e+30, 1.2567e+19, 6.7452e+25],
         [7.0816e+31, 1.6530e+19, 4.6530e+33, 1.5997e+34],
         [1.4607e-19, 1.3472e-08, 7.5553e+28, 5.2839e-11],
         [1.2557e+19, 1.7612e+19, 6.7721e+22, 6.8906e+22]]),
 tensor([[2.7912e+03, 3.4453e-12, 1.1446e+24, 5.0778e+31],
         [4.2964e+24, 4.4422e-11, 7.3972e+31, 7.1560e+22]]))

In [17]:
for s in splits: # 원하는 모양으로 쪼갠다.
    print(s.size())

torch.Size([4, 4])
torch.Size([4, 4])
torch.Size([2, 4])


### chunk: Split tensor to number of chunks.

In [19]:
x = torch.FloatTensor(8,4)
x

tensor([[1.0194e-38, 1.0469e-38, 1.0010e-38, 8.9081e-39],
        [8.9082e-39, 5.9694e-39, 8.9082e-39, 1.0194e-38],
        [9.1837e-39, 4.6837e-39, 6.9796e-39, 9.0000e-39],
        [1.0561e-38, 1.0653e-38, 4.1327e-39, 8.9082e-39],
        [9.8265e-39, 9.4592e-39, 1.0561e-38, 1.0653e-38],
        [1.0469e-38, 9.5510e-39, 1.0378e-38, 8.9082e-39],
        [1.0653e-38, 1.1204e-38, 1.0653e-38, 1.0194e-38],
        [8.4490e-39, 1.1020e-38, 1.0378e-38, 8.9082e-39]])

In [20]:
chunks = x.chunk(3, dim = 0) # 원하는 개수로 나눈다.

for c in chunks:
    print(c.size())

torch.Size([3, 4])
torch.Size([3, 4])
torch.Size([2, 4])


### index_select: Select elements by using dimension index.

In [21]:
x = torch.FloatTensor([[[1, 1],
                        [2, 2]],
                       [[3, 3],
                        [4, 4]],
                       [[5, 5],
                        [6, 6]]])
indice = torch.LongTensor([2,1])

print(x.size())

torch.Size([3, 2, 2])


In [22]:
y = x.index_select(dim = 0, index = indice) # 인덱스를 select 해서 합쳐준다.

print(y)
print(y.size())

tensor([[[5., 5.],
         [6., 6.]],

        [[3., 3.],
         [4., 4.]]])
torch.Size([2, 2, 2])


### cat: Concatenation of multiple tensors in the list.

In [23]:
x = torch.FloatTensor([[1, 2, 3],
                       [4, 5, 6],
                       [7, 8, 9]])
y = torch.FloatTensor([[10, 11, 12],
                       [13, 14, 15],
                       [16, 17, 18]])

print(x.size(), y.size())

torch.Size([3, 3]) torch.Size([3, 3])


In [24]:
z = torch.cat([x, y], dim=0)
print(z)
print(z.size())

tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 7.,  8.,  9.],
        [10., 11., 12.],
        [13., 14., 15.],
        [16., 17., 18.]])
torch.Size([6, 3])


In [25]:
z = torch.cat([x,y],dim = -1)
print(z)
print(z.size())

tensor([[ 1.,  2.,  3., 10., 11., 12.],
        [ 4.,  5.,  6., 13., 14., 15.],
        [ 7.,  8.,  9., 16., 17., 18.]])
torch.Size([3, 6])


In [26]:
z = torch.cat([x,y],dim = 1)
print(z)
print(z.size())

tensor([[ 1.,  2.,  3., 10., 11., 12.],
        [ 4.,  5.,  6., 13., 14., 15.],
        [ 7.,  8.,  9., 16., 17., 18.]])
torch.Size([3, 6])


### stack: Stacking of multiple tensors in the list.

In [27]:
z = torch.stack([x, y])
print(z)
print(z.size())

tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.],
         [ 7.,  8.,  9.]],

        [[10., 11., 12.],
         [13., 14., 15.],
         [16., 17., 18.]]])
torch.Size([2, 3, 3])


Or you can specify the dimension. Default is 0.

In [28]:
z = torch.stack([x, y], dim=-1)
print(z)
print(z.size())

tensor([[[ 1., 10.],
         [ 2., 11.],
         [ 3., 12.]],

        [[ 4., 13.],
         [ 5., 14.],
         [ 6., 15.]],

        [[ 7., 16.],
         [ 8., 17.],
         [ 9., 18.]]])
torch.Size([3, 3, 2])


In [29]:
z = torch.stack([x, y], dim=0)
print(z)
print(z.size())

tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.],
         [ 7.,  8.,  9.]],

        [[10., 11., 12.],
         [13., 14., 15.],
         [16., 17., 18.]]])
torch.Size([2, 3, 3])


### Implement 'stack' function by using 'cat'.

In [30]:
# z = torch.stack([x, y])
z = torch.cat([x.unsqueeze(0), y.unsqueeze(0)], dim=0)
print(z)
print(z.size())

tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.],
         [ 7.,  8.,  9.]],

        [[10., 11., 12.],
         [13., 14., 15.],
         [16., 17., 18.]]])
torch.Size([2, 3, 3])


### Useful Trick: Merge results from iterations

In [31]:
result = []
for i in range(5):
    x = torch.FloatTensor(2, 2)
    result += [x]

result = torch.stack(result)
result.size()

torch.Size([5, 2, 2])