In [2]:
import torch

### Find the max along an axis

In [6]:
x = torch.tensor([[1., .5, .2, 0.],
                  [2., 1.9, 1.2, 1.8],
                  [2.2, 2.1, 2.5, 3.]])
x.size()

torch.Size([3, 4])

In [8]:
torch.max(x, dim=1)

(tensor([1., 2., 3.]), tensor([0, 0, 3]))

### Convert vector to one-hot matrix

In [13]:
a = torch.tensor([[3], [0], [3]])
print(a.size())
print(a)

torch.Size([3, 1])
tensor([[3],
        [0],
        [3]])


In [14]:
x = torch.zeros(3, torch.max(a)+1)
print(x.size())
print(x)

torch.Size([3, 4])
tensor([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]])


In [15]:
x.scatter_(1, a, torch.tensor(1))

tensor([[0., 0., 0., 1.],
        [1., 0., 0., 0.],
        [0., 0., 0., 1.]])

### Select non-zero values in a matrix into a vector

In [23]:
a = torch.tensor([[0], [3], [1]])
m = torch.tensor([[0.87, 0., 0., 0.],
                  [0., 0., 0., 0.1],
                  [0., 0.4, .0, 0.]])
torch.gather(m, 1, a)

tensor([[0.8700],
        [0.1000],
        [0.4000]])

### Select specific indexes from a matrix
Given a matrix $x \in \mathbb R^{m \times n}$ and a vector $y \in \mathbb R^n$, select from each row of $x$, the element at index specified by $y$.

$$
x = \begin{bmatrix}
1 & 2 & 3 \\
4 & 5 & 6 \\
\end{bmatrix}
$$

$$
y = \begin{bmatrix}
0 \\
2 \\
\end{bmatrix}
$$

I want to extract
$$
z = \begin{bmatrix}
1 \\
6
\end{bmatrix}
$$

My immidiate reaction is to use an operation similar to calculating softmax probs, which is to convert $y$ to a one-hot vector, do an elementwise multiplication with $x$ and then extract all the non-zero elements of x out. However, using the `torch.gather()` function I can actually do this much more easily.

In [27]:
x = torch.tensor([[1., 2., 3.],
                  [4., 5., 6.]])
y = torch.tensor([[0],
                  [2]])

In [30]:
torch.gather(x, dim=1, index=y)

tensor([[1.],
        [6.]])