In [1]:
"import necessary packages"
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import *
from functools import reduce

# torch.where and torch.nonzero tutorial
According to:
* https://docs.pytorch.org/docs/stable/generated/torch.where.html#torch-where
* https://docs.pytorch.org/docs/stable/generated/torch.nonzero.html#torch.nonzero

若如下使用`torch.where`:
```
torch.where(condition, input, other, *, out=None)
```
你可以直接参考上面给出的`torch-where`的链接。

`torch.nonzero`同理。

以下讲的是这种情况：
这样调用`torch.where(condition) → tuple of LongTensor`的时候是如何返回的。


In [7]:
x=torch.randn(4,6)
print("x:\n",x)
mask = x < 0
print("mask:\n",mask)
rows, cols = torch.where(mask)
print(rows)
print(cols)

# rows对应的是`x`各个非0元素的行索引；cols对应的是`x`各个非0元素的列索引
# 从下面的输出可以看到，`(rows[0], cols[0])`正是指向行开始第一个True的元素索引，以此类推。
# 一般用于快速获取原张量中的非0值，如下：

print(x[rows,cols])

x:
 tensor([[ 1.9279,  0.4416, -0.6376,  0.4393,  2.1493,  0.6519],
        [-0.3132,  0.5309,  1.0463, -0.1153, -0.6102, -0.7488],
        [ 0.9631, -0.5186,  0.8598,  0.6168,  1.4107, -0.1460],
        [-0.1118, -0.4871, -0.1144,  0.7684, -0.5828, -1.6393]])
mask:
 tensor([[False, False,  True, False, False, False],
        [ True, False, False,  True,  True,  True],
        [False,  True, False, False, False,  True],
        [ True,  True,  True, False,  True,  True]])
tensor([0, 1, 1, 1, 1, 2, 2, 3, 3, 3, 3, 3])
tensor([2, 0, 3, 4, 5, 1, 5, 0, 1, 2, 4, 5])
tensor([-0.6376, -0.3132, -0.1153, -0.6102, -0.7488, -0.5186, -0.1460, -0.1118,
        -0.4871, -0.1144, -0.5828, -1.6393])


# torch stride

In [None]:
# stride

test_shape=(2,3,4)
a = torch.arange(0,reduce(lambda x,y:x*y, test_shape),step=1).view(*test_shape)
print(a.stride(0))
print(a.stride(1))
print(a.stride(2))

12
4
1
tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])


# torch transpose && permute

In [None]:
# torch transpose
test_shape=(2,3,4)
a = torch.arange(0,reduce(lambda x,y:x*y, test_shape),step=1).view(*test_shape)
print("origin:\n",a)

trans_0_1_a = torch.transpose(a, 0,1)
print("trans_0_1_a:\n",trans_0_1_a)
trans_1_2_a = torch.transpose(a, 1,2)
print("trans_1_2_a:\n",trans_1_2_a)
transpose_0_2_a = torch.transpose(a, 0,2)
print("transpose_0_2_a:\n",transpose_0_2_a)

origin:
 tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])
trans_0_1_a:
 tensor([[[ 0,  1,  2,  3],
         [12, 13, 14, 15]],

        [[ 4,  5,  6,  7],
         [16, 17, 18, 19]],

        [[ 8,  9, 10, 11],
         [20, 21, 22, 23]]])
trans_1_2_a:
 tensor([[[ 0,  4,  8],
         [ 1,  5,  9],
         [ 2,  6, 10],
         [ 3,  7, 11]],

        [[12, 16, 20],
         [13, 17, 21],
         [14, 18, 22],
         [15, 19, 23]]])
transpose_0_2_a:
 tensor([[[ 0, 12],
         [ 4, 16],
         [ 8, 20]],

        [[ 1, 13],
         [ 5, 17],
         [ 9, 21]],

        [[ 2, 14],
         [ 6, 18],
         [10, 22]],

        [[ 3, 15],
         [ 7, 19],
         [11, 23]]])


### Conclusion
stride: 张量各维度的步长，也可以这样理解：
$$
\text{有张量[2,3,4]，则各维度stride即各维度元素间的位置差，因此有：} \\
n * \text{stride in ndim=i} = \text{ the (n+1)th element index in ndim} \text{，其中：ndim=i即维度i， } n \in N^*，n<ndim
$$
即各维度元素由遍历各维`stride`指向的维度构成。

出于上面的结论，可以很容易得出：`transpose`，或者更高阶的`permute`，其实就是调换遍历各维度的先后顺序。
如有下矩阵：
$$
\begin{bmatrix}
 0 &1  &2  &3 \\
 4 &5  &6  &7
\end{bmatrix}\\
\text{ndim=[2,4]}\\
\text{stride in ndim=1}=4, \text{stride in ndim=2}=1\\

\text{after transposing ndim 1 and 2:}\\
\begin{bmatrix}
0&4\\
1&5\\
2&6\\
3&7\\
\end{bmatrix}\\
\text{ndim=[4,2]}\\
\text{stride in ndim=1}=2, \text{stride in ndim=2}=1\\
$$
解释：本来遍历一维是[0,4],[1,5],[2,6],[3,7]，每个数组中间差一个第一维度的stride，遍历二维是[0,1,2,3],[4,5,6,7]；

`transpose`后：

