In [1]:
import tensorflow as tf
import numpy as np

In [2]:
class MaskedSoftmax(tf.keras.layers.Layer):
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def call(self, z, mask = None):
        
        z = tf.keras.backend.exp(z)
        
        if not mask is None:
            assert(mask.shape == z.shape), 'Mask has incorrect dimensions: ' + str(z.shape) + ' vs. ' + str(mask.shape)
            z = tf.multiply(z, tf.dtypes.cast(mask, 'float32'))
        
        return tf.divide(z, tf.reduce_sum(z, axis = -1, keepdims=True))

In [10]:
class TrailingSoftmax(tf.keras.layers.Layer):
    
    def __init__(self, **kwargs):
        super().__init__(**kwargs)
        
    def build(self, input_shape):
        assert(len(input_shape) == 3), 'Expect rank 3 input: (batch_size, Tx, d)'
        assert(input_shape[-1] == input_shape[-2]),'Last two ranks must be symmetrical for attn mechanism'
        
        (m, Tx, d) = input_shape
        
        print('here')
        
        self.trailing_mask = tf.linalg.LinearOperatorLowerTriangular(tf.ones((m, Tx,Tx)))
        
        print('there')
        
    def call(self, z, mask = None):
        
        z = tf.keras.backend.exp(z)
        
        #if not mask is None: # because embedding multiplication pre-applies embedding mask
            #num_mask = tf.dtypes.cast(mask, 'float32')

            #column_num_mask = tf.expand_dims(num_mask, -1)

            #trailing_mask = tf.multiply(self.trailing_mask, column_num_mask)
            
        z = tf.multiply(z, self.trailing_mask)
            
        return tf.divide(z, tf.reduce_sum(z, axis = -1, keepdims=True))       

In [11]:
A = tf.ones(shape = [1,5,5])
A

<tf.Tensor: id=14, shape=(1, 5, 5), dtype=float32, numpy=
array([[[1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1.]]], dtype=float32)>

In [12]:
TrailingSoftmax()(A)

here
there


ValueError: Attempt to convert a value (<tensorflow.python.ops.linalg.linear_operator_lower_triangular.LinearOperatorLowerTriangular object at 0x0000026CD2333F08>) with an unsupported type (<class 'tensorflow.python.ops.linalg.linear_operator_lower_triangular.LinearOperatorLowerTriangular'>) to a Tensor.

In [4]:
class MultiHeadAttn(tf.keras.layers.Layer):
    
    def __init__(self, h = 8, **kwargs):
        super().__init__(**kwargs)
        self.h = h
        self.input_spec = tf.keras.layers.InputSpec(ndim = 3)
        
    def build(self, input_shape):
        assert(len(input_shape) == 3), 'Expected input shape of (m, Tx, d)'

        (self.m, self.k, self.model_dim) = input_shape
        
        self.projected_dim = self.model_dim//self.h
        
        self.softmaxer = TrailingSoftmax()
        
        self.softmaxer.build(input_shape)
        
        self.W1 = self.add_weight(
                shape = (self.h, self.model_dim, self.projected_dim), 
                initializer = 'glorot_normal', 
                trainable = True)
        
        self.W2 = self.add_weight(
            shape = (self.projected_dim * self.h, self.model_dim),
            initializer= 'glorot_normal',
            trainable = True)
        
    def call(self, X, mask = None):
        
        X = tf.expand_dims(X, 1)
        
        projected = tf.matmul(X, self.W1)
        
        energies = tf.multiply(tf.matmul(projected,projected,transpose_b=True),1/self.projected_dim**0.5)
        
        #alphas = tf.nn.softmax(energies, axis = -1)
        alphas = self.softmaxer(energies, mask)
        
        context = tf.matmul(alphas, projected)
        
        flattened = tf.reshape(context, (self.m, self.k, -1))
        
        output = tf.matmul(tf.dtypes.cast(flattened, 'float32'), self.W2)
        
        return output

In [342]:
class TransformerNode(tf.keras.layers.Layer):
    
    def __init__(self, h = 8, **kwargs):
        super().__init__(**kwargs)
        self.h = h
        
    def build(self, input_shape):
        
        (m, k, model_dim) = input_shape
        
        self.multihead_attn = MultiHeadAttn(self.h)
        self.norm1 = tf.keras.layers.BatchNormalization()
        self.dense1 = tf.keras.layers.Dense(model_dim, activation = 'relu', use_bias = True)
        self.dense2 = tf.keras.layers.Dense(model_dim, use_bias = True)
        self.norm2 = tf.keras.layers.BatchNormalization()
        
    def call(self, X):
        
        X = self.multihead_attn(X) + X
        
        X = self.norm1(X)
        
        X_bypass = X
        
        X = self.dense1(X)
        
        X = self.dense2(X)
        
        X = X + X_bypass
        
        X = self.norm2(X)
        
        return X      

In [339]:
embed = tf.keras.layers.Embedding(10000, 512, mask_zero=True)
attn = MultiHeadAttn(h = 8)
trans = Transformer(8)

X = np.array([[0,0,0,5,8],[1,0,0,3,3]])

X = embed(X)
X.get_shape()

TensorShape([2, 5, 512])

In [341]:
trans(X).get_shape()

TensorShape([2, 5, 512])

In [289]:
W = np.random.rand(8, 512, 64)
Wo = np.random.rand(512,512)
X = np.random.rand(2,1,5,512)

proj = tf.matmul(X,W)
proj.get_shape()

TensorShape([2, 8, 5, 64])

In [290]:
alpha = tf.multiply(tf.matmul(proj,proj,transpose_b=True),1/64**0.5)
alpha.get_shape()

TensorShape([2, 8, 5, 5])

In [291]:
soft_alpha = tf.nn.softmax(alpha_norm, axis = -1)
soft_alpha.get_shape()

TensorShape([2, 8, 5, 5])

In [292]:
energy = tf.matmul(soft_alpha, proj)
energy.get_shape()

TensorShape([2, 8, 5, 64])

In [293]:
c = tf.reshape(energy, (2, 5, -1))
c.get_shape()

TensorShape([2, 5, 512])

In [295]:
output = tf.matmul(c, Wo)
output.get_shape()

TensorShape([2, 5, 512])