In [25]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F

# 关于word embeding序列建模
# Consider the source sentence and the target sentence
# Construct the sequences whose characters are represented by their indexes in the word table
batch_size = 2

# size of the words
max_num_src_words = 8
max_num_target_words = 8

model_dim = 8

# define the max len of sequences
max_src_seq_len = 5
max_target_seq_len = 5
max_position_len = 5

# src_len = torch.randint(2, 5, (batch_size,))
# target_len = torch.randint(2, 5, (batch_size,))

src_len = torch.Tensor([2,4]).to(torch.int32)
target_len = torch.Tensor([4,3]).to(torch.int32)

# the sequence of the index about the words
# and have been done the padding, see the 0 as the default
src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L,)), (0, max(src_len) - L)), 0)\
                     for L in src_len])
target_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_target_words, (L,)), (0, max(target_len) - L)), 0)\
                     for L in target_len])


# Construct the word embedding 
src_embedding_table = nn.Embedding(max_num_src_words + 1, model_dim)
target_embedding_table = nn.Embedding(max_num_target_words + 1, model_dim)
src_embedding = src_embedding_table(src_seq)
target_embedding = target_embedding_table(target_seq)

# Construct the position embedding
pos_mat = torch.arange(max_position_len).reshape((-1,1))
i_mat = torch.pow(10000,torch.arange(0,8,2).reshape((1,-1))/model_dim)
pe_embedding_table = torch.zeros(max_position_len, model_dim)
pe_embedding_table[:, 0::2] = torch.sin(pos_mat/i_mat)
pe_embedding_table[:, 1::2] = torch.cos(pos_mat/i_mat)

# print(pe_embedding_table)

pe_embedding = nn.Embedding(max_position_len, model_dim)
pe_embedding.weight = nn.Parameter(pe_embedding_table, requires_grad=False)

src_pos = torch.cat([torch.unsqueeze(torch.arange(max(src_len)),0) for _ in src_len]).to(torch.int32)
target_pos = torch.cat([torch.unsqueeze(torch.arange(max(target_len)),0) for _ in target_len]).to(torch.int32)
# print(src_pos)

src_pe_embedding = pe_embedding(src_pos)
target_pe_embedding = pe_embedding(target_pos)
# print(src_pe_embedding)
# print(target_pe_embedding)

# Construct the encoder's self-attention mask
# The shape of the mask:[batch_size, max_src_len, max_src_len],the output is -1 or thr -inf
valid_encoder_pos = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max(src_len) - L)), 0) \
                                               for L in src_len]),2)
valid_encoder_pos_matrix = torch.bmm(valid_encoder_pos, valid_encoder_pos.transpose(1,2))
invalid_encoder_pos_matrix = 1 - valid_encoder_pos_matrix
mask_encoder_self_attention = invalid_encoder_pos_matrix.to(torch.bool)

score = torch.randn(batch_size, max(src_len), max(src_len))
masked_score = score.masked_fill(mask_encoder_self_attention, -1e9)
prob = F.softmax(masked_score, -1)

print(src_len)
print(score)
print(masked_score)
print(prob)


tensor([2, 4], dtype=torch.int32)
tensor([[[-0.0578, -0.4552, -0.2449,  0.5633],
         [ 1.0674, -0.4888,  0.5003,  1.1752],
         [-0.8320, -1.2310,  0.3404, -0.8350],
         [-1.0298, -0.9438, -1.2155,  0.1495]],

        [[ 0.7927,  2.3546,  0.2898, -1.3661],
         [ 1.7138,  0.8618, -0.3444,  0.9352],
         [ 0.0902, -1.0603,  1.2608, -2.2013],
         [ 0.3592, -0.1886, -1.3319, -0.1157]]])
tensor([[[-5.7835e-02, -4.5516e-01, -1.0000e+09, -1.0000e+09],
         [ 1.0674e+00, -4.8878e-01, -1.0000e+09, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]],

        [[ 7.9267e-01,  2.3546e+00,  2.8975e-01, -1.3661e+00],
         [ 1.7138e+00,  8.6184e-01, -3.4442e-01,  9.3525e-01],
         [ 9.0245e-02, -1.0603e+00,  1.2608e+00, -2.2013e+00],
         [ 3.5916e-01, -1.8861e-01, -1.3319e+00, -1.1574e-01]]])
tensor([[[0.5980, 0.4020, 0.0000, 0.0000],
         [0.8258, 0.1742, 0.0000, 0