In [59]:
import numpy as np
import torch

sentences = [
    # enc_input           dec_input         dec_output
    ['ich mochte ein bier P', 'S i want a beer .', 'i want a beer . E'],
    ['ich mochte ein cola P', 'S i want a coke .', 'i want a coke . E']
]

# Padding Should be Zero
src_vocab = {'P': 0, 'ich': 1, 'mochte': 2, 'ein': 3, 'bier': 4, 'cola': 5}
src_vocab_size = len(src_vocab)

tgt_vocab = {'P': 0, 'i': 1, 'want': 2, 'a': 3, 'beer': 4, 'coke': 5, 'S': 6, 'E': 7, '.': 8}
idx2word = {i: w for i, w in enumerate(tgt_vocab)}
tgt_vocab_size = len(tgt_vocab)

src_len = 5  # enc_input max sequence length
tgt_len = 6  # dec_input(=dec_output) max sequence length

# Transformer Parameters
d_model = 512  # Embedding Size
d_ff = 2048  # FeedForward dimension
d_k = d_v = 64  # dimension of K(=Q), V
n_layers = 6  # number of Encoder of Decoder Layer
n_heads = 8  # number of heads in Multi-Head Attention


if torch.cuda.is_available():
    device = torch.device("cuda")
    print("use gpu........")
else:
    device = torch.device("cpu")
    print("use cpu........")


def make_data(sentences):
    enc_inputs, dec_inputs, dec_outputs = [], [], []
    for i in range(len(sentences)):
        enc_input = [[src_vocab[n] for n in sentences[i][0].split()]]  # [[1, 2, 3, 4, 0], [1, 2, 3, 5, 0]]
        dec_input = [[tgt_vocab[n] for n in sentences[i][1].split()]]  # [[6, 1, 2, 3, 4, 8], [6, 1, 2, 3, 5, 8]]
        dec_output = [[tgt_vocab[n] for n in sentences[i][2].split()]]  # [[1, 2, 3, 4, 8, 7], [1, 2, 3, 5, 8, 7]]

        enc_inputs.extend(enc_input)
        dec_inputs.extend(dec_input)
        dec_outputs.extend(dec_output)

    return torch.LongTensor(enc_inputs), torch.LongTensor(dec_inputs), torch.LongTensor(dec_outputs)


enc_inputs, dec_inputs, dec_outputs = make_data(sentences)
print("enc_inputs: ", enc_inputs.size())


use cpu........
enc_inputs:  torch.Size([2, 5])


In [60]:
def get_attn_pad_mask(seq_q, seq_k):
    '''
    seq_q: [batch_size, seq_len]
    seq_k: [batch_size, seq_len]
    seq_len could be src_len or it could be tgt_len
    seq_len in seq_q and seq_len in seq_k maybe not equal
    '''
    batch_size, len_q = seq_q.size()
    batch_size, len_k = seq_k.size()
    # eq(zero) is PAD token
    pad_attn_mask = seq_k.data.eq(0).unsqueeze(1)  # [batch_size, 1, len_k], False is masked
    pad_attn_mask = pad_attn_mask.expand(batch_size, len_q, len_k)
    return pad_attn_mask  # [batch_size, len_q, len_k]

enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs)  # [batch_size, src_len, src_len]
self_attn_mask = enc_self_attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1)
print(self_attn_mask.size())


torch.Size([2, 8, 5, 5])


In [61]:
def get_attn_subsequence_mask(seq):
    '''
    seq: [batch_size, tgt_len]
    '''
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
    subsequence_mask = np.triu(np.ones(attn_shape), k=1)  # Upper triangular matrix
    subsequence_mask = torch.from_numpy(subsequence_mask).byte()
    return subsequence_mask  # [batch_size, tgt_len, tgt_len]

dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs).to(device)
dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).to(device)
dec_self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequence_mask), 0).to(device)
print("dec_self_attn_mask:\n", dec_self_attn_mask)
dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs)
print("dec_enc_attn_mask:\n", dec_enc_attn_mask)

print("训练用GT的batch合并:", dec_outputs.view(-1))

dec_input = torch.zeros(1, 0)
print(dec_input)
start_symbol=tgt_vocab["S"]
dec_input = torch.cat([dec_input.detach(), torch.tensor([[start_symbol]]).to(device)], -1)
print(dec_input)



dec_self_attn_mask:
 tensor([[[False,  True,  True,  True,  True,  True],
         [False, False,  True,  True,  True,  True],
         [False, False, False,  True,  True,  True],
         [False, False, False, False,  True,  True],
         [False, False, False, False, False,  True],
         [False, False, False, False, False, False]],

        [[False,  True,  True,  True,  True,  True],
         [False, False,  True,  True,  True,  True],
         [False, False, False,  True,  True,  True],
         [False, False, False, False,  True,  True],
         [False, False, False, False, False,  True],
         [False, False, False, False, False, False]]])
dec_enc_attn_mask:
 tensor([[[False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True]],

        [[False, False, False, False,  Tr