# PyTorch Useful Methods

In [9]:
import torch

### expand: copy the given tensor and concat those at desired dimension.

In [10]:
x = torch.FloatTensor([[[1, 2]],
                       [[3, 4]]])
print(x)
print(x.size())

tensor([[[1., 2.]],

        [[3., 4.]]])
torch.Size([2, 1, 2])


In [11]:
y = x.expand(*[2, 3, 2])

print(y)
print(y.size())

tensor([[[1., 2.],
         [1., 2.],
         [1., 2.]],

        [[3., 4.],
         [3., 4.],
         [3., 4.]]])
torch.Size([2, 3, 2])


#### Implement expand with cat.

In [16]:
z = torch.cat([x, x, x], dim=1)

print(z)
print(z.size())

z = torch.cat([x, x, x], dim=0)

print(x)
print(x.size())
print(z)
print(z.size())


tensor([[[1., 2.],
         [1., 2.],
         [1., 2.]],

        [[3., 4.],
         [3., 4.],
         [3., 4.]]])
torch.Size([2, 3, 2])
tensor([[[1., 2.]],

        [[3., 4.]]])
torch.Size([2, 1, 2])
tensor([[[1., 2.]],

        [[3., 4.]],

        [[1., 2.]],

        [[3., 4.]],

        [[1., 2.]],

        [[3., 4.]]])
torch.Size([6, 1, 2])


### randperm: Random Permutation

In [18]:
x = torch.randperm(10)

print(x)
print(x.size())

tensor([7, 3, 1, 9, 8, 2, 5, 0, 6, 4])
torch.Size([10])


### argmax: Return index of maximum values

In [81]:
x = torch.randperm(3**3).reshape(3, 3, -1)
print(x)
print(x.size())

tensor([[[25, 18,  7],
         [ 5, 23,  2],
         [16, 15, 17]],

        [[ 4, 12, 19],
         [ 1, 20, 22],
         [10,  8, 26]],

        [[ 3,  9, 14],
         [13,  6, 24],
         [21,  0, 11]]])
torch.Size([3, 3, 3])


In [None]:
y = x.argmax(dim=-1)

print(y)
print(y.size())

### topk: Return tuple of top-k values and indices.

In [None]:
values, indices = torch.topk(x, k=1, dim=-1)

print(values)
print(values.size())
print(indices)
print(indices.size())

Note that topk didn't reduce the dimension, even in $k=1$ case.

In [None]:
print(values.squeeze(-1))
print(values.squeeze(-1).size())
print(indices.squeeze(-1))
print(indices.squeeze(-1).size())

In [None]:
print(x.argmax(dim=-1))
print(indices.squeeze(-1))
print(x.argmax(dim=-1) == indices.squeeze(-1))

print(x.argmax(dim=-1).size())
print(indices)
print(x.argmax(dim=-1) == indices)

In [93]:
_, indices = torch.topk(x, k=2, dim=-1)
print(x)
print(x.argmax(dim=-1))
print(x.argmax(dim=-1).size())
print(indices)
print(indices.size())
print(indices[:, :, 0])
print(x.argmax(dim=-1) == indices[:, :, 0])

tensor([[[25, 18,  7],
         [ 5, 23,  2],
         [16, 15, 17]],

        [[ 4, 12, 19],
         [ 1, 20, 22],
         [10,  8, 26]],

        [[ 3,  9, 14],
         [13,  6, 24],
         [21,  0, 11]]])
tensor([[0, 1, 2],
        [2, 2, 2],
        [2, 2, 0]])
torch.Size([3, 3])
tensor([[[0, 1],
         [1, 0],
         [2, 0]],

        [[2, 1],
         [2, 1],
         [2, 0]],

        [[2, 1],
         [2, 0],
         [0, 2]]])
torch.Size([3, 3, 2])
tensor([[0, 1, 2],
        [2, 2, 2],
        [2, 2, 0]])
tensor([[True, True, True],
        [True, True, True],
        [True, True, True]])


### Sort by using topk

In [94]:

target_dim = -1

values, indices = torch.topk(x,
                             k=x.size(target_dim),
                             largest=True)

print(values)


tensor([[[25, 18,  7],
         [23,  5,  2],
         [17, 16, 15]],

        [[19, 12,  4],
         [22, 20,  1],
         [26, 10,  8]],

        [[14,  9,  3],
         [24, 13,  6],
         [21, 11,  0]]])


### Topk by using sort

In [121]:
print(x)
print(x[:,:,1])

tensor([[[25, 18,  7],
         [ 5, 23,  2],
         [16, 15, 17]],

        [[ 4, 12, 19],
         [ 1, 20, 22],
         [10,  8, 26]],

        [[ 3,  9, 14],
         [13,  6, 24],
         [21,  0, 11]]])
tensor([[18, 23, 15],
        [12, 20,  8],
        [ 9,  6,  0]])


In [134]:
k=1
#print(x)
values, indices = torch.sort(x, dim=-1, descending=True)
print(values)
#print(indices)
#print(values[:, :, :1])
print(values[:, :, :2].squeeze())

values, indices = values[:, :, :k], indices[:, :, :k]
#print(values.squeeze(-1))
#print(indices.squeeze(-1))

tensor([[[25, 18,  7],
         [23,  5,  2],
         [17, 16, 15]],

        [[19, 12,  4],
         [22, 20,  1],
         [26, 10,  8]],

        [[14,  9,  3],
         [24, 13,  6],
         [21, 11,  0]]])
tensor([[[25, 18],
         [23,  5],
         [17, 16]],

        [[19, 12],
         [22, 20],
         [26, 10]],

        [[14,  9],
         [24, 13],
         [21, 11]]])


### masked_fill: fill the value if element of mask is True.

In [None]:
x = torch.FloatTensor([i for i in range(3**2)]).reshape(3, -1)

print(x)
print(x.size())

In [None]:
mask = x > 4

print(mask)

In [None]:
y = x.masked_fill(mask, value=-1)

print(y)

### Ones and Zeros

In [None]:
print(torch.ones(2, 3))
print(torch.zeros(2, 3))

In [None]:
x = torch.FloatTensor([[1, 2, 3],
                       [4, 5, 6]])
print(x.size())

In [None]:
print(torch.ones_like(x))
print(torch.zeros_like(x))