In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
print(torch.cuda.is_available())

# word embedding

In [15]:
# 考虑source和target sentence
# 构建序列，序列的字符以其在词表中的索引形式表示
batch_size = 5
seed = 42

# 单词表大小
max_num_src_words = 8
max_num_tar_words = 8

# 序列最大长度
max_src_seq_len = 5
max_tar_seq_len = 5

src_len = torch.randint(2, max_src_seq_len, (batch_size, ))
tar_len = torch.randint(2, max_tar_seq_len, (batch_size, ))
src_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_src_words, (L, )), (0, max_src_seq_len - L)), 0) for L in src_len])
tar_seq = torch.cat([torch.unsqueeze(F.pad(torch.randint(1, max_num_tar_words, (L, )), (0, max_tar_seq_len - L)), 0) for L in tar_len])

In [16]:
print(tar_len)
print(tar_seq)

tensor([3, 2, 2, 2, 4])
tensor([[2, 2, 2, 0, 0],
        [5, 3, 0, 0, 0],
        [3, 1, 0, 0, 0],
        [1, 2, 0, 0, 0],
        [6, 2, 7, 1, 0]])


In [37]:
# 构造embedding
model_dim = 8
src_embedding_table = nn.Embedding(max_num_src_words + 1, model_dim)
tar_embedding_table = nn.Embedding(max_num_tar_words + 1, model_dim)

src_embedding = src_embedding_table(src_seq)
tar_embedding = tar_embedding_table(tar_seq)

In [38]:
print(src_embedding_table.weight)
print(src_seq)
print(src_embedding)

Parameter containing:
tensor([[ 1.1577, -1.1033, -0.0046,  0.2951,  1.1064, -0.4616,  0.0740, -1.2800],
        [-1.0922,  0.9452,  1.6080,  0.1181, -0.5269, -0.5717,  1.2725, -0.9821],
        [-0.3340,  0.5336, -0.3190,  0.3476, -0.1858, -0.8630,  0.0159, -1.2126],
        [-1.1064,  0.2684,  2.5226, -0.1207,  0.3365, -1.1780, -3.2235, -0.9720],
        [-1.4668, -0.9273, -1.0133,  1.0210, -1.0476, -0.3432, -0.7435, -0.6645],
        [-0.1514, -0.5776, -0.4437, -0.2365, -0.1730,  0.4924, -0.6754, -0.3440],
        [-0.7918,  0.3064,  1.4678, -1.1451,  1.5427, -1.2430,  1.6119,  0.1663],
        [ 1.8051,  1.1401, -1.3318,  1.1858, -1.7826, -1.3912,  0.7077, -1.3698],
        [ 0.2001, -0.2227,  1.6947, -0.1528, -2.4707, -1.2975,  0.0355, -0.6559]],
       requires_grad=True)
tensor([[4, 2, 7, 0, 0],
        [4, 2, 6, 5, 0],
        [6, 7, 1, 1, 0],
        [1, 2, 3, 0, 0],
        [6, 7, 6, 3, 0]])
