# PyTorch Useful Methods

In [1]:
import torch

### expand: copy the given tensor and concat those at desired dimension.

In [6]:
x = torch.FloatTensor([[[1, 2]],

                       [[3, 4]]])
# print(x.squeeze())
print(x.size())

torch.Size([2, 1, 2])


In [7]:
y = x.expand(*[2, 3, 2]) # 복사해서 원하는 차원으로 만들어 줌?

print(y)
print(y.size())

tensor([[[1., 2.],
         [1., 2.],
         [1., 2.]],

        [[3., 4.],
         [3., 4.],
         [3., 4.]]])
torch.Size([2, 3, 2])


#### Implement expand with cat.

In [8]:
y = torch.cat([x, x, x], dim=1)

print(y)
print(y.size())

tensor([[[1., 2.],
         [1., 2.],
         [1., 2.]],

        [[3., 4.],
         [3., 4.],
         [3., 4.]]])
torch.Size([2, 3, 2])


### randperm: Random Permutation

In [9]:
x = torch.randperm(10) # 0부터 9까지 랜덤 셔플
# index_select의 indice로 활용 가능

print(x)
print(x.size())

tensor([5, 7, 0, 8, 1, 3, 9, 6, 4, 2])
torch.Size([10])


### argmax: Return index of maximum values

In [10]:
x = torch.randperm(3**3).reshape(3, 3, -1)

print(x)
print(x.size())

tensor([[[25, 24, 14],
         [26, 21,  0],
         [ 1, 22, 20]],

        [[ 5, 19, 13],
         [11, 18,  2],
         [ 8, 16,  7]],

        [[10, 17, 23],
         [ 4, 15, 12],
         [ 3,  6,  9]]])
torch.Size([3, 3, 3])


In [13]:
y = x.argmax(dim=-1) # 기준 차원에서 가장 큰 애의 인덱스를 반환

print(y)
print(y.size())

tensor([[0, 0, 1],
        [1, 1, 1],
        [2, 1, 2]])
torch.Size([3, 3])


### topk: Return tuple of top-k values and indices.

In [36]:
values, indices = torch.topk(x, k=1, dim=-1) # 제일 큰 k개를 뽑아라


print(values)
print(indices)
print(values.size())
print(indices.size())

print(indices[:, :, 0])

tensor([[[0],
         [0],
         [1]],

        [[1],
         [1],
         [1]],

        [[2],
         [1],
         [2]]])
tensor([[0, 0, 1],
        [1, 1, 1],
        [2, 1, 2]])


Note that topk didn't reduce the dimension, even in $k=1$ case.

In [9]:
print(values.squeeze(-1))
print(indices.squeeze(-1))

tensor([[25, 16, 24],
        [17, 22, 12],
        [23, 21, 26]])
tensor([[2, 1, 0],
        [2, 1, 2],
        [2, 1, 0]])


In [10]:
print(x.argmax(dim=-1) == indices.squeeze(-1))

tensor([[True, True, True],
        [True, True, True],
        [True, True, True]])


In [11]:
_, indices = torch.topk(x, k=2, dim=-1)
print(indices.size())

print(x.argmax(dim=-1) == indices[:, :, 0])

torch.Size([3, 3, 2])
tensor([[True, True, True],
        [True, True, True],
        [True, True, True]])


### Sort by using topk

In [41]:
target_dim = -1
values, indices = torch.topk(x,
                             k=x.size(target_dim),
                             largest=True)

print(values)

tensor([[[25, 24, 14],
         [26, 21,  0],
         [22, 20,  1]],

        [[19, 13,  5],
         [18, 11,  2],
         [16,  8,  7]],

        [[23, 17, 10],
         [15, 12,  4],
         [ 9,  6,  3]]])


### Topk by using sort

In [43]:
k=1
values, indices = torch.sort(x, dim=-1, descending=True)
values, indices = values[:, :, :k], indices[:, :, :k]

print(values.squeeze(-1))
print(indices.squeeze(-1))

tensor([[25, 26, 22],
        [19, 18, 16],
        [23, 15,  9]])
tensor([[0, 0, 1],
        [1, 1, 1],
        [2, 1, 2]])


### masked_fill: fill the value if element of mask is True.

In [44]:
x = torch.FloatTensor([i for i in range(3**2)]).reshape(3, -1)

print(x)
print(x.size())

tensor([[0., 1., 2.],
        [3., 4., 5.],
        [6., 7., 8.]])
torch.Size([3, 3])


In [45]:
mask = x > 4 # 불리안 텐서

print(mask)

tensor([[False, False, False],
        [False, False,  True],
        [ True,  True,  True]])


In [46]:
y = x.masked_fill(mask, value=-1) # mask가 true면 -1로 채워라

print(y)

tensor([[ 0.,  1.,  2.],
        [ 3.,  4., -1.],
        [-1., -1., -1.]])


### Ones and Zeros

In [47]:
print(torch.ones(2, 3))
print(torch.zeros(2, 3))

tensor([[1., 1., 1.],
        [1., 1., 1.]])
tensor([[0., 0., 0.],
        [0., 0., 0.]])


In [51]:
x = torch.LongTensor([[1, 2, 3],
                       [4, 5, 6]])
print(x.size())

torch.Size([2, 3])


In [52]:
print(torch.ones_like(x)) # type과 device가 같은...
print(torch.zeros_like(x))

tensor([[1, 1, 1],
        [1, 1, 1]])
tensor([[0, 0, 0],
        [0, 0, 0]])
