In [44]:
import torch
import torch.nn.functional as F
import os
import zipfile
import collections

In [45]:
import torch.nn as nn

In [46]:
class MLP(nn.Module):
    def __init__(self, embedding_size, hidden_size, output_size):
        super().__init__()
        self.linear1 = nn.Linear(embedding_size, hidden_size)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        return x

In [15]:
class SelfAttention(nn.Module):
    def __init__(self, embedding_size):
        super().__init__()
        self.attention = nn.MultiheadAttention(embedding_size, 1, batch_first = True)
        self.is_causal = True
    
    def forward(self, embeddings):
        batch, word_count, embedding_size = embeddings.shape
        attn_mask = nn.Transformer.generate_square_subsequent_mask(word_count, device = embeddings.device)
        output = self.attention(
            embeddings,
            embeddings,
            embeddings,
            is_causal = True,
            attn_mask = attn_mask
        )
        return output[0] + embeddings

In [47]:
class CrossAttention(nn.Module):
    def __init__(self, embedding_size):
        super().__init__()
        self.attention = nn.MultiheadAttention(embedding_size, 1, batch_first = True)

    def forward(self, french_embeddings, english_embeddings):
        # shape batch_size x number_of_english_embeddings x number_of_french_embeddings
        output = self.attention(
            english_embeddings, 
            french_embeddings, 
            french_embeddings, 
            )
        return output[0] + english_embeddings

In [48]:
class Vocab:
    """Vocabulary for text."""
    def __init__(self, tokens=None, min_freq=0, reserved_tokens=None):
        if tokens is None:
            tokens = []
        if reserved_tokens is None:
            reserved_tokens = []
        # Sort according to frequencies
        counter = count_corpus(tokens)
        self._token_freqs = sorted(counter.items(), key=lambda x: x[1],
                                   reverse=True)
        # The index for the unknown token is 0
        self.idx_to_token = ['<unk>'] + reserved_tokens
        self.token_to_idx = {token: idx
                             for idx, token in enumerate(self.idx_to_token)}
        for token, freq in self._token_freqs:
            if freq < min_freq:
                break
            if token not in self.token_to_idx:
                self.idx_to_token.append(token)
                self.token_to_idx[token] = len(self.idx_to_token) - 1

    def __len__(self):
        return len(self.idx_to_token)

    def __getitem__(self, tokens):
        if not isinstance(tokens, (list, tuple)):
            return self.token_to_idx.get(tokens, self.unk)
        return [self.__getitem__(token) for token in tokens]

    def to_tokens(self, indices):
        if not isinstance(indices, (list, tuple)):
            return self.idx_to_token[indices]
        return [self.idx_to_token[index] for index in indices]

    @property
    def unk(self):  # Index for the unknown token
        return 0

    @property
    def token_freqs(self):  # Token frequencies
        return self._token_freqs

def count_corpus(tokens):
    """Count token frequencies."""
    # Here `tokens` is a 1D list or 2D list
    if len(tokens) == 0 or isinstance(tokens[0], list):
        # Flatten a list of token lists into a list of tokens
        tokens = [token for line in tokens for token in line]
    return collections.Counter(tokens)

def download(url, cache_dir=os.path.join('..', 'data')):
    """Download a file, return the local filename."""
    os.makedirs(cache_dir, exist_ok=True)
    fname = os.path.join(cache_dir, url.split('/')[-1])
    if os.path.exists(fname):
        with open(fname, 'rb') as f:
            while True:
                data = f.read(1048576)
                if not data:
                    break
        return fname
    print(f'Downloading {fname} from {url}...')
    r = requests.get(url, stream=True, verify=True)
    with open(fname, 'wb') as f:
        f.write(r.content)
    return fname

def download_extract(url, folder=None):
    """Download and extract a zip file."""
    fname = download(url)
    base_dir = os.path.dirname(fname)
    data_dir, ext = os.path.splitext(fname)
    if ext == '.zip':
        fp = zipfile.ZipFile(fname, 'r')
    else:
        assert False, 'Only zip files can be extracted.'
    fp.extractall(base_dir)
    return os.path.join(base_dir, folder) if folder else data_dir

def read_data_nmt():
    """Load the English-French dataset."""
    data_dir = download_extract('http://d2l-data.s3-accelerate.amazonaws.com/fra-eng.zip')
    with open(os.path.join(data_dir, 'fra.txt'), 'r') as f:
        return f.read()

def preprocess_nmt(text):
    """Preprocess the English-French dataset."""
    def no_space(char, prev_char):
        return char in set(',.!?') and prev_char != ' '

    # Replace non-breaking space with space, and convert uppercase letters to
    # lowercase ones
    text = text.replace('\u202f', ' ').replace('\xa0', ' ').lower()
    # Insert space between words and punctuation marks
    out = [' ' + char if i > 0 and no_space(char, text[i - 1]) else char
           for i, char in enumerate(text)]
    return ''.join(out)

