In [1]:
import tensorflow as tf

In [2]:
tf.enable_eager_execution()

In [10]:
class ScaledDotProductAttention:
    
    def __init__(self, dim):
        self.scale_factor = tf.sqrt(float(dim))

    def forward(self, Q, K, V):
        K = tf.transpose(K, perm=[0,2,1])
        softmax_scaled_weights = tf.nn.softmax(tf.matmul(Q,K) / self.scale_factor)
        return tf.matmul(softmax_scaled_weights, V)
    
class MultiHeadAttention:
    
    def __init__(self, heads, hidden_inp, trainable=True):
        assert(hidden_inp % heads == 0)
        
        self.trainable = trainable
        self.heads = heads
        self.h_in = hidden_inp
        self.dk = hidden_inp // heads
        self.t_shape = (self.h_in, self.dk * self.heads)
        self.scaled_dpa = ScaledDotProductAttention(self.dk)
        self.build()
    
    def build(self):
        self.Wq = tf.Variable(tf.random_normal(self.t_shape))
        self.Wk = tf.Variable(tf.random_normal(self.t_shape))
        self.Wv = tf.Variable(tf.random_normal(self.t_shape))
        self.Wo = tf.Variable(tf.random_normal((self.h_in, self.dk * self.heads)))
        self._trainable_weights = [self.Wq, self.Wk, self.Wv, self.Wo]
        
    @tf.contrib.eager.defun
    def forward(self, Q, K, V):
        
        # input dims [batch, ts, dk*heads]
        q = tf.tensordot(Q, self.Wq, axes=[[-1], [0]])
        k = tf.tensordot(K, self.Wk, axes=[[-1], [0]])
        v = tf.tensordot(V, self.Wv, axes=[[-1], [0]])
        
        def reshape1(x):
            s = tf.shape(x)   # [batch_size, len_q, n_head * d_k]
            x = tf.reshape(x, [s[0], s[1], self.heads, self.dk])
            x = tf.transpose(x, [2, 0, 1, 3])  
            x = tf.reshape(x, [-1, s[1], self.dk])  # [n_head * batch_size, len_q, dk]
            return x
        
        # Reshape to do the scaled dot product attention to [batch*heads, ts, dk]
        q = reshape1(q)
        k = reshape1(k)
        v = reshape1(v)
        
        dp_att_out = self.scaled_dpa.forward(q,k,v)
        
        # Reshape back to [batch, ts, dk*heads]
        def reshape2(x):
            s = tf.shape(x)   # [n_head * batch_size, len_v, d_k]
            x = tf.reshape(x, [self.heads, -1, s[1], s[2]]) 
            x = tf.transpose(x, [1, 2, 0, 3])
            x = tf.reshape(x, [-1, s[1], self.heads*self.dk])  # [batch_size, len_v, n_head * d_k]
            return x
        
        dp_att_out = reshape2(dp_att_out)
        
        return tf.tensordot(dp_att_out, self.Wo, axes=[[-1], [0]])

    
    def get_trainable_weights(self):
        return self._trainable_weights if self.trainable else []

class Dense:
    
    def __init__(self, input_dim, output_dim, bias=True, trainable=True):
        self.input_dim = input_dim
        self.output_dim = output_dim
        self.bias = bias
        self.trainable = trainable
        self.build()
        
    def build(self):
        self.W = tf.Variable(tf.random_normal((self.input_dim, self.output_dim)))
        if self.bias:
            self.b  = tf.Variable(tf.random_uniform((self.output_dim,)))
            self._trainable_weights = [self.W, self.b]
        else:
            self._trainable_weights = [self.W]
   
    @tf.contrib.eager.defun
    def forward(self, input_):
        if self.bias:
            return tf.tensordot(input_, self.W, axes=[[-1], [0]]) + self.b
        else:
            return tf.tensordot(input_, self.W, axes=[[-1], [0]])
        
    def get_trainable_weights(self):
        return self._trainable_weights if self.trainable else []
    

class TransformerLayer:
    
    def __init__(self, heads, hidden_inp, dense_hidden, trainable=True):
        self.trainable = trainable
        self.heads = heads
        self.hidden_inp = hidden_inp
        self.hidden_output = hidden_inp // heads
        self.dense_hidden = dense_hidden
        self.build()
    
    def build(self):
        self.mha = MultiHeadAttention(self.heads, self.hidden_inp)
        self.dense1 = Dense(self.hidden_inp, self.dense_hidden)
        self.dense2 = Dense(self.dense_hidden, self.hidden_inp)
        self._trainable_variables = self.mha.get_trainable_weights() + [self.dense1, self.dense2]
        
    @tf.contrib.eager.defun
    def forward(self, Q, K, V):
        output_ma = self.mha.forward(Q, K, V)
        o_ma_norm = tf.contrib.layers.layer_norm(output_ma + K)
        o_dense1 = tf.nn.relu(self.dense1.forward(o_ma_norm))
        o_dense2 = self.dense2.forward(o_dense1)
        output = tf.contrib.layers.layer_norm(o_dense2 + o_ma_norm)
        return output
        
    def get_trainable_weights(self):
        return self._trainable_variables if self.trainable else False

