In [1]:
import tensorflow as tf
import numpy as np
import tensorflow_probability as tfp

# Implementing MaskedSoftmax Layer:

Computes a softmax matrix of size Tx * Tx for self attention or interconnected layers. In trailing mode, calculates such that each softmax output only attends to entry_j such that j <= i:

$[[1 0 0 0 0],
 [1 1 0 0 0],
 [1 1 1 0 0],
 [1 1 1 1 0],
 [1 1 1 1 1]]$
 
Also takes into account the mask produced by the embedding layer.

In [26]:
class MaskedSoftmax(tf.keras.layers.Layer):
    
    def __init__(self, trailing = False, **kwargs):
        super().__init__(**kwargs)
        self.trailing = trailing
        
    def build(self, input_shape):
        assert(len(input_shape) >= 3), 'Expect rank >3 input: (batch_size, Tx, d) or (batch_size, heads, Tx, d)'
        assert(input_shape[-1] == input_shape[-2]),'Last two ranks must be symmetrical for attn mechanism'
        
        if self.trailing:
            (Tx, d) = input_shape[-2:]

            num_ones = 0.5 * (Tx**2 + Tx)

            self.trailing_mask = tfp.math.fill_triangular(tf.ones(num_ones), upper = False)
        
    def call(self, z, mask = None):
        
        z = tf.exp(z)
            
        if self.trailing:
            z = tf.multiply(z, self.trailing_mask)            
            
        z = tf.divide(z, tf.reduce_sum(z, axis = -1, keepdims=True))  
        
        #if not mask is None:
        #    (m, Tx) = mask.get_shape()
        #    mask = tf.reshape(mask, (m, 1, Tx, 1)) # creates shape (m, 1, 1, Tx) to be compatible with num heads
        #    z = tf.multiply(z, tf.dtypes.cast(mask, 'float32'))
            
        return z

In [27]:
X = tf.ones((2, 1, 5, 5))
mask = tf.convert_to_tensor(np.array([[True, True, False, False, False], [True, True, True, True, False]]))
mask.get_shape()

TensorShape([2, 5])

In [28]:
t = MaskedSoftmax(trailing = True)

In [29]:
t(X, mask = mask)

<tf.Tensor: id=303, shape=(2, 1, 5, 5), dtype=float32, numpy=
array([[[[1.        , 0.        , 0.        , 0.        , 0.        ],
         [0.5       , 0.5       , 0.        , 0.        , 0.        ],
         [0.33333334, 0.33333334, 0.33333334, 0.        , 0.        ],
         [0.25      , 0.25      , 0.25      , 0.25      , 0.        ],
         [0.2       , 0.2       , 0.2       , 0.2       , 0.2       ]]],


       [[[1.        , 0.        , 0.        , 0.        , 0.        ],
         [0.5       , 0.5       , 0.        , 0.        , 0.        ],
         [0.33333334, 0.33333334, 0.33333334, 0.        , 0.        ],
         [0.25      , 0.25      , 0.25      , 0.25      , 0.        ],
         [0.2       , 0.2       , 0.2       , 0.2       , 0.2       ]]]],
      dtype=float32)>

# Multihead Attention Layer

1. Input is rank 3: (m, Tx, d_model)
2. Project to h-dimensional space (h = num heads):
<ol><li>a. expand dimension of X to be (m, 1, Tx, d_model)</li>
    <li>b. $X \in R^{m, 1, Tx, d_{model}} * W1 \in R^{h, d_{model}, dv} = Proj \in R^{m, h, Tx, dv}$</li>
</ol><br>
3. Calculate raw energies through dot product attention mechanism:<br><br>
$ \dfrac{1}{\sqrt{len(dv)}} Proj\cdot Proj^{T} = Energies \in R^{m, h, Tx, Tx}$<br><br>
4. Compute trailing (if decoder) or unmasked (if encoder) softmax to compute alphas:<br><br>
$ Softmax(Energies) = Alphas $<br><br>
5. Compute new context vectors through multiplication of Alphas (Query dot Key) and Values:<br><br>
$ Alphas \in R^{m, h, Tx, Tx} \cdot Proj \in R^{m, h, Tx, dv} = context \in R^{m, h, Tx, dv} $<br>
6. Stack context vectors from different heads back into original dimension:<br><br>
$ Stack(context) = context \in R^{m, Tx, h\cdot dv} $ <br><br>
7. Project back to model space:<br><br>
$ context \cdot W2 \in R^{h\cdot dv, d_{model}} = Y \in R^{m, Tx, d_{model}}$

