# Useful functions

In [2]:
import torch

  from .autonotebook import tqdm as notebook_tqdm


## Expand

In [3]:
x = torch.FloatTensor([
    [
        [1, 2],
    ],
    [
        [3, 4],
    ],
])

f = lambda x: print(x, x.size())

f(x)

tensor([[[1., 2.]],

        [[3., 4.]]]) torch.Size([2, 1, 2])


In [4]:
y = x.expand(*[2, 3, 2])

f(y)

tensor([[[1., 2.],
         [1., 2.],
         [1., 2.]],

        [[3., 4.],
         [3., 4.],
         [3., 4.]]]) torch.Size([2, 3, 2])


### Hard way

In [5]:
y = torch.cat([x, x, x], 1)

f(y)

tensor([[[1., 2.],
         [1., 2.],
         [1., 2.]],

        [[3., 4.],
         [3., 4.],
         [3., 4.]]]) torch.Size([2, 3, 2])


## Random premutation

In [6]:
x = torch.randperm(10)

f(x)

tensor([2, 4, 9, 6, 7, 5, 3, 0, 8, 1]) torch.Size([10])


## Argument max

In [7]:
x = torch.randperm(3 ** 3).reshape(3, 3, -1)

f(x)

tensor([[[ 0, 15,  1],
         [14, 22, 25],
         [19,  7, 17]],

        [[18, 10,  3],
         [16, 23, 11],
         [ 6,  8,  2]],

        [[26,  5, 20],
         [ 4,  9, 21],
         [13, 24, 12]]]) torch.Size([3, 3, 3])


In [8]:
y = x.argmax(-1)

f(y)

tensor([[1, 2, 0],
        [0, 1, 1],
        [0, 2, 1]]) torch.Size([3, 3])


## Top-k values and indices

In [9]:
values, indices = torch.topk(x, 1, -1)

f(values)
f(indices)

tensor([[[15],
         [25],
         [19]],

        [[18],
         [23],
         [ 8]],

        [[26],
         [21],
         [24]]]) torch.Size([3, 3, 1])
tensor([[[1],
         [2],
         [0]],

        [[0],
         [1],
         [1]],

        [[0],
         [2],
         [1]]]) torch.Size([3, 3, 1])


In [10]:
f(values.squeeze(-1))
f(indices.squeeze(-1))

tensor([[15, 25, 19],
        [18, 23,  8],
        [26, 21, 24]]) torch.Size([3, 3])
tensor([[1, 2, 0],
        [0, 1, 1],
        [0, 2, 1]]) torch.Size([3, 3])


In [12]:
_, indices = x.topk(k=2, dim=-1)

f(indices)

tensor([[[1, 2],
         [2, 1],
         [0, 2]],

        [[0, 1],
         [1, 0],
         [1, 0]],

        [[0, 2],
         [2, 1],
         [1, 0]]]) torch.Size([3, 3, 2])


### Application

In [22]:
f(x)

tensor([[[ 0, 15,  1],
         [14, 22, 25],
         [19,  7, 17]],

        [[18, 10,  3],
         [16, 23, 11],
         [ 6,  8,  2]],

        [[26,  5, 20],
         [ 4,  9, 21],
         [13, 24, 12]]]) torch.Size([3, 3, 3])


In [29]:
f(x.argmax(dim=-1))

tensor([[1, 2, 0],
        [0, 1, 1],
        [0, 2, 1]]) torch.Size([3, 3])


In [40]:
f(indices[:, :, 0]) # Demention reduction

tensor([[1, 2, 0],
        [0, 1, 1],
        [0, 2, 1]]) torch.Size([3, 3])


In [35]:
x.argmax(dim=-1) == indices[:, :, 0]

tensor([[True, True, True],
        [True, True, True],
        [True, True, True]])

#### Using topk

In [41]:
x.argmax(dim=-1)

tensor([[1, 2, 0],
        [0, 1, 1],
        [0, 2, 1]])

In [47]:
_, indices = x.topk(k=1, dim=-1)

f(indices)

tensor([[[1],
         [2],
         [0]],

        [[0],
         [1],
         [1]],

        [[0],
         [2],
         [1]]]) torch.Size([3, 3, 1])


In [50]:
x.argmax(dim=-1) == indices.squeeze(-1)

tensor([[True, True, True],
        [True, True, True],
        [True, True, True]])

## Sort by topk()

In [68]:
x = torch.randperm(3 ** 3).reshape(3, 3, -1)
f(x)

tensor([[[21, 15,  0],
         [ 4,  5, 11],
         [ 6, 17, 24]],

        [[ 8, 26, 18],
         [19,  3, 20],
         [ 7, 22, 25]],

        [[14,  2,  1],
         [10, 13,  9],
         [23, 12, 16]]]) torch.Size([3, 3, 3])


In [72]:
target_dim = -1

values, indices = x.topk(k=x.size(target_dim), largest=True)

f(values)

tensor([[[21, 15,  0],
         [11,  5,  4],
         [24, 17,  6]],

        [[26, 18,  8],
         [20, 19,  3],
         [25, 22,  7]],

        [[14,  2,  1],
         [13, 10,  9],
         [23, 16, 12]]]) torch.Size([3, 3, 3])


## Topk by sort()

In [74]:
k = 1

values, indices = x.sort(dim=-1, descending=True) # Descending
values, indices = values[:, :, :k], indices[:, :, :k]

f(values.squeeze(-1))
f(indices.squeeze(-1))

tensor([[21, 11, 24],
        [26, 20, 25],
        [14, 13, 23]]) torch.Size([3, 3])
tensor([[0, 2, 2],
        [1, 2, 2],
        [0, 1, 0]]) torch.Size([3, 3])


In [79]:
k = 1

values, indices = x.sort(dim=0, descending=True) # Descending
values, indices = values[:, :, :k], indices[:, :, :k]

f(values.squeeze(-1))
f(indices.squeeze(-1))

tensor([[21, 19, 23],
        [14, 10,  7],
        [ 8,  4,  6]]) torch.Size([3, 3])
tensor([[0, 1, 2],
        [2, 2, 1],
        [1, 0, 0]]) torch.Size([3, 3])