In [None]:
class TransformerEncoder:
    
    def __init__(self, num_layers, heads, embedding_dim, fc_hidden_dim, num_classes):
        self.num_layers = num_layers
        self.heads = heads
        self.embedding_dim = embedding_dim
        self.dense_hidden = fc_hidden_dim
        self.num_classes = num_classes
        self.build()
        
    def build(self):
        self.layers = []
        for _ in range(self.num_layers):
            self.layers += [TransformerLayer(self.heads, self.embedding_dim, self.dense_hidden)]
        
        self.dense_out = Dense(self.embedding_dim, self.num_classes, bias=False)
        self.layers += [self.dense_out]
    
    @tf.contrib.eager.defun
    def forward(self, input_):
        o_step_i = input_
        for layer in self.layers[:-1]:
            o_step_i = layer.forward(o_step_i,o_step_i,o_step_i)

        return self.dense_out.forward(o_step_i)
    
    def get_trainable_weights(self):
        return [weight for layer in self.layers for weight in layer.get_trainable_weights()]
    
    def train(self,
              x_train, 
              y_train, 
              x_val, 
              y_val, 
              vocab,
              loss, 
              epochs,
              score_fun,
              tensorboard=False,
              log_dir="./transformer_log/",
              ckpt_dir="./transformer_ckpt/",
              pad_value=0,
              batch_size=32, 
              val_bs=32):
        
        if tensorboard:
            summary_writer = tf.contrib.summary.create_file_writer(log_dir, flush_millis=10000)
            summary_writer.set_as_default()
            global_step = tf.train.get_or_create_global_step()
            
        iteration = 0
        n_classes = y_train.shape[1]
        current_val_score = self.compute_score(x_val, y_val, vocab, n_classes, val_bs, score_fun)
        for epoch in range(epochs):
            for x, y in get_embedded_iterator(x_train, y_train, n_classes, batch_size, vocab):
                
                if tensorboard:
                    global_step.assign_add(1)
                
                minimize(self.optimizer, self, loss, x, y, lr=0.001, logging=tensorboard, it=iteration, log_every=10)
            
            val_score = self.compute_score(x_val, y_val, vocab, n_classes, val_bs, score_fun)

            if tensorboard:
                log_scalar('val_score', val_score)

            print("Validation score is {0}".format(val_score))
            
            if val_score > current_val_score:
                self.save_model(ckpt=ckpt_dir)
                
    def compute_score(self, x_val, y_true, vocab, n_classes, bs, score_fun):
        scores = []
        
        for x, y in get_embedded_iterator(x_val, y_true, n_classes, bs, vocab):
            scores.append(score_fun(self.forward(x), y))
        return np.mean(scores)
    
    def save_model(self, ckpt="./transformer_log/"):
        self.ckp.save(ckpt)

    def restore_model(self, ckpt="./transformer_ckpt/"):
        self.ckp.restore(tf.train.latest_checkpoint(ckpt))

In [None]:
mha = MultiHeadAttention(8, 512)

In [None]:
%%time
mha.forward(tf.random_normal((10,12,512)), tf.random_normal((10,12,512)), tf.random_normal((10,12,512)))

In [None]:
t_layer = TransformerLayer(8, 512, 2048)

In [None]:
%%time
t_layer.forward(tf.random_normal((10,2,512)), tf.random_normal((10,2,512)), tf.random_normal((10,2,512)))

In [None]:
t_encoder = TransformerEncoder(6, 8, 512, 2048, 2)

In [14]:
%%time
t_encoder.forward(tf.random_normal((128,182,512)))

Wall time: 6.99 ms


<tf.Tensor: id=7332, shape=(128, 182, 2), dtype=float32, numpy=
array([[[-33.881607  ,  -2.3976965 ],
        [-25.505833  ,  -5.985231  ],
        [-38.95021   ,  -3.8312526 ],
        ...,
        [-38.821518  ,   7.856987  ],
        [-25.54805   ,  -2.078445  ],
        [-22.891998  ,  -1.8945389 ]],

       [[ -9.196106  ,   8.707629  ],
        [-23.193325  ,   7.910781  ],
        [-17.789616  ,  10.899169  ],
        ...,
        [-28.601572  ,  -2.873456  ],
        [-20.607645  ,   1.9075794 ],
        [-12.547087  ,  -3.5973682 ]],

       [[-21.083778  ,  -0.2351141 ],
        [-22.88121   ,   6.1116323 ],
        [-29.783657  ,   9.990348  ],
        ...,
        [-14.516047  ,   1.4761353 ],
        [-18.951696  ,  -0.8923621 ],
        [-29.320147  ,   0.24620867]],

       ...,

       [[-13.379937  ,  -4.368475  ],
        [-13.5032015 ,  -7.1722164 ],
        [-16.199566  ,  -6.9870105 ],
        ...,
        [-16.495617  , -10.109959  ],
        [-11.504489  ,  -1.24