def tokenize_nmt(text, num_examples=None):
    """Tokenize the English-French dataset."""
    source, target = [], []
    for i, line in enumerate(text.split('\n')):
        if num_examples and i > num_examples:
            break
        parts = line.split('\t')
        if len(parts) == 2:
            source.append(parts[0].split(' '))
            target.append(parts[1].split(' '))
    return source, target

def truncate_pad(line, num_steps, padding_token):
    """Truncate or pad sequences."""
    if len(line) > num_steps:
        return line[:num_steps]  # Truncate
    return line + [padding_token] * (num_steps - len(line))  # Pad

def build_array_nmt(lines, vocab, num_steps):
    """Transform text sequences of machine translation into mini-batches."""
    lines = [vocab[l] for l in lines]
    lines = [[vocab['<bos>']] + l + [vocab['<eos>']] for l in lines]
    array = torch.tensor([truncate_pad(
        l, num_steps, vocab['<pad>']) for l in lines])
    valid_len = (array != vocab['<pad>']).type(torch.int32).sum(1)
    return array, valid_len

def load_array(data_arrays, batch_size, is_train=True):
    """Construct a PyTorch data iterator."""
    dataset = torch.utils.data.TensorDataset(*data_arrays)
    return torch.utils.data.DataLoader(dataset, batch_size, shuffle=is_train)

def load_data_nmt(batch_size, num_steps, num_examples=600):
    """Return the iterator and the vocabularies of the translation dataset."""
    text = preprocess_nmt(read_data_nmt())
    source, target = tokenize_nmt(text, num_examples)
    src_vocab = Vocab(source, min_freq=2,
                          reserved_tokens=['<pad>', '<bos>', '<eos>'])
    tgt_vocab = Vocab(target, min_freq=2,
                          reserved_tokens=['<pad>', '<bos>', '<eos>'])
    src_array, src_valid_len = build_array_nmt(source, src_vocab, num_steps)
    tgt_array, tgt_valid_len = build_array_nmt(target, tgt_vocab, num_steps)
    data_arrays = (src_array, src_valid_len, tgt_array, tgt_valid_len)
    data_iter = load_array(data_arrays, batch_size)
    return data_iter, src_vocab, tgt_vocab

In [49]:
dataloader, vocab_english, vocab_french = load_data_nmt(2, 4)

In [57]:
for X, _, Y, _ in dataloader:
    break

In [55]:
vocab_english.to_tokens(10), vocab_english.to_tokens(73), vocab_english.to_tokens(4)

('tom', 'left', '.')

In [56]:
vocab_french.to_tokens(8), vocab_french.to_tokens(0), vocab_french.to_tokens(4)

('tom', '<unk>', '.')

In [58]:
X

tensor([[ 2,  9, 82,  4],
        [ 2, 36, 12,  5]])

In [59]:
Y

tensor([[  2,  67,   5,   3],
        [  2, 131,   5,   3]])

In [60]:
class EncoderDecoderTransformer(nn.Module):
    def __init__(self, unique_french_words, unique_english_words, embedding_size):
        super().__init__()

        self.french_embedding_layer = nn.Embedding(unique_french_words, embedding_size)

        self.attn1 = SelfAttention(embedding_size)
        self.mlp1 = MLP(embedding_size, 2 * embedding_size, embedding_size)

        self.attn2 = SelfAttention(embedding_size)
        self.mlp2 = MLP(embedding_size, 2*embedding_size, embedding_size)

        self.attn3 = SelfAttention(embedding_size)
        self.mlp3 = MLP(embedding_size, 2*embedding_size, embedding_size)

        self.english_embedding_layer = nn.Embedding(unique_english_words, embedding_size)

        self.attn4 = SelfAttention(embedding_size)
        self.cross_attn4 = CrossAttention(embedding_size)
        self.mlp4 = MLP(embedding_size, embedding_size * 2, embedding_size)

        self.attn5 = SelfAttention(embedding_size)
        self.cross_attn5 = CrossAttention(embedding_size)
        self.mlp5 = MLP(embedding_size, embedding_size * 2, embedding_size)

        self.attn6 = SelfAttention(embedding_size)
        self.cross_attn6 = CrossAttention(embedding_size)
        self.mlp6 = MLP(embedding_size, embedding_size * 2, embedding_size)

        self.to_out = nn.Linear(embedding_size, unique_english_words)

    def forward(self, french_input, incomplete_english_input):
        french_embeddings = self.french_embedding_layer(french_input)

        french_embeddings2 = self.attn1(french_embeddings)
        french_embeddings2 = self.mlp1(french_embeddings2)
        french_embeddings2 += french_embeddings
        
        french_embeddings3 = self.attn2(french_embeddings2)
        french_embeddings3 = self.mlp2(french_embeddings3)
        french_embeddings3 += french_embeddings2

        french_embeddings4 = self.attn3(french_embeddings3)
        french_embeddings4 = self.mlp3(french_embeddings4)
        french_embeddings4 += french_embeddings3

        english_embeddings = self.english_embedding_layer(incomplete_english_input)
        
        english_embeddings2 = self.attn4(english_embeddings)
        english_embeddings2 = self.cross_attn4(french_embeddings4, english_embeddings2)
        english_embeddings2 = self.mlp4(english_embeddings2)
        english_embeddings2 += english_embeddings

        english_embeddings3 = self.attn5(english_embeddings2)
        english_embeddings3 = self.cross_attn5(french_embeddings4, english_embeddings3)
        english_embeddings3 = self.mlp5(english_embeddings3)
        english_embeddings3 += english_embeddings2

        english_embeddings4 = self.attn6(english_embeddings3)
        english_embeddings4 = self.cross_attn6(french_embeddings4, english_embeddings4)
        english_embeddings4 = self.mlp6(english_embeddings4)
        english_embeddings4 += english_embeddings3

        output = self.to_out(english_embeddings4)

        return output