遍历一维是[0,1,2,3],[4,5,6,7]；遍历二维是[0,4],[1,5],[2,6],[3,7]；

所以说`transpose`,`permute`之流，实际就是调换遍历各维度元素的先后顺序。


In [None]:
"reimplement torch permute"
from functools import reduce,partial
product=partial(reduce,lambda x,y:x*y)

shape=[2,3,4,5]
mat=torch.arange(product(shape)).reshape(shape)
print("[origin]\n",mat)

#### torch implement ####
permute_1=mat.permute(0,1,3,2) #[omit] simply transpose the last two dimension

mat2=mat.permute(1,0,2,3)
# print("[permute(1,0,2,3)]\n",mat2)

mat3=mat.permute(2,1,0,3)
# print("[permute(2,1,0,3)]\n",mat3)

mat4=mat.permute(3,2,0,1)
# print("[permute(3,2,0,1)]\n",mat4)
#### torch implement ####

#### reimplement ####
def transpose_2d(tensor:torch.Tensor,new_dim_indices:Iterable=[1,0]):
    assert tensor.ndim == len(new_dim_indices), "length of dimension indices passed in must be equal to shape of tensor passed in."
    new_tensor=[]
    new_dims=[]
    for _ in new_dim_indices:
        new_dims.append(tensor.shape[_])

    for j in range(new_dims[0]):
        temp=[]
        for i in range(new_dims[1]):
            temp.append(tensor[i,j])
        new_tensor.append(temp)
    return torch.tensor(new_tensor,dtype=tensor.dtype,device=tensor.device)

temp_tensor=torch.arange(3*5).view(3,5).requires_grad_(False)
print(temp_tensor)
print(transpose_2d(temp_tensor,new_dim_indices=[1,0]))

def permute(tensor:torch.Tensor, dim_indices:Iterable):
    assert tensor.ndim == len(dim_indices), "length of dimension indices passed in must be equal to shape of tensor passed in."
    new_tensor=[]
    access_index=[None for _ in range(len(dim_indices))]
    def _recursive_permute(access_index):
        temp_tensor=[]
        for i in range(access_index):
            temp_tensor.append(tensor[i])
        return temp_tensor

#### reimplement ####


[origin]
 tensor([[[[  0,   1,   2,   3,   4],
          [  5,   6,   7,   8,   9],
          [ 10,  11,  12,  13,  14],
          [ 15,  16,  17,  18,  19]],

         [[ 20,  21,  22,  23,  24],
          [ 25,  26,  27,  28,  29],
          [ 30,  31,  32,  33,  34],
          [ 35,  36,  37,  38,  39]],

         [[ 40,  41,  42,  43,  44],
          [ 45,  46,  47,  48,  49],
          [ 50,  51,  52,  53,  54],
          [ 55,  56,  57,  58,  59]]],


        [[[ 60,  61,  62,  63,  64],
          [ 65,  66,  67,  68,  69],
          [ 70,  71,  72,  73,  74],
          [ 75,  76,  77,  78,  79]],

         [[ 80,  81,  82,  83,  84],
          [ 85,  86,  87,  88,  89],
          [ 90,  91,  92,  93,  94],
          [ 95,  96,  97,  98,  99]],

         [[100, 101, 102, 103, 104],
          [105, 106, 107, 108, 109],
          [110, 111, 112, 113, 114],
          [115, 116, 117, 118, 119]]]])
tensor([[ 0,  1,  2,  3,  4],
        [ 5,  6,  7,  8,  9],
        [10, 11, 12, 13, 14

# CrossEntropyLoss

In [6]:
"reimplement CrossEntropyLoss"
reduction='mean'
num_predict=3
vocab_size=8

logits=torch.randint(0,10,(num_predict,vocab_size)).to(dtype=torch.float32)
label=torch.arange(num_predict).random_(vocab_size-1)
print("[logits]\n",logits)
print("[label]\n",label)

#### torch implement ####
loss_func=nn.CrossEntropyLoss(reduction=reduction) # alert that reduction of CrossEntropyLoss is mean
print("[loss]\n",loss_func(logits,label))
#### torch implement ####

#### reimplement ####
logits_dim_max=logits.amax(-1,keepdim=True)
after_diff_logits = logits-logits_dim_max # to avoid digits overflow 
### softmax ###
after_softmax_logits=after_diff_logits.exp()/after_diff_logits.exp().sum(-1,keepdim=True)
### softmax ###
### NLL: negative log-likehood ###
after_log_logits = -(after_softmax_logits.log())
### NLL: negative log-likehood ###
### calculate loss ###
before_reduction=after_log_logits[[i for i in range(num_predict)], label]
if reduction=='mean':
    loss=before_reduction.mean()
elif reduction=='sum':
    loss=before_reduction.sum()
elif reduction=='none':
    loss=before_reduction
### calculate loss ###
print("[loss(reimplement)] ",loss)
#### reimplement ####


[logits]
 tensor([[3., 1., 3., 7., 5., 1., 8., 2.],
        [8., 9., 9., 2., 5., 9., 7., 7.],
        [1., 0., 3., 6., 8., 1., 5., 5.]])
[label]
 tensor([6, 5, 2])
[loss]
 tensor(2.2922)
[loss(reimplement)]  tensor(2.2922)