tensor([[[-1.4668, -0.9273, -1.0133,  1.0210, -1.0476, -0.3432, -0.7435,
          -0.

# position embedding

In [44]:
max_position_len = 5
pos_mat = torch.arange(max_position_len).reshape(-1, 1)
i_mat = torch.pow(10000, torch.arange(0, 8, 2).reshape(1, -1) / model_dim)
pe_embedding_table = torch.zeros(max_position_len, model_dim)
pe_embedding_table[:,0::2] = torch.sin(pos_mat / i_mat)
pe_embedding_table[:,1::2] = torch.cos(pos_mat / i_mat)

In [45]:
print(pe_embedding_table)

tensor([[ 0.0000e+00,  1.0000e+00,  0.0000e+00,  1.0000e+00,  0.0000e+00,
          1.0000e+00,  0.0000e+00,  1.0000e+00],
        [ 8.4147e-01,  5.4030e-01,  9.9833e-02,  9.9500e-01,  9.9998e-03,
          9.9995e-01,  1.0000e-03,  1.0000e+00],
        [ 9.0930e-01, -4.1615e-01,  1.9867e-01,  9.8007e-01,  1.9999e-02,
          9.9980e-01,  2.0000e-03,  1.0000e+00],
        [ 1.4112e-01, -9.8999e-01,  2.9552e-01,  9.5534e-01,  2.9995e-02,
          9.9955e-01,  3.0000e-03,  1.0000e+00],
        [-7.5680e-01, -6.5364e-01,  3.8942e-01,  9.2106e-01,  3.9989e-02,
          9.9920e-01,  4.0000e-03,  9.9999e-01]])


# MASK

In [65]:
# [batch_size, max_src_len, max_src_len]
valid_encoder_pos = torch.unsqueeze(torch.cat([torch.unsqueeze(F.pad(torch.ones(L), (0, max_src_seq_len - L)), 0) for L in src_len]), 2)
valid_encoder_pos_matrix = torch.bmm(valid_encoder_pos, valid_encoder_pos.transpose(1, 2))

In [66]:
valid_encoder_pos_matrix.shape

torch.Size([5, 5, 5])

In [67]:
valid_encoder_pos_matrix

tensor([[[1., 1., 1., 0., 0.],
         [1., 1., 1., 0., 0.],
         [1., 1., 1., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[1., 1., 1., 1., 0.],
         [1., 1., 1., 1., 0.],
         [1., 1., 1., 1., 0.],
         [1., 1., 1., 1., 0.],
         [0., 0., 0., 0., 0.]],

        [[1., 1., 1., 1., 0.],
         [1., 1., 1., 1., 0.],
         [1., 1., 1., 1., 0.],
         [1., 1., 1., 1., 0.],
         [0., 0., 0., 0., 0.]],

        [[1., 1., 1., 0., 0.],
         [1., 1., 1., 0., 0.],
         [1., 1., 1., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]],

        [[1., 1., 1., 1., 0.],
         [1., 1., 1., 1., 0.],
         [1., 1., 1., 1., 0.],
         [1., 1., 1., 1., 0.],
         [0., 0., 0., 0., 0.]]])

In [69]:
src_len

tensor([3, 4, 4, 3, 4])

In [70]:
invalid_encoder_pos_matrix = 1 - valid_encoder_pos_matrix
mask_encoder_pos_matrix = invalid_encoder_pos_matrix.to(torch.bool)

In [71]:
mask_encoder_pos_matrix

tensor([[[False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [ True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True]],

        [[False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [ True,  True,  True,  True,  True]],

        [[False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [ True,  True,  True,  True,  True]],

        [[False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [ True,  True,  True,  True,  True],
         [ True,  True,  True,  True,  True]],

        [[False, False, False, False,  True],
         [False, False, Fa

In [77]:
mask_encoder_pos_matrix.shape

torch.Size([5, 5, 5])

In [82]:
score = torch.randn(batch_size, max(src_len) + 1, max(src_len) + 1)
print(score.shape)
masked_score = score.masked_fill(mask_encoder_pos_matrix, -1e9)

torch.Size([5, 5, 5])


In [83]:
prob = F.softmax(masked_score, -1)

In [85]:
masked_score


tensor([[[ 1.9844e-02,  7.0300e-01,  7.9120e-01, -1.0000e+09, -1.0000e+09],
         [-7.0288e-01, -1.6758e+00,  1.8230e+00, -1.0000e+09, -1.0000e+09],
         [ 2.7541e-01, -1.6280e+00, -1.0990e+00, -1.0000e+09, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]],

        [[ 1.3281e+00, -1.2928e+00, -5.3599e-01,  4.7906e-02, -1.0000e+09],
         [ 5.9424e-01,  4.2643e-01,  1.6178e+00,  2.2941e+00, -1.0000e+09],
         [ 1.7913e+00,  1.6131e+00, -6.9231e-01,  1.3680e+00, -1.0000e+09],
         [ 8.9133e-01, -1.2346e+00, -1.4193e+00,  4.0889e-02, -1.0000e+09],
         [-1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09, -1.0000e+09]],

        [[ 1.8870e-01,  7.3259e-01, -1.4324e+00,  6.5178e-01, -1.0000e+09],
         [ 4.8557e-01,  3.9648e-01, -5.4384e-01,  1.0643e+00, -1.0000e+09],
         [-1.8655e-02, -5.4257e-01,  1.7989e-01,  1.2249e+00, -1.0000e+09],
        