In [2]:
# Reference: https://towardsdatascience.com/build-your-own-transformer-from-scratch-using-pytorch-84c850470dcb
import torch
import torch.nn as nn
import torch.optim as optim
import torch.utils.data as data
import math
import copy

In [13]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_model, num_heads):
        ''' attention(q,k,v) = sum_i(similarity(q, k_i)*v_i
            where q,k,v are n-dim learned vectors from input embedding
            and similarity is the scalar result of q dot k
            
            Hence, attention is a learned representation of input embedding
            x_1 from sentence (or sequence) [x_1, x_2, ..., x_i],
            based on weighted average of v_i's, where the scalar weights are 
            based on similarity of q = W_q*x_i among the rest of the 
            k_i = W_k*x_i (tokens in the same sentence).
        '''
        super().__init__()
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads" # why?

        self.d_model = d_model # d_model is the concat'd dimensionality of the num_head instances of the input_embeddings
        self.num_heads = num_heads # num_heads is the # of kinds of weights to be learned for the weigthed averaging operation
        self.d_k = d_model // num_heads # d_k is the input embedding dimensionality (vector length)

        self.W_q = nn.Linear(d_model, d_model) # y = x*W.t + b, weights to convert input embedding to learned representation Q
        self.W_k = nn.Linear(d_model, d_model)
        self.W_v = nn.Linear(d_model, d_model)
        self.W_o = nn.Linear(d_model, d_model) # I'm guessing that the purpose of this is to learn a final 
                                               # combination for the multi-head attentions to produce a single long attention vector

    def scaled_dot_product_attention(self, Q, K, V, mask=None):
        # What are the shapes of Q,K, and V?
        # Q,K,V shape is (batch_size, seq_len, d_model) when combined,
        # (batch_size, num_heads, seq_len, d_k) when split into each attention vector
        # when they enter this operation, they're SPLIT
        
        # K's last two dim (seq_len, d_k) are swapped (transposed)
        # This matmul computes the similarity of all tokens against each of the other tokens (hence seq_len,seq_len shape)
        # It does this per head as well.
        # Shape of similarity_scores: (batch_sz, num_heads, seq_len, seq_len)
        similarity_scores = torch.matmul(Q, K.transpose(-2, -1)) / math.sqrt(self.d_k)
        
        if mask is not None:
            similarity_scores = similarity_scores.masked_fill(mask == 0, -1e9) # where mask is 0, replace tensor element value with -1e9 (e^x, used by softmax later, is 0 at -inf)
         
        similarity_probs = torch.softmax(similarity_scores, dim=-1) # apply softmax on last dim (seq_len column, 
                                                                    # so each row adds to 1, each column represents a 
                                                                    # token's similarity to token_i=row_num)
        # Output shape is back to (batch_sz, num_head, seq_len, d_k) 
        # Double check how each score is applied to V, it doesn't seem to perform the weighted average operation
        # The intuition here follows the "row interpretation" of a matrix multiplication
        # where matmul can be thought of as: each row of the matrix on the left is a vector of weights for taking a linear combination of the rows of the matrix on the right! 
        # That is, each row on the left produces its own linear combination of the rows of the matrix on the right, and they all get stacked together row-by-row in the output matrix.
        # see: https://forums.fast.ai/t/fun-matrix-math-example-the-transformers-attention-mechanism/41606
        output = torch.matmul(similarity_probs, V)
        return output

    def split_heads(self, x):
        batch_size, seq_length, d_model = x.size()
        # Final shape: (batch_sz, num_heads, seq_len, d_k)
        return x.view(batch_size, seq_length, self.num_heads, self.d_k).transpose(1,2)

    def combine_heads(self, x):
        batch_size, _, seq_length, d_k = x.size()
        # Return as contiguous chunk in memory
        # Final shape: (batch_sz, seq_len, d_model)
        return x.transpose(1,2).contiguous().view(batch_size, seq_length, self.d_model)

    def forward(self, Q, K, V, mask=None):
        Q = self.split_heads(self.W_q(Q)) # Note how the QKV operations are performed per head, separately
        K = self.split_heads(self.W_k(K)) # while on W_o, all attentions are then combined to transform into one.
        V = self.split_heads(self.W_v(V))
                                                                                                                            
        attn_output = self.scaled_dot_product_attention(Q, K, V, mask)
        output = self.W_o(self.combine_heads(attn_output))

        return output # shape is (batch_sz, seq_len, d_model)
    

In [4]:
t = MultiHeadAttention(6,2) # d_model=6,num_heads=2 
x = torch.tensor([[[1.,2.,3.,4.,5.,6.], [2.,4.,6.,8.,10.,12.]]]) # batch_sz=1, seq_len=2, d_model=6
y = t(x, x, x)

