In [13]:
import torch

# Slicing and Concatenation

## Accessing

In [14]:
x = torch.FloatTensor([
    [
        [1, 2],
        [3, 4],
    ],
    [
        [5, 6],
        [7, 8],
    ],
    [
        [9, 10],
        [11, 12],
    ],
])

f = lambda x: print(x, x.size())

f(x)

tensor([[[ 1.,  2.],
         [ 3.,  4.]],

        [[ 5.,  6.],
         [ 7.,  8.]],

        [[ 9., 10.],
         [11., 12.]]]) torch.Size([3, 2, 2])


In [15]:
# Same works

f(x[0])
f(x[0, :])
f(x[0, :, :])

tensor([[1., 2.],
        [3., 4.]]) torch.Size([2, 2])
tensor([[1., 2.],
        [3., 4.]]) torch.Size([2, 2])
tensor([[1., 2.],
        [3., 4.]]) torch.Size([2, 2])


In [16]:
f(x[-1])
f(x[-1, :])
f(x[-1, :, :])

tensor([[ 9., 10.],
        [11., 12.]]) torch.Size([2, 2])
tensor([[ 9., 10.],
        [11., 12.]]) torch.Size([2, 2])
tensor([[ 9., 10.],
        [11., 12.]]) torch.Size([2, 2])


In [17]:
f(x[:, 0])

tensor([[ 1.,  2.],
        [ 5.,  6.],
        [ 9., 10.]]) torch.Size([3, 2])


In [19]:
f(x[1:3, :, :])
f(x[:, :1, :])
f(x[:, :-1, :])

tensor([[[ 5.,  6.],
         [ 7.,  8.]],

        [[ 9., 10.],
         [11., 12.]]]) torch.Size([2, 2, 2])
tensor([[[ 1.,  2.]],

        [[ 5.,  6.]],

        [[ 9., 10.]]]) torch.Size([3, 1, 2])
tensor([[[ 1.,  2.]],

        [[ 5.,  6.]],

        [[ 9., 10.]]]) torch.Size([3, 1, 2])


## Split

In [20]:
x = torch.FloatTensor(10, 4)
f(x)

tensor([[ 3.6367e-38,  1.4013e-45, -5.4010e+20,  7.0600e+17],
        [ 1.3210e-36,  1.4013e-45,  1.3370e-36,  1.4013e-45],
        [ 0.0000e+00, -0.0000e+00,  1.7025e-27,  1.5846e+29],
        [ 6.5240e-37,  1.4013e-45,  6.5366e-37,  1.4013e-45],
        [ 6.5373e-37,  1.4013e-45,  6.5368e-37,  1.4013e-45],
        [ 6.5373e-37,  1.4013e-45,  6.5387e-37,  1.4013e-45],
        [ 6.5285e-37,  1.4013e-45,  6.5388e-37,  1.4013e-45],
        [ 6.5376e-37,  1.4013e-45,  6.5389e-37,  1.4013e-45],
        [ 6.5377e-37,  1.4013e-45,  6.5390e-37,  1.4013e-45],
        [ 6.5377e-37,  1.4013e-45,  6.5321e-37,  1.4013e-45]]) torch.Size([10, 4])


In [25]:
splits = x.split(4, dim=0)

for s in splits:
    f(s)

tensor([[ 3.6367e-38,  1.4013e-45, -5.4010e+20,  7.0600e+17],
        [ 1.3210e-36,  1.4013e-45,  1.3370e-36,  1.4013e-45],
        [ 0.0000e+00, -0.0000e+00,  1.7025e-27,  1.5846e+29],
        [ 6.5240e-37,  1.4013e-45,  6.5366e-37,  1.4013e-45]]) torch.Size([4, 4])
tensor([[6.5373e-37, 1.4013e-45, 6.5368e-37, 1.4013e-45],
        [6.5373e-37, 1.4013e-45, 6.5387e-37, 1.4013e-45],
        [6.5285e-37, 1.4013e-45, 6.5388e-37, 1.4013e-45],
        [6.5376e-37, 1.4013e-45, 6.5389e-37, 1.4013e-45]]) torch.Size([4, 4])
tensor([[6.5377e-37, 1.4013e-45, 6.5390e-37, 1.4013e-45],
        [6.5377e-37, 1.4013e-45, 6.5321e-37, 1.4013e-45]]) torch.Size([2, 4])


## Chunk

In [26]:
x = torch.FloatTensor(8, 4)
f(x)

tensor([[0.0000e+00, 5.6221e-05, 0.0000e+00, 4.7684e-06],
        [0.0000e+00, 5.6220e-05, 0.0000e+00, 4.7684e-06],
        [0.0000e+00, 5.6221e-05, 0.0000e+00, 4.7684e-06],
        [0.0000e+00, 5.6222e-05, 0.0000e+00, 4.7684e-06],
        [0.0000e+00, 5.6212e-05, 0.0000e+00, 4.7684e-06],
        [0.0000e+00, 5.6222e-05, 0.0000e+00, 4.7684e-06],
        [0.0000e+00, 5.6221e-05, 0.0000e+00, 4.7684e-06],
        [0.0000e+00, 5.6222e-05, 0.0000e+00, 4.7684e-06]]) torch.Size([8, 4])


