In [50]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

# 关于word embedding，以序列建模为例
# 考虑source sentence 和 target sentence
# 构建序列，序列的字符以其在词表中的索引的形式表示
batch_size = 2

# 单词表大小
max_num_src_words = 8
max_num_tgt_words = 8
model_dim = 8

# 序列的最大长度
max_src_seq_len = 5
max_tgt_seq_len = 5
max_position_len = 5

#src_len = torch.randint(2, 5, (batch_size,))
#tgt_len = torch.randint(2, 5, (batch_size,))
src_len = torch.Tensor([2, 4]).to(torch.int32)
tgt_len = torch.Tensor([4, 3]).to(torch.int32)

# 单词索引构成源句子和目标句子， 构建batch， 并且做了padding， 默认值为0
src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L,)),(0, max(src_len)-L)), 0) for L in src_len])
tgt_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_tgt_words, (L,)),(0, max(tgt_len)-L)), 0) for L in tgt_len])

# 构造word embedding
src_embedding_table = nn.Embedding(max_num_src_words+1, model_dim)
tgt_embedding_table = nn.Embedding(max_num_tgt_words+1, model_dim)
src_embedding = src_embedding_table(src_seq)
tgt_embedding = tgt_embedding_table(tgt_seq)

# 构造position embedding
pos_mat = torch.arange(max_position_len).reshape((-1, 1))
i_mat = torch.pow(10000, torch.arange(0, 8, 2).reshape((1, -1))/model_dim)
pe_embedding_table = torch.zeros(max_position_len, model_dim)
pe_embedding_table[:, 0::2] = torch.sin(pos_mat / i_mat)
pe_embedding_table[:, 1::2] = torch.cos(pos_mat / i_mat)

pe_embedding = nn.Embedding(max_position_len, model_dim)
pe_embedding.weight = nn.Parameter(pe_embedding_table, requires_grad=False)

src_pos = torch.cat([torch.unsqueeze(torch.arange(max(src_len)),0) for _ in src_len]).to(torch.int32)
tgt_pos = torch.cat([torch.unsqueeze(torch.arange(max(tgt_len)),0) for _ in tgt_len]).to(torch.int32)

src_pe_embedding = pe_embedding(src_pos)
tgt_pe_embedding = pe_embedding(tgt_pos)

# 构造encoder的self-attention mask
# mask的shape：[batch_size, max_src_len, max_src_len],值为1或-inf
valid_encoder_pos = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max(src_len)-L)),0) for L in src_len]), 2)

valid_encoder_pos_matrix = torch.bmm(valid_encoder_pos, valid_encoder_pos.transpose(1, 2))
invalid_encoder_pos_matrix = 1-valid_encoder_pos_matrix
mask_encoder_self_attention = invalid_encoder_pos_matrix.to(torch.bool)

score = torch.randn(batch_size, max(src_len), max(src_len))
masked_score = score.masked_fill(mask_encoder_self_attention, -1e9)
prob = F.softmax(masked_score, -1)

print(src_len)
print(score)
print(masked_score)
print(prob)

tensor([2, 4], dtype=torch.int32)
tensor([[[ 0.3815,  0.5220,  0.2785,  0.6042],
         [ 1.5050, -1.6018,  0.2356, -0.2815],
         [ 1.7953,  0.5009, -0.1580,  0.7617],
         [-1.3702, -0.7591, -0.9953, -1.3197]],

        [[ 1.1532,  1.2433,  1.2719, -0.2345],
         [-0.2523,  1.5268, -0.3140,  0.1079],
         [-0.2026,  0.4382, -1.1197,  1.0557],
         [-0.1494, -0.6837,  0.0867, -0.5056]]])
tensor([[[ 3.8153e-01,  5.2205e-01, -1.0000e+09, -1.0000e+09],
         [ 1.5050e+00, -1.6018e+00, -1.0000e+09, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]],

        [[ 1.1532e+00,  1.2433e+00,  1.2719e+00, -2.3452e-01],
         [-2.5229e-01,  1.5268e+00, -3.1398e-01,  1.0791e-01],
         [-2.0264e-01,  4.3819e-01, -1.1197e+00,  1.0557e+00],
         [-1.4941e-01, -6.8375e-01,  8.6683e-02, -5.0557e-01]]])
tensor([[[0.4649, 0.5351, 0.0000, 0.0000],
         [0.9572, 0.0428, 0.0000, 0

In [44]:
# softmax演示, scaled的重要性
alpha1 = 0.1
alpha2 = 10
score = torch.randn(5)
prob1 = F.softmax(score*alpha1, -1)
prob2 = F.softmax(score*alpha2, -1)
def softmax_func(score):
    return F.softmax(score)
jaco_mat1 = torch.autograd.functional.jacobian(softmax_func, score*alpha1)
jaco_mat2 = torch.autograd.functional.jacobian(softmax_func, score*alpha2)
print(score)
print(jaco_mat1)
print(jaco_mat2)

tensor([-0.3876,  0.1989,  0.6898,  1.0837, -0.5549])
tensor([[ 0.1527, -0.0375, -0.0394, -0.0410, -0.0348],
        [-0.0375,  0.1597, -0.0418, -0.0435, -0.0369],
        [-0.0394, -0.0418,  0.1656, -0.0457, -0.0388],
        [-0.0410, -0.0435, -0.0457,  0.1704, -0.0403],
        [-0.0348, -0.0369, -0.0388, -0.0403,  0.1508]])
tensor([[ 3.9960e-07, -5.6306e-11, -7.6300e-09, -3.9191e-07, -2.9967e-14],
        [-5.6306e-11,  1.4089e-04, -2.6905e-06, -1.3820e-04, -1.0567e-11],
        [-7.6300e-09, -2.6905e-06,  1.8730e-02, -1.8727e-02, -1.4319e-09],
        [-3.9191e-07, -1.3820e-04, -1.8727e-02,  1.8866e-02, -7.3551e-08],
        [-2.9967e-14, -1.0567e-11, -1.4319e-09, -7.3551e-08,  7.4993e-08]])


  
