## PyTorch Tensor Slicing and Concatenation...

In [1]:
import torch

## Slicing and Concatenation

### Indexing and Slicing
Prepare target tensor.

In [2]:
x = torch.FloatTensor([[[1, 2],
                        [3, 4]],
                       [[5, 6],
                        [7, 8]],
                       [[9, 10],
                        [11, 12]]])
print(x.size())

torch.Size([3, 2, 2])


Access to certain dimension.

In [None]:
print(x[0])
print(x[0, :])
print(x[0, :, :])
print(x[1, :, :])
print(x[: ,:, :])

In [None]:
print(x[-1])
print(x[-1, :])
print(x[-1, :, :])

In [4]:
print(x[:, 0, :])
print(x[:, :, 1])

tensor([[ 1.,  2.],
        [ 5.,  6.],
        [ 9., 10.]])
tensor([[ 2.,  4.],
        [ 6.,  8.],
        [10., 12.]])


Access by range. Note that the number of dimensions would not be changed.

In [5]:
print(x[1:3, :, :].size())
print(x[0:2, :, :].size())
print(x[0:3, :, :].size())
print(x[:, :1, :].size())
print(x[:, :1, :])
print(x[:, :-1, :].size())
print(x[:, :-1, :])

torch.Size([2, 2, 2])
torch.Size([2, 2, 2])
torch.Size([3, 2, 2])
torch.Size([3, 1, 2])
tensor([[[ 1.,  2.]],

        [[ 5.,  6.]],

        [[ 9., 10.]]])
torch.Size([3, 1, 2])
tensor([[[ 1.,  2.]],

        [[ 5.,  6.]],

        [[ 9., 10.]]])


### split: Split tensor to desirable shapes.

In [None]:
x = torch.FloatTensor(10, 4)
print(x)

In [None]:
splits = x.split(4, dim=0)
print(splits)

for s in splits:
    print(s.size())

### chunk: Split tensor to number of chunks.

In [None]:
x = torch.FloatTensor(8, 4)

In [None]:
chunks = x.chunk(3, dim=0)

for c in chunks:
    print(c.size())

### index_select: Select elements by using dimension index.

In [None]:
x = torch.FloatTensor([[[1, 1],
                        [2, 2]],
                       [[3, 3],
                        [4, 4]],
                       [[5, 5],
                        [6, 6]]])
indice = torch.LongTensor([2, 1])

print(x.size())
print(x)
print(indice)

In [None]:
y = x.index_select(dim=0, index=indice)

print(y)
print(y.size())

### cat: Concatenation of multiple tensors in the list.

In [None]:
x = torch.FloatTensor([[1, 2, 3],
                       [4, 5, 6],
                       [7, 8, 9]])
y = torch.FloatTensor([[10, 11, 12],
                       [13, 14, 15],
                       [16, 17, 18]])

print(x.size(), y.size())

In [None]:
z = torch.cat([x, y], dim=0)
print(z)
print(z.size())

In [None]:
z = torch.cat([x, y], dim=-1)
print(z)
print(z.size())

### stack: Stacking of multiple tensors in the list.

In [None]:
z = torch.stack([x, y])
print(z)
print(z.size())

Or you can specify the dimension. Default is 0.

In [None]:
z = torch.stack([x, y], dim=-1)
print(z)
print(z.size())

### Implement 'stack' function by using 'cat'.

In [None]:
# z = torch.stack([x, y])
z = torch.cat([x.unsqueeze(0), y.unsqueeze(0)], dim=0)
print(z)
print(z.size())

### Useful Trick: Merge results from iterations

In [None]:
result = []
for i in range(5):
    x = torch.FloatTensor(2, 2)
    result += [x]

result = torch.stack(result)
result.size()