In [42]:
class MultiHeadAttn(tf.keras.layers.Layer):
    
    def __init__(self, heads = 8, encoder = True, **kwargs):
        super().__init__(**kwargs)
        self.h = heads
        self.input_spec = tf.keras.layers.InputSpec(ndim = 3)
        self.encoder = encoder
        
    def build(self, input_shape):
        assert(len(input_shape) == 3), 'Expected input shape of (m, Tx, d)'

        (self.m, self.k, self.model_dim) = input_shape
        
        self.projected_dim = self.model_dim//self.h
        
        self.softmaxer = MaskedSoftmax(trailing = not self.encoder)
        
        self.reshaper = tf.keras.layers.Reshape((self.k, -1))
        
        self.W1 = self.add_weight(
                shape = (self.h, self.model_dim, self.projected_dim), 
                initializer = 'glorot_normal', 
                trainable = True)
        
        self.W2 = self.add_weight(
            shape = (self.projected_dim * self.h, self.model_dim),
            initializer= 'glorot_normal',
            trainable = True)
        
    def call(self, X, mask = None):
        
        X = tf.expand_dims(X, 1)
        
        projected = tf.matmul(X, self.W1)
        
        energies = tf.multiply(1/self.projected_dim**0.5,tf.matmul(projected,projected,transpose_b=True))
        
        alphas = self.softmaxer(energies, mask = mask)
        
        context = tf.matmul(alphas, projected)
        
        flattened = self.reshaper(context)
        
        output = tf.matmul(tf.dtypes.cast(flattened, 'float32'), self.W2)
        
        if not mask is None:
            output = tf.multiply(output, tf.expand_dims(tf.dtypes.cast(mask, 'float32'), -1))
        
        return output

In [43]:
attn = MultiHeadAttn(heads = 1, encoder = False)

In [44]:
X = tf.ones((2,5,3))
m = tf.convert_to_tensor(np.array([[True, True, False, False, False], [True, True, True, True, False]]))
X = attn(X, mask = m)
X.get_shape(), X

(TensorShape([2, 5, 3]),
 <tf.Tensor: id=646, shape=(2, 5, 3), dtype=float32, numpy=
 array([[[ 0.3999936 , -1.092978  , -0.13726237],
         [ 0.3999936 , -1.092978  , -0.13726237],
         [ 0.        , -0.        , -0.        ],
         [ 0.        , -0.        , -0.        ],
         [ 0.        , -0.        , -0.        ]],
 
        [[ 0.3999936 , -1.092978  , -0.13726237],
         [ 0.3999936 , -1.092978  , -0.13726237],
         [ 0.3999936 , -1.092978  , -0.13726236],
         [ 0.3999936 , -1.092978  , -0.13726237],
         [ 0.        , -0.        , -0.        ]]], dtype=float32)>)

# Encoder-Decoder Multihead

1. Input is two matrices rank 3 (m, Tx, d_model): D, output from previous decoder layer; E, output from encoder stack
2. Project E and D to h-dimensional space (h = num heads):
<ol><li>a. expand dimension of E/D to be (m, 1, Tx, d_model)</li>
    <li>b. $E/D \in R^{m, 1, Tx, d_{model}} * W1 \in R^{h, d_{model}, dv} = E/D_{proj} \in R^{m, h, Tx, dv}$</li>
</ol><br>
3. Calculate raw energies through dot product attention mechanism, query = $D_{proj}$, key = $E_{proj}$:<br><br>
$ \dfrac{1}{\sqrt{len(dv)}} D_{proj}\cdot E_{proj}^{T} = Energies \in R^{m, h, Tx, Tx}$<br><br>
4. Compute softmax to compute alphas:<br><br>
$ Softmax(Energies) = Alphas $<br><br>
5. Compute new context vectors through multiplication of Alphas and Values (also $E_{proj}$):<br><br>
$ Alphas \in R^{m, h, Tx, Tx} \cdot E_{proj} \in R^{m, h, Tx, dv} = context \in R^{m, h, Tx, dv} $<br>
6. Stack context vectors from different heads back into original dimension:<br><br>
$ Stack(context) = context \in R^{m, Tx, h\cdot dv} $ <br><br>
7. Project back to model space:<br><br>
$ context \cdot W2 \in R^{h\cdot dv, d_{model}} = Y \in R^{m, Tx, d_{model}}$

