# Demonstration of `top-k` operator in `pytorch`

This notebook is to demonstrate how the `top-k` operator works in `pytorch`.
We hereby show that the `top-k` operator is equivelant to using a mask to filter out the top-k elements.

In [1]:
import torch

In [2]:
x = torch.tensor(
    [0.4, 0.5, 0.6], requires_grad=True
)
x

tensor([0.4000, 0.5000, 0.6000], requires_grad=True)

In [3]:
w = torch.tensor(
    [
        [1.0, 1.0, 1.0],
        [2.0, 2.0, 2.0],
        [3.0, 3.0, 3.0],
    ],
    requires_grad=True
)
w

tensor([[1., 1., 1.],
        [2., 2., 2.],
        [3., 3., 3.]], requires_grad=True)

In [4]:
y = w @ x
y

tensor([1.5000, 3.0000, 4.5000], grad_fn=<MvBackward0>)

In [5]:
val1, idx1 = torch.topk(y, 2)
val1, idx1

(tensor([4.5000, 3.0000], grad_fn=<TopkBackward0>), tensor([2, 1]))

In [6]:
loss1 = (val1 ** 2).sum()
loss1.backward()
w.grad

tensor([[0.0000, 0.0000, 0.0000],
        [2.4000, 3.0000, 3.6000],
        [3.6000, 4.5000, 5.4000]])

In [7]:
w.grad.zero_()

tensor([[0., 0., 0.],
        [0., 0., 0.],
        [0., 0., 0.]])

In [8]:
x.grad

tensor([39., 39., 39.])

In [9]:
x.grad.zero_()

tensor([0., 0., 0.])

In [10]:
y2 = w @ x
y2

tensor([1.5000, 3.0000, 4.5000], grad_fn=<MvBackward0>)

In [11]:
w_topk = torch.zeros(2, y2.shape[0])
w_topk[torch.arange(2), idx1] = 1
w_topk.requires_grad = True
w_topk

tensor([[0., 0., 1.],
        [0., 1., 0.]], requires_grad=True)

In [12]:
val2 = w_topk @ y2
val2

tensor([4.5000, 3.0000], grad_fn=<MvBackward0>)

In [13]:
loss2 = (val2 ** 2).sum()
loss2.backward()
w.grad

tensor([[0.0000, 0.0000, 0.0000],
        [2.4000, 3.0000, 3.6000],
        [3.6000, 4.5000, 5.4000]])

In [14]:
x.grad

tensor([39., 39., 39.])

In [15]:
x.grad.zero_()

tensor([0., 0., 0.])

In [16]:
# 定义矩阵W和向量x
W = torch.tensor([[1.0, 1.0, 1.0], [2.0, 2.0, 2.0], [3.0, 3.0, 3.0]], requires_grad=True)
x = torch.tensor([0.4, 0.5, 0.6], requires_grad=True)

# 计算y和z
y = torch.matmul(W, x)
z = y**2

# 对z的每个元素求和以获取标量
z_total = z.sum()

# 反向传播计算梯度
z_total.backward()

# 显示梯度
W.grad, x.grad


(tensor([[1.2000, 1.5000, 1.8000],
         [2.4000, 3.0000, 3.6000],
         [3.6000, 4.5000, 5.4000]]),
 tensor([42., 42., 42.]))