## PyTorch Tensor Slicing and Concatenation

In [1]:
import torch

## Slicing and Concatenation

### Indexing and Slicing

Prepare target index.

In [2]:
x = torch.FloatTensor([[[1, 2],
                       [3, 4]],
                      [[5, 6],
                      [7, 8]],
                      [[9, 10],
                      [11, 12]]])
print(x.size())

torch.Size([3, 2, 2])


Access to certain dimension

In [3]:
print(x[0])
print(x[0, :])
print(x[0, :, :])

tensor([[1., 2.],
        [3., 4.]])
tensor([[1., 2.],
        [3., 4.]])
tensor([[1., 2.],
        [3., 4.]])


In [4]:
print(x[-1])
print(x[-1, :])
print(x[-1, :, :])

tensor([[ 9., 10.],
        [11., 12.]])
tensor([[ 9., 10.],
        [11., 12.]])
tensor([[ 9., 10.],
        [11., 12.]])


In [6]:
print(x[:, 0, :])

tensor([[ 1.,  2.],
        [ 5.,  6.],
        [ 9., 10.]])


Access by range. Note that the number of dimensions would not be changed.

In [7]:
print(x[1:3, : , :].size())
print(x[:, :1, :].size()) # 1이라는 dimension이 사라지지 않음.
print(x[:, :-1, :].size()) # 마지막 전까지 다 가져오는것도 똑같음.

torch.Size([2, 2, 2])
torch.Size([3, 1, 2])
torch.Size([3, 1, 2])


## Split : Split tensor to desirable shapes.

In [14]:
x = torch.FloatTensor(10, 4)

In [18]:
splits = x.split(4, dim = 0) # dimension 0번째를 4로 되게 쪼개고 나머지...
# 10 = 4 + 4 + 2

for s in splits:
    print(s.size())

torch.Size([4, 4])
torch.Size([4, 4])


In [16]:
splits

(tensor([[-1.2516e-12,  7.5250e-43, -9.9947e-10,  7.5250e-43],
         [ 1.4013e-45,  7.0065e-45, -1.6758e-12,  7.5250e-43],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  0.0000e+00]]),
 tensor([[1.4013e-45, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00]]),
 tensor([[0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
         [2.1019e-44, 0.0000e+00, 0.0000e+00, 0.0000e+00]]))

## chunk : Split tensor to number of chunks.

In [17]:
x = torch.FloatTensor(8, 4)

In [19]:
chunks = x.chunk(3, dim = 0) # dimenison 0을 3개로 쪼개기

for c in chunks:
    print(c.size())

torch.Size([3, 4])
torch.Size([3, 4])
torch.Size([2, 4])


## index_select : Select elements by using dimension index.

In [20]:
x = torch.FloatTensor([[[1, 1],
                       [2, 2]],
                      [[3, 3],
                      [4, 4]],
                      [[5, 5],
                      [6, 6]]])
indice = torch.LongTensor([2, 1])
print(x.size())

torch.Size([3, 2, 2])


In [21]:
y = x.index_select(dim = 0, index = indice)

print(y)
print(y.size())

tensor([[[5., 5.],
         [6., 6.]],

        [[3., 3.],
         [4., 4.]]])
torch.Size([2, 2, 2])


## cat : Concatenation of multiple tensors in the list.

In [22]:
x = torch.FloatTensor([[1, 2, 3],
                      [4, 5, 6],
                      [7, 8, 9]])
y = torch.FloatTensor([[10, 11, 12],
                      [13, 14, 15],
                      [16, 17, 18]])
print(x.size(), y.size())

torch.Size([3, 3]) torch.Size([3, 3])


In [23]:
z = torch.cat([x, y], dim = 0)
print(z)
print(z.size())

tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 7.,  8.,  9.],
        [10., 11., 12.],
        [13., 14., 15.],
        [16., 17., 18.]])
torch.Size([6, 3])


In [25]:
z = torch.cat([x, y], dim = -1)
print(z)
print(z.size())

tensor([[ 1.,  2.,  3., 10., 11., 12.],
        [ 4.,  5.,  6., 13., 14., 15.],
        [ 7.,  8.,  9., 16., 17., 18.]])
torch.Size([3, 6])


## stack : Stacking of mulitple tensors in the list.

In [26]:
z = torch.stack([x, y])
print(z)
print(z.size())

tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.],
         [ 7.,  8.,  9.]],

        [[10., 11., 12.],
         [13., 14., 15.],
         [16., 17., 18.]]])
torch.Size([2, 3, 3])


Or you can specify the dimension. Default is 0.

In [27]:
z = torch.stack([x, y], dim = -1)
print(z)
print(z.size())

tensor([[[ 1., 10.],
         [ 2., 11.],
         [ 3., 12.]],

        [[ 4., 13.],
         [ 5., 14.],
         [ 6., 15.]],

        [[ 7., 16.],
         [ 8., 17.],
         [ 9., 18.]]])
torch.Size([3, 3, 2])


### Implement 'stack' function by using 'cat'

In [28]:
z = torch.cat([x.unsqueeze(0), y.unsqueeze(0)], dim = 0)
print(z)
print(z.size())

tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.],
         [ 7.,  8.,  9.]],

        [[10., 11., 12.],
         [13., 14., 15.],
         [16., 17., 18.]]])
torch.Size([2, 3, 3])


## Useful Trick : Merge results from iterations

In [30]:
result = [] # 빈 리스트를 만듦.
for i in range(5):
    x = torch.FloatTensor(2, 2)
    result += [x]

# print(result.size()) 'list' object has no attribute 'size'
result = torch.stack(result)
print(result.size())

torch.Size([5, 2, 2])


데이터가 너무 클때 사용할 수 있는 trick