In [49]:
class PositionWiseFeedForward(nn.Module):
    ''' To be stacked after MultiHeadAttn. This will take in the
        attn vector combined (via add & norm) with the input embedding
        via a residual connection (it skipped the MultiHead).
        I don't know yet why it is designed this way.
        Maybe to retain more info about the input embedding AND POSITIONAL ENCODING
        instead of just indirectly working on it via Q, K, and V?
    '''
    def __init__(self, d_model, d_ff):
        super().__init__()
        self.fc1 = nn.Linear(d_model, d_ff) # in_dim=d_model, out_dim=d_ff, where d_ff is a hyperparam for internal dim of feedforward layer
        self.fc2 = nn.Linear(d_ff, d_model) # final output shape of feed forward is d_model
        self.relu = nn.ReLU()

    def forward(self, x):
        return self.fc2(self.relu(self.fc1(x))) # why no activation on last layer?
                                                # range is all real numbers?

class EncoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        # d_model = dimension of the concatenated learned embeddings from all heads
        # num_heads = number of attention heads that will learn an embedding
        # d_ff = feed forward inner dimension of the 2-layer MLP (d_model, d_ff, d_model)
        # dropout = probability of dropout applied on output of MultiHeadAttn & of MLP layer
        
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model) # standard normalization with learnable scale & bias
        self.norm2 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout) # randomly zero out some input elements with probability=dropout, for regularization

    def forward(self, x, mask):
        # mask is a matrix where zero elements will result to masking out some learned embeddings in V
        
        attn_output = self.dropout(self.self_attn(x, x, x, mask)) # MultiHeadAttn
        x = self.norm1(x + attn_output)                           # Add & Norm (w/ residual connx)

        ff_output = self.dropout(self.feed_forward(x))            # Feed Forward
        x = self.norm2(x + ff_output)                             # Add & Norm (w/ residual connx)

        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, num_heads, d_ff, dropout):
        super().__init__()
        self.self_attn = MultiHeadAttention(d_model, num_heads)
        self.cross_attn = MultiHeadAttention(d_model, num_heads)
        self.feed_forward = PositionWiseFeedForward(d_model, d_ff)
        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, enc_output, src_mask, tgt_mask):
        attn_output = self.dropout(self.self_attn(x, x, x, tgt_mask))
        x = self.norm1(x + attn_output)
        
        attn_output = self.dropout(self.cross_attn(x, enc_output, enc_output, src_mask)) #Q=x, K,V=enc_output
        x = self.norm2(x + attn_output)
        
        ff_output = self.dropout(self.feed_forward(x))  
        x = self.norm3(x + ff_output)                  

        return x  

In [38]:
class PositionalEncoding(nn.Module):
    def __init__(self, d_model, max_seq_len):
        # max_seq_len is the max sentence len (num of tokens in a sequence)
        super().__init__()

        pe = torch.zeros(max_seq_len, d_model) # init PosEnc as 2D matrix of zeros (each row is a token)
        # position = [[0],[1], ...]], column vector of positions from 0 to max_seq_len-1
        position = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1)
        # div_term is a freq defined for each dimension of the embedding of len d_model
        # skip by two bc this will be fed to both sin and cos (covering whole d_model len)
        # div_term is a row vector
        exp_term = torch.arange(0, d_model, 2).float() * -(math.log(10000.0) / d_model)
        div_term = torch.exp(exp_term) # this formula is same as in the reference, just more efficient.
                                       # just equate the term to e^x, and solve for x, then e^x again. 
                                       # See: https://ai.stackexchange.com/q/41670
        
        # matmul(col_vec,row_vec) to get matrix.shape=(max_seq_len, d_model//2) 
        pe[:, 0::2] = torch.sin(position * div_term) # broadcast, assign to all tokens and on even dims of embedding
        pe[:, 1::2] = torch.cos(position * div_term) # all max_seq_len tokens, on odd indices
        pe = pe.unsqueeze(0)
        # now we have a matrix (1, max_seq_len, d_model) with corresponding 
        # position value per token (ie row)
        
        # register_buffer => Tensor which is not a trainable-parameter,
        # but should be part of the modules variables.
        # persistent=False, to not add the buffer to the state dict (e.g. when we save the model)
        self.register_buffer('pe', pe, persistent=False)

    def forward(self, x):
        # expects input of shape (batch_sz, max_seq_len, d_model)
        return x + self.pe[:, :x.size(1)] # add to all batches, up to seq_len, vectors of len d_model

