# Useful functions

In [2]:
import torch

  from .autonotebook import tqdm as notebook_tqdm


## Expand

In [3]:
x = torch.FloatTensor([
    [
        [1, 2],
    ],
    [
        [3, 4],
    ],
])

f = lambda x: print(x, x.size())

f(x)

tensor([[[1., 2.]],

        [[3., 4.]]]) torch.Size([2, 1, 2])


In [4]:
y = x.expand(*[2, 3, 2])

f(y)

tensor([[[1., 2.],
         [1., 2.],
         [1., 2.]],

        [[3., 4.],
         [3., 4.],
         [3., 4.]]]) torch.Size([2, 3, 2])


### Hard way

In [5]:
y = torch.cat([x, x, x], 1)

f(y)

tensor([[[1., 2.],
         [1., 2.],
         [1., 2.]],

        [[3., 4.],
         [3., 4.],
         [3., 4.]]]) torch.Size([2, 3, 2])


## Random premutation

In [6]:
x = torch.randperm(10)

f(x)

tensor([1, 8, 3, 2, 7, 4, 0, 5, 9, 6]) torch.Size([10])


## Argument max

In [7]:
x = torch.randperm(3 ** 3).reshape(3, 3, -1)

f(x)

tensor([[[23,  6,  1],
         [22, 19,  4],
         [18, 17, 11]],

        [[13,  9,  2],
         [26,  0, 16],
         [14, 12,  5]],

        [[ 7, 20, 10],
         [25, 24, 15],
         [ 3,  8, 21]]]) torch.Size([3, 3, 3])


In [8]:
y = x.argmax(-1)

f(y)

tensor([[0, 0, 0],
        [0, 0, 0],
        [1, 0, 2]]) torch.Size([3, 3])


## Top-k values and indices

In [9]:
values, indices = torch.topk(x, 1, -1)

f(values)
f(indices)

tensor([[[23],
         [22],
         [18]],

        [[13],
         [26],
         [14]],

        [[20],
         [25],
         [21]]]) torch.Size([3, 3, 1])
tensor([[[0],
         [0],
         [0]],

        [[0],
         [0],
         [0]],

        [[1],
         [0],
         [2]]]) torch.Size([3, 3, 1])


In [10]:
f(values.squeeze(-1))
f(indices.squeeze(-1))

tensor([[23, 22, 18],
        [13, 26, 14],
        [20, 25, 21]]) torch.Size([3, 3])
tensor([[0, 0, 0],
        [0, 0, 0],
        [1, 0, 2]]) torch.Size([3, 3])


In [11]:
_, indices = x.topk(k=2, dim=-1)

f(indices)

tensor([[[0, 1],
         [0, 1],
         [0, 1]],

        [[0, 1],
         [0, 2],
         [0, 1]],

        [[1, 2],
         [0, 1],
         [2, 1]]]) torch.Size([3, 3, 2])


### Application

In [12]:
f(x)

tensor([[[23,  6,  1],
         [22, 19,  4],
         [18, 17, 11]],

        [[13,  9,  2],
         [26,  0, 16],
         [14, 12,  5]],

        [[ 7, 20, 10],
         [25, 24, 15],
         [ 3,  8, 21]]]) torch.Size([3, 3, 3])


In [13]:
f(x.argmax(dim=-1))

tensor([[0, 0, 0],
        [0, 0, 0],
        [1, 0, 2]]) torch.Size([3, 3])


In [14]:
f(indices[:, :, 0]) # Demention reduction

tensor([[0, 0, 0],
        [0, 0, 0],
        [1, 0, 2]]) torch.Size([3, 3])


In [15]:
x.argmax(dim=-1) == indices[:, :, 0]

tensor([[True, True, True],
        [True, True, True],
        [True, True, True]])

#### Using topk

In [16]:
x.argmax(dim=-1)

tensor([[0, 0, 0],
        [0, 0, 0],
        [1, 0, 2]])

In [17]:
_, indices = x.topk(k=1, dim=-1)

f(indices)

tensor([[[0],
         [0],
         [0]],

        [[0],
         [0],
         [0]],

        [[1],
         [0],
         [2]]]) torch.Size([3, 3, 1])


In [18]:
x.argmax(dim=-1) == indices.squeeze(-1)

tensor([[True, True, True],
        [True, True, True],
        [True, True, True]])

## Sort by topk()

In [19]:
x = torch.randperm(3 ** 3).reshape(3, 3, -1)
f(x)

tensor([[[ 5,  3, 14],
         [ 2, 21, 16],
         [10, 25, 18]],

        [[ 4, 12, 11],
         [22,  1,  8],
         [ 6,  7, 26]],

        [[24, 15,  0],
         [ 9, 13, 17],
         [20, 23, 19]]]) torch.Size([3, 3, 3])


In [20]:
target_dim = -1

values, indices = x.topk(k=x.size(target_dim), largest=True)

f(values)

tensor([[[14,  5,  3],
         [21, 16,  2],
         [25, 18, 10]],

        [[12, 11,  4],
         [22,  8,  1],
         [26,  7,  6]],

        [[24, 15,  0],
         [17, 13,  9],
         [23, 20, 19]]]) torch.Size([3, 3, 3])


## Topk by sort()

In [29]:
k = 1

values, indices = x.sort(dim=-1, descending=True)
values, indices = values[:, :, :k], indices[:, :, :k]

f(values.squeeze(-1))
f(indices.squeeze(-1))

tensor([[14, 21, 25],
        [12, 22, 26],
        [24, 17, 23]]) torch.Size([3, 3])
tensor([[2, 1, 1],
        [1, 0, 2],
        [0, 2, 1]]) torch.Size([3, 3])


## masked_fill: Fill the value if element of mask is True

In [36]:
x = torch.FloatTensor([i for i in range(3 ** 2)]).reshape(3, -1)

f(x)

tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]]) torch.Size([3, 3])


In [37]:
mask = x > 4
f(mask)

tensor([[False, False, False],
        [False, False,  True],
        [ True,  True,  True]]) torch.Size([3, 3])


In [40]:
y = x.masked_fill(mask, value=-999)
f(y)

tensor([[   0.,    1.,    2.],
        [   3.,    4., -999.],
        [-999., -999., -999.]]) torch.Size([3, 3])


## Ones or Zeroes

In [42]:
f(torch.ones(2, 3))
f(torch.zeros(2, 3))

tensor([[1., 1., 1.],
        [1., 1., 1.]]) torch.Size([2, 3])
tensor([[0., 0., 0.],
        [0., 0., 0.]]) torch.Size([2, 3])


In [43]:
x = torch.FloatTensor([
    [1, 2, 3],
    [4, 5, 6],
])
f(x)

tensor([[1., 2., 3.],
        [4., 5., 6.]]) torch.Size([2, 3])


In [44]:
f(torch.ones_like(x))
f(torch.zeros_like(x))

tensor([[1., 1., 1.],
        [1., 1., 1.]]) torch.Size([2, 3])
tensor([[0., 0., 0.],
        [0., 0., 0.]]) torch.Size([2, 3])