In [28]:
chunks = x.chunk(3, dim=0)

for c in chunks:
    f(c)

tensor([[0.0000e+00, 5.6221e-05, 0.0000e+00, 4.7684e-06],
        [0.0000e+00, 5.6220e-05, 0.0000e+00, 4.7684e-06],
        [0.0000e+00, 5.6221e-05, 0.0000e+00, 4.7684e-06]]) torch.Size([3, 4])
tensor([[0.0000e+00, 5.6222e-05, 0.0000e+00, 4.7684e-06],
        [0.0000e+00, 5.6212e-05, 0.0000e+00, 4.7684e-06],
        [0.0000e+00, 5.6222e-05, 0.0000e+00, 4.7684e-06]]) torch.Size([3, 4])
tensor([[0.0000e+00, 5.6221e-05, 0.0000e+00, 4.7684e-06],
        [0.0000e+00, 5.6222e-05, 0.0000e+00, 4.7684e-06]]) torch.Size([2, 4])


## Index selection

In [29]:
x = torch.FloatTensor([
    [
        [1, 2],
        [3, 4],
    ],
    [
        [5, 6],
        [7, 8],
    ],
    [
        [9, 10],
        [11, 12],
    ]
])

f(x)

tensor([[[ 1.,  2.],
         [ 3.,  4.]],

        [[ 5.,  6.],
         [ 7.,  8.]],

        [[ 9., 10.],
         [11., 12.]]]) torch.Size([3, 2, 2])


In [30]:
indice = torch.LongTensor([2, 1])
f(indice)

tensor([2, 1]) torch.Size([2])


In [32]:
y = x.index_select(0, indice)
f(y)

tensor([[[ 9., 10.],
         [11., 12.]],

        [[ 5.,  6.],
         [ 7.,  8.]]]) torch.Size([2, 2, 2])


## Concatenation

In [34]:
x = torch.FloatTensor([
    [1, 2, 3],
    [4, 5, 6],
    [7, 8, 9]
])
y = torch.FloatTensor([
    [11, 12, 13],
    [14, 15, 16],
    [17, 18, 19]
])

f(x)
f(y)

tensor([[1., 2., 3.],
        [4., 5., 6.],
        [7., 8., 9.]]) torch.Size([3, 3])
tensor([[11., 12., 13.],
        [14., 15., 16.],
        [17., 18., 19.]]) torch.Size([3, 3])


In [35]:
z = torch.cat([x, y], 0)
f(z)

tensor([[ 1.,  2.,  3.],
        [ 4.,  5.,  6.],
        [ 7.,  8.,  9.],
        [11., 12., 13.],
        [14., 15., 16.],
        [17., 18., 19.]]) torch.Size([6, 3])


In [36]:
z = torch.cat([x, y], -1)
f(z)

tensor([[ 1.,  2.,  3., 11., 12., 13.],
        [ 4.,  5.,  6., 14., 15., 16.],
        [ 7.,  8.,  9., 17., 18., 19.]]) torch.Size([3, 6])


## Stack: Increasing of Dimention

In [38]:
z = torch.stack([x, y])
z = torch.stack([x, y], 0)
f(z)

tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.],
         [ 7.,  8.,  9.]],

        [[11., 12., 13.],
         [14., 15., 16.],
         [17., 18., 19.]]]) torch.Size([2, 3, 3])


In [40]:
z = torch.stack([x, y], -1)
f(z)

tensor([[[ 1., 11.],
         [ 2., 12.],
         [ 3., 13.]],

        [[ 4., 14.],
         [ 5., 15.],
         [ 6., 16.]],

        [[ 7., 17.],
         [ 8., 18.],
         [ 9., 19.]]]) torch.Size([3, 3, 2])


### Hard way: Using cat()

In [44]:
z = torch.cat([x.unsqueeze(0), y.unsqueeze(0)], 0)
f(z)

tensor([[[ 1.,  2.,  3.],
         [ 4.,  5.,  6.],
         [ 7.,  8.,  9.]],

        [[11., 12., 13.],
         [14., 15., 16.],
         [17., 18., 19.]]]) torch.Size([2, 3, 3])


## Useful tricks

In [47]:
result = []

for i in range(5):
    x = torch.FloatTensor(2, 2)
    result += [x]

result = torch.stack(result)
f(result)

tensor([[[ 0.0000e+00,  4.4766e+00],
         [ 1.7009e-27,  1.5846e+29]],

        [[ 2.5223e-44,  0.0000e+00],
         [ 1.7009e-27,  1.5846e+29]],

        [[ 0.0000e+00,  0.0000e+00],
         [ 1.7009e-27,  1.5846e+29]],

        [[ 0.0000e+00, -0.0000e+00],
         [ 1.6990e-27, -8.5899e+09]],

        [[        inf,  2.3694e-38],
         [ 2.3694e-38,  2.3694e-38]]]) torch.Size([5, 2, 2])