In [62]:
loss_function = nn.CrossEntropyLoss(ignore_index = 1)

In [63]:
transformer = EncoderDecoderTransformer(len(vocab_french), len(vocab_english), 256)

In [64]:
optimizer = torch.optim.SGD(transformer.parameters(), lr = 0.01)

In [65]:
transformer.to("cuda")

EncoderDecoderTransformer(
  (french_embedding_layer): Embedding(201, 256)
  (attn1): SelfAttention(
    (attention): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
    )
  )
  (mlp1): MLP(
    (linear1): Linear(in_features=256, out_features=512, bias=True)
    (relu): ReLU()
    (linear2): Linear(in_features=512, out_features=256, bias=True)
  )
  (attn2): SelfAttention(
    (attention): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
    )
  )
  (mlp2): MLP(
    (linear1): Linear(in_features=256, out_features=512, bias=True)
    (relu): ReLU()
    (linear2): Linear(in_features=512, out_features=256, bias=True)
  )
  (attn3): SelfAttention(
    (attention): MultiheadAttention(
      (out_proj): NonDynamicallyQuantizableLinear(in_features=256, out_features=256, bias=True)
    )
  )
  (mlp3): MLP(
    (linear1): Linear(in_features=256, out_features=512

In [66]:
for i in range(1000):
    epoch_loss = 0
    batch_count = 0
    for english_batch, _, french_batch, _ in dataloader:
        english_batch = english_batch.to("cuda")
        french_batch = french_batch.to("cuda")
        # Je suis bored_french
        # <bos> I am
        # model says: I am sleepy_english 
        # wanted output: I am bored_english
        model_prediction = transformer(french_batch, english_batch)
        
        model_next_prediction = model_prediction[:, :-1].permute(0, 2, 1)
        model_wanted_output = english_batch[:, 1:]

        loss = loss_function(model_next_prediction, model_wanted_output)
        epoch_loss += loss
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        batch_count += 1

    print(epoch_loss / batch_count)

tensor(3.5355, device='cuda:0', grad_fn=<DivBackward0>)
tensor(2.5557, device='cuda:0', grad_fn=<DivBackward0>)
tensor(2.1788, device='cuda:0', grad_fn=<DivBackward0>)
tensor(1.9028, device='cuda:0', grad_fn=<DivBackward0>)
tensor(1.7324, device='cuda:0', grad_fn=<DivBackward0>)
tensor(1.5694, device='cuda:0', grad_fn=<DivBackward0>)
tensor(1.4071, device='cuda:0', grad_fn=<DivBackward0>)
tensor(1.2649, device='cuda:0', grad_fn=<DivBackward0>)
tensor(1.1648, device='cuda:0', grad_fn=<DivBackward0>)
tensor(1.0494, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.9609, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.8753, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.7942, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.7511, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.6928, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.6529, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.6343, device='cuda:0', grad_fn=<DivBackward0>)
tensor(0.5936, device='cuda:0', grad_fn=<DivBack

KeyboardInterrupt: 

In [41]:
def predict_seq2seq(net, src_sentence, src_vocab, tgt_vocab, num_steps):
    """Predict for sequence to sequence."""
    # Set `net` to eval mode for inference
    net.eval()
    src_tokens = [src_vocab['<bos>']] + src_vocab[src_sentence.lower().split(' ')] + [
        src_vocab['<eos>']]

    output_seq = [tgt_vocab['<bos>']]
    for _ in range(num_steps):
        Y = net(torch.tensor(src_tokens).unsqueeze(0).to("cuda"), torch.tensor(output_seq).unsqueeze(0).to("cuda"))
        # We use the token with the highest prediction likelihood as the input
        # of the decoder at the next time step
        dec_X = Y[:, -1, :].argmax()
        pred = dec_X.item()
        if pred == tgt_vocab['<eos>']:
            break
        output_seq.append(pred)
    return ' '.join(tgt_vocab.to_tokens(output_seq))

In [68]:
predict_seq2seq(transformer, "j'ai perdu .", vocab_french, vocab_english, 4)

'<bos> i lost .'