class Transformer(nn.Module):
    def __init__(self, src_vocab_sz, tgt_vocab_sz, d_model, num_heads,
                       num_layers, d_ff, max_seq_len, dropout):
        super().__init__()
        # src_vocab_sz = num of words in src vocab
        # tgt_vocab_sz = num of words in tgt vocab
        # d_model = resulting len of embedding vector
        # max_seq_len = maximum number of words or tokens in a sentence or sequence
        self.encoder_embedding = nn.Embedding(src_vocab_sz, d_model) # Simple lookup table that returns an embedding vector for given indices
        self.decoder_embedding = nn.Embedding(tgt_vocab_sz, d_model)
        self.positional_encoding = PositionalEncoding(d_model, max_seq_len)

        # num_layers = number of encoder and decoder layers for embedding to pass through
        self.encoder_layers = nn.ModuleList([EncoderLayer(d_model, num_heads, d_ff, dropout)]*num_layers)
        self.decoder_layers = nn.ModuleList([DecoderLayer(d_model, num_heads, d_ff, dropout)]*num_layers)

        self.fc = nn.Linear(d_model, tgt_vocab_size)
        self.dropout = nn.Dropout(dropout)

    def generate_mask(self, src, tgt):
        src_mask = (src != 0).unsqueeze(1).unsqueeze(2) # insert dim=1 at specified position
        tgt_mask = (tgt != 0).unsqueeze(1).unsqueeze(3)
        seq_len = tgt.size(1)
        nopeak_mask = (1 - torch.triu(torch.ones(1, seq_len, seq_len), diagonal = 1)).bool() # get matrix's lower triangle of True's 
        tgt_mask = tgt_mask & nopeak_mask
        
        return src_mask, tgt_mask

    def forward(self, src, tgt):
        # src, tgt = a sentence is a batch of vectors containing indices for words (to be embedded)
        #          = shape is (batch_sz, max_seq_len)
        src_mask, tgt_mask = self.generate_mask(src, tgt)
        src_embed = self.encoder_embedding(src) # outputs (batch_sz, max_seq_len, d_model)
        src_embed = self.dropout(self.positional_encoding(src_embed)) # add positional info
        
        tgt_embed = self.dropout(self.positional_encoding(self.decoder_embedding(tgt))) # same as above but for tgt embeds

        enc_output = src_embed
        for enc_layer in self.encoder_layers:
            enc_output = enc_layer(enc_output, src_mask)

        dec_output = tgt_embed
        for dec_layer in self.decoder_layers:
            # NOTE: the same final encoder output is used for all dec layers
            dec_output = dec_layer(dec_output, enc_output, src_mask, tgt_mask)

        output = sef.fc(dec_output) # shape is (batch_sz, seq_len, tgt_vocab_size)
        
        return output # shape (batch_sz, seq_len, d_model)
            
        

In [39]:
# Prepare sample data
src_vocab_size = 1000
tgt_vocab_size = 1000
d_model = 128
num_heads = 4
num_layers = 2
d_ff = 512
max_seq_length = 50
dropout = 0.1

transformer = Transformer(src_vocab_size, tgt_vocab_size, d_model, num_heads, num_layers, d_ff, max_seq_length, dropout)

# Generate random sample data
src_data = torch.randint(1, src_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)
tgt_data = torch.randint(1, tgt_vocab_size, (64, max_seq_length))  # (batch_size, seq_length)

In [48]:
# Sample training
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.Adam(transformer.parameters(), lr= 0.0001, betas=(0.9, 0.98), eps=1e-9)

transformer.train()

for epoch in range(10): # no batches, use same data
    optimizer.zero_grad()
    output = transformer(src_data, tgt_data[:, :-1]) # decoder input is 1 to before end
    _output = output.contiguous().view(-1, tgt_vocab_size) # -1 means derive that size from other dims
    target = tgt_data[:, 1:].contiguous().view(-1) # shifted sequence by 1 to the right (and flattened)
    
    loss = criterion(_output, target)
    loss.backward()
    optimizer.step()
    print(f"Epoch: {epoch+1}, Loss: {loss.item()}")
    

AttributeError: 'PositionWiseFeedForward' object has no attribute 'fc2'

In [20]:
seq_length = 4
nopeak_mask = (1 - torch.triu(torch.ones(1, seq_length, seq_length), diagonal = 1)).bool()


In [46]:
x = torch.randn(2,2,2)
x.view(-1, 4)

tensor([[ 0.5132, -0.7557, -1.5234, -0.4160],
        [ 1.7056, -1.4635, -1.1913, -1.3772]])