Import packages

In [1]:
import torch
from torch import nn

import spacy

import math

  from .autonotebook import tqdm as notebook_tqdm


Load English and German tokenizers using Spacy.

In [2]:
eng_lang = spacy.load("en_core_web_sm")
ger_lang = spacy.load("de_core_news_sm")

Specify device as GPU

In [3]:
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
print(device)

cuda:0


Create a single batch of random input just to use to check if the shape initializations in the below classes are correct.

In [4]:
torch.cuda.empty_cache()

In [5]:
# Test input has shape (N, seq_len, embed_len) => ((8, 20, 512)).
# I decreased the batch size from 64 to 8 because of memory issues of my GPU.
# 64 was too big for the memory, and the biggest size that worked was 8.
test_query = torch.rand((8, 20, 512)).to(device)
test_key = torch.rand((8, 20, 512)).to(device)
test_value = torch.rand((8, 20, 512)).to(device)

Set hyperparameters. I used the same ones as those used in the paper.

In [8]:
num_heads = 8
d_model = 512

Begin building the Transformer. The first step is to build the 'Scaled Dot-Product Attention' block mentioned in the paper. This is still just the first draft; it will probably need some fixes once I get to later stages.

In [9]:
class ScaledDotProduct(nn.Module):
    def __init__(self, queries):
        super(ScaledDotProduct, self).__init__()
        
        # I probably don't need to initialize K, Q, and V here, since they will be passed to the 
        # scaled dot product when we call it from the MultiHeadAttention class in the forward method.
        # Will delete the variable below later if I turn out to be right.
        self.queries = queries
        #self.keys = keys
        #self.values = values
        self.dk = self.queries.shape[1]
        #print('self.dk type: ', type(self.dk))

        # Softmax operator. 'dim' still needs to be specified
        self.softmax = nn.Softmax()

    # Define the forward function
    def forward(self, queries, keys, values):
        compatibility = torch.bmm(queries, torch.transpose(keys, 1, 2))   # first batch MatMul operation
        compatibility = compatibility / math.sqrt((self.dk))             # scaling down by sqrt(dk)
        compatibility_softmax = self.softmax(compatibility)               # normalizing using Softmax
        output = torch.bmm(compatibility_softmax, values)                 # final batch MatMul operation

        return output

Build the 'Multi-Head Attention' block.