In [51]:
class EncoderDecoderMultiHead(tf.keras.layers.Layer):
    
    def __init__(self, heads = 8, **kwargs):
        super().__init__(**kwargs)
        self.h = heads
        self.input_spec = tf.keras.layers.InputSpec(ndim = 3)
        
    def build(self, input_shape):
        
        assert(type(input_shape) == list and len(input_shape) == 2),'Expected input as "[encoder_seq, decoder_seq]"'
        assert(input_shape[0] == input_shape[1]), 'Expected encoder and decoder inputs to have same dimensions'

        (self.m, self.k, self.model_dim) = input_shape[0]
        
        self.projected_dim = self.model_dim//self.h
        
        self.softmaxer = MaskedSoftmax(trailing = False)
        
        self.reshaper = tf.keras.layers.Reshape((self.k, -1))
        
        self.W1 = self.add_weight(
                shape = (self.h, self.model_dim, self.projected_dim), 
                initializer = 'glorot_normal', 
                trainable = True)
        
        self.W2 = self.add_weight(
            shape = (self.projected_dim * self.h, self.model_dim),
            initializer= 'glorot_normal',
            trainable = True)
        
    def call(self, inputs, mask = None):
        
        (encoder_in, decoder_in) = inputs
        
        encoder_in, decoder_in = tf.expand_dims(encoder_in, 1), tf.expand_dims(decoder_in, 1)
        
        encoder_proj, decoder_proj = tf.matmul(encoder_in, self.W1), tf.matmul(decoder_in, self.W1)
        
        energies = tf.multiply(1/self.projected_dim**0.5,tf.matmul(decoder_proj,encoder_proj,transpose_b=True))
        
        alphas = self.softmaxer(energies)
        
        context = tf.matmul(alphas, encoder_proj)
        
        flattened = self.reshaper(context)
        
        output = tf.matmul(tf.dtypes.cast(flattened, 'float32'), self.W2)
        
        if not mask is None:
            output = tf.multiply(output, tf.expand_dims(tf.dtypes.cast(mask, 'float32'), -1))
        
        return output

# Transformer Layers

### Encoder Layer

1. $X = Normalize(Multihead(X) + X)$
2. $X = Normalize(W2\cdot relu(W1\cdot X + b1) + b2 + X)$

In [152]:
class EncoderNode(tf.keras.layers.Layer):
    
    def __init__(self, h = 8, **kwargs):
        super().__init__(**kwargs)
        self.h = h
        
    def build(self, input_shape):
        
        assert(len(input_shape) == 3), 'Expected input of rank 3: (m, Tx, d)'
        
        (m, k, model_dim) = input_shape
        
        self.multihead_attn = MultiHeadAttn(self.h, encoder = True)
        self.norm1 = tf.keras.layers.BatchNormalization()
        self.norm2 = tf.keras.layers.BatchNormalization()
        self.dense1 = tf.keras.layers.Dense(model_dim, activation = 'relu', use_bias = True)
        self.dense2 = tf.keras.layers.Dense(model_dim, activation = 'linear', use_bias = True)
        
    def call(self, X, mask = None):
                
        X = self.norm1(self.multihead_attn(X) + X)
        
        X = self.norm2(self.dense2(self.dense1(X)) + X)
        
        if not mask is None:
            
            X = tf.multiply(X, tf.expand_dims(tf.dtypes.cast(mask, 'float32'), -1))
        
        return X      

### Decoder Layer

Inputs = D, output from prev decoder layer
         E, output from encoder stack
         
1. $X = Normalize(Multihead(D) + D)$
2. $X = Normalize(Multihead(E,X) + X)$
3. $X = Normalize(W2\cdot relu(W1\cdot X + b1) + b2 + X)$

In [180]:
class DecoderNode(tf.keras.layers.Layer):
    
    def __init__(self, h = 8):
        super().__init__(**kwargs)
        self.h = h
        
    def build(self, input_shape):
        assert(type(input_shape) == list and len(input_shape) == 2),'Expected input as "[encoder_seq, decoder_seq]"'
        assert(input_shape[0] == input_shape[1]), 'Expected encoder and decoder inputs to have same dimensions'
        
        (m, Tx, model_dim) = input_shape[0]
        
        self.masked_attn = MultiHeadAttn(self.h, encoder = False)
        self.merged_attn = EncoderDecoderMultiHead(self.h)
        self.norm1 = tf.keras.layers.BatchNormalization()
        self.norm2 = tf.keras.layers.BatchNormalization()
        self.norm3 = tf.keras.layers.BatchNormalization()
        self.dense1 = tf.keras.layers.Dense(model_dim, activation = 'relu', use_bias = True)
        self.dense2 = tf.keras.layers.Dense(model_dim, activation = 'linear', use_bias = True)
        
    def call(self, inputs):
        
        (encoder_in, decoder_in) = inputs
        
        X = self.norm1(self.masked_attn(decoder_in) + decoder_in)
        
        X = self.norm2(self.merged_attn([encoder_in, X]) + X)
        
        X = self.norm3(self.dense2(self.dense(1)) + X)
        
        if not mask is None:
            
            X = tf.multiply(X, tf.expand_dims(tf.dtypes.cast(mask, 'float32'), -1))
        
        return X  