In [2]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

In [3]:
src_vocab_size = 5000
tgt_vocab_size = 5000
d_model = 512
num_heads = 8
num_layers = 6
d_ff = 2048
max_seq_length = 100
dropout = 0.1

# transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)

# Generate random sample data
src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)

In [8]:
src_data[0]

tensor([3696,  567, 1260, 3447, 2848, 3346, 1956, 1942,  406,  797, 1865, 1405,
        2751, 2949, 2954, 2310, 2281,   44, 1428, 3923, 2286, 4694,   78, 1895,
        2875, 3763, 4156, 1972, 3645, 2169, 3961, 3762, 1171, 1855, 3018,  182,
        2772,  237,  855, 4553, 1595, 3147,  755, 1148, 1929, 2253, 1630, 4720,
        1061, 4548, 1626,  482, 4905, 4906, 4803, 2428,  706,  941,  880, 3301,
        4912, 3122,  367, 3986, 1820, 1878, 3637, 3129, 1639,  962,  953, 1773,
         622, 2947, 3731, 3727,  459, 4226,  479, 2636, 3324,  423, 1439, 4801,
        3788, 1982, 1625, 1948, 3455, 1953, 1307,   42,   61,  122,  411, 3821,
        4130, 2506, 4817, 1324])

In [12]:
tgt_data[:, :-1]

tensor([[3599, 1104, 1801,  ..., 1914,  292,  413],
        [4919, 2449,  587,  ..., 4325, 2466, 1288],
        [1868, 2833, 1877,  ..., 4059, 1336, 4121],
        ...,
        [ 939, 1600, 2281,  ...,  492, 3069, 4348],
        [2953, 3642, 2039,  ..., 1041, 4982, 1523],
        [3898, 2968,  320,  ..., 4523,  491, 1928]])

In [13]:
tgt_data[:, 1:]

tensor([[1104, 1801, 1554,  ...,  292,  413,  362],
        [2449,  587, 1519,  ..., 2466, 1288, 4714],
        [2833, 1877, 2751,  ..., 1336, 4121, 1597],
        ...,
        [1600, 2281, 3063,  ..., 3069, 4348, 4905],
        [3642, 2039, 3901,  ..., 4982, 1523,   11],
        [2968,  320,  431,  ...,  491, 1928, 4560]])

In [14]:
vocab_size = 10 # including [mask] and [pad]
max_len = 5
num_seq = 5

def gen_sample_data(vocab_size, max_len, num_seq):
    """generate a list of text with variable lengths
    """
    # minus 2 for [0: padding ,1: mask]
    gen_single_sequence = lambda : torch.randint(2, vocab_size-3, size=(torch.randint(1,max_len, size=(1,)),))
    return [gen_single_sequence() for _ in range(num_seq)]

seqs = gen_sample_data(vocab_size, max_len, num_seq)

def batch_data(data):
    """Generate batched_data with padding
    """
    num_samples = len(data)
    full_data = torch.zeros(num_samples, max_len)
    for i, sent in enumerate(data):
        min_length = min(len(sent), max_len)
        full_data[i, :min_length] = sent[:min_length]
    return full_data.long()

batch_data = batch_data(seqs)
batch_data

tensor([[2, 4, 0, 0, 0],
        [5, 5, 0, 0, 0],
        [3, 4, 5, 6, 0],
        [4, 2, 6, 2, 0],
        [4, 3, 3, 2, 0]])

In [15]:
masking_prob = 0.15 
full_mask = torch.randn(batch_data.shape) < masking_prob
full_mask

tensor([[ True, False, False, False,  True],
        [ True,  True, False,  True,  True],
        [False,  True,  True, False, False],
        [ True,  True,  True,  True, False],
        [False,  True,  True,  True, False]])

In [16]:
special_tokens = [0]
for tk in special_tokens:
    full_mask = full_mask & (batch_data != tk)
full_mask

tensor([[ True, False, False, False, False],
        [ True,  True, False, False, False],
        [False,  True,  True, False, False],
        [ True,  True,  True,  True, False],
        [False,  True,  True,  True, False]])

In [17]:
random_prob = 0.1
random_mask = torch.randn(batch_data.shape) < random_prob
# for all the tokens that should be masked, select those that should be randomly masked
full_mask_with_random = full_mask & (random_mask)
full_mask_with_random

tensor([[False, False, False, False, False],
        [False, False, False, False, False],
        [False,  True,  True, False, False],
        [False,  True, False, False, False],
        [False,  True, False,  True, False]])

In [18]:
unchanged_prob = 0.1
unchanged_mask = torch.randn(batch_data.shape) < unchanged_prob
# for all the tokens that should be masked, select those that should be unchanged
full_mask_with_unchanged = full_mask & (unchanged_mask)
full_mask_with_unchanged

tensor([[ True, False, False, False, False],
        [False, False, False, False, False],
        [False,  True,  True, False, False],
        [ True,  True,  True,  True, False],
        [False, False,  True, False, False]])

In [19]:
# get the mask for [mask] tokens
full_mask_with_mask = full_mask & (~full_mask_with_random) & (~full_mask_with_unchanged)
full_mask_with_mask

tensor([[False, False, False, False, False],
        [ True,  True, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False],
        [False, False, False, False, False]])

In [21]:
final_mask = batch_data.clone()

num_random_tokens = full_mask_with_random.sum().item()
random_tokens = torch.randint(0, vocab_size, size=(num_random_tokens,))
indices = torch.nonzero(full_mask_with_random, as_tuple=True) # (returns tuples *(x,y), *(x,y)
final_mask[indices] = random_tokens

In [22]:
mask_token = 1
final_mask = final_mask.masked_fill_(full_mask_with_mask, mask_token)

In [23]:
y = batch_data.clone()
padding_token = 0
y = y.masked_fill_(~full_mask, padding_token)

In [25]:
batch_data

tensor([[2, 4, 0, 0, 0],
        [5, 5, 0, 0, 0],
        [3, 4, 5, 6, 0],
        [4, 2, 6, 2, 0],
        [4, 3, 3, 2, 0]])

In [24]:
final_mask

tensor([[2, 4, 0, 0, 0],
        [1, 1, 0, 0, 0],
        [3, 4, 1, 6, 0],
        [4, 2, 6, 2, 0],
        [4, 9, 3, 2, 0]])

In [26]:
y

tensor([[2, 0, 0, 0, 0],
        [5, 5, 0, 0, 0],
        [0, 4, 5, 0, 0],
        [4, 2, 6, 2, 0],
        [0, 3, 3, 2, 0]])