In [25]:
class MultiHeadAttention(nn.Module):
    def __init__(self, h, d_model, queries, keys, values):
        super(MultiHeadAttention, self).__init__()
        self.num_heads = h
        self.batch_num = queries.shape[0]
        self.seq_len = queries.shape[1]
        self.embed_len = queries.shape[2]
        self.d_model = d_model
        self.queries = queries
        self.keys = keys
        self.values = values
        self.head_length = self.d_model/self.num_heads

        self.concat_output = []

        # For an input, Q for example, which would originally have a shape
        # of (N, seq_len, embed_len), it would be split up into the number of 
        # heads that we define (ex: 8). So, the new shape would be
        # (N, seq_len, embed_len/8). This would also apply to K and V too.

        # Since we are flattening batches of matrices, I'm not sure if the flattening
        # should be done in another way. I'll come back to this later if it needs changing.
        self.q_in = (torch.flatten(self.queries, start_dim=1, end_dim=2) / self.num_heads).shape[1]        
        self.k_in = (torch.flatten(self.keys, start_dim=1, end_dim=2) / self.num_heads).shape[1]
        self.v_in = (torch.flatten(self.values, start_dim=1, end_dim=2) / self.num_heads).shape[1]
        
        
        # For the input of each Linear layer, we would have the divided Q, K, 
        # and V calculated above. q_in, k_in, and v_in = 10240 each.
        print(type(self.q_in))
        print(self.q_in)
        self.q_linear = nn.Linear(int(self.q_in/8), int(self.q_in/8))
        self.k_linear = nn.Linear(int(self.k_in/8), int(self.k_in/8))
        self.v_linear = nn.Linear(int(self.v_in/8), int(self.v_in/8))

        # Attention layer.
        self.attention = ScaledDotProduct(self.queries)

        # This is the final Linear layer, after the outputs of all the heads
        # from the Scaled Dot Product layer have been concatenated together. The
        # output dimension of this layer is a hyperparameter that we define. Here
        # we use d_model, which is 512.
        self.output_linear = nn.Linear(self.q_in, self.d_model)

    def forward(self, queries, keys, values):
        # Feed the 8 heads of Q, K, and V into the linear layers in parallel, and then into the
        # attention block. Let's say the original tensor Q has the following shape: 
        # (N, seq_len, embed_len) -> (8, 20, 512).
        # The segment that will go into each head will be of the following size:
        # (N, seq_len, embed_len/num_heads) -> (8, 20, 64). So we need to slice the third dimension.
        for i in range(self.num_heads):

            # The output of each of the linear layers has length -> (N, seq_len*embed_len/num_heads) -> (N, 1280)
            q_linear_output = self.q_linear(torch.flatten(queries[:, :, int(i*self.head_length):int((i+1)*self.head_length)], start_dim=1, end_dim=2))
            k_linear_output = self.k_linear(torch.flatten(keys[:, :, int(i*self.head_length):int((i+1)*self.head_length)], start_dim=1, end_dim=2))
            v_linear_output = self.v_linear(torch.flatten(values[:, :, int(i*self.head_length):int((i+1)*self.head_length)], start_dim=1, end_dim=2))

            #print('q_linear_output shape: ', q_linear_output.shape)

            # Since the three outputs computed from the linear layers above are just 1D vectors of length
            # (N, seq_len*embed_len/num_heads) -> (N, 1280), and the ScaledDotProduct forward method expects 3D tensors,
            # I will reshape the 1D vectors into 3D tensors of shape (N, seq_len, embed_len/num_heads)
            q_reshaped_output = torch.reshape(q_linear_output, (self.batch_num, self.seq_len, int(self.embed_len/self.num_heads)))
            k_reshaped_output = torch.reshape(k_linear_output, (self.batch_num, self.seq_len, int(self.embed_len/self.num_heads)))
            v_reshaped_output = torch.reshape(v_linear_output, (self.batch_num, self.seq_len, int(self.embed_len/self.num_heads)))

            #print('q_reshaped_output shape: ', q_reshaped_output.shape)

            # Feed reshaped Q, K, and V into ScaledDotProduct layer.
            # 'sdp_output' should have shape (N, seq_len, embed_len/num_heads)
            sdp_output = self.attention.forward(q_reshaped_output, k_reshaped_output, v_reshaped_output)

            # Each 'sdp_output' is a Tensor of shape (N, seq_len, embed_len/num_heads) -> (8, 20, 64).
            # Each flattened Tensor has length (8, 20*64) = 10240
            #print('sdp_output shape: ', sdp_output.shape)
            #print('sdp_output flattened length: ', torch.flatten(sdp_output, start_dim=1, end_dim=2).shape)
            
            # We need to concatenate the outputs of all the heads
            # into one vector and pass it through a final linear layer
            self.concat_output.append(torch.flatten(sdp_output, start_dim=1, end_dim=2))
            
        flattened_concat_output = torch.flatten(torch.stack(self.concat_output), start_dim=1, end_dim=2)
        
        # Pass the concatenated vector in a final linear layer and return output
        return self.output_linear(flattened_concat_output)

Test shapes

In [26]:
multihead = MultiHeadAttention(num_heads, d_model, test_query, test_key, test_value).to(device)

<class 'int'>
10240


In [27]:
test_output = multihead.forward(test_query, test_key, test_value)

flattened_concat_output shape:  torch.Size([8, 10240])


  compatibility_softmax = self.softmax(compatibility)               # normalizing using Softmax


Building the Encoder. 

In [None]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()
        