# TensorFlow code for Self-Attention 

### Referrence:
* [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
* [Transformer: A Novel Neural Network Architecture for Language Understanding](https://research.googleblog.com/2017/08/transformer-novel-neural-network.html)
* Tensor2tensor (https://github.com/tensorflow/tensor2tensor)

In [1]:
import tensorflow as tf

In [2]:
def input_fun(**config):
    data = tf.random_normal((
        config['batch_size'], config['sequence_length'], config['hidden_dim']))
    return data

In [17]:
def attention_fun(Q, K, scaled_=True, masked_=False):
    attention = tf.matmul(Q, K, transpose_b=True)  # [batch_size, sequence_length, sequence_length]

    if scaled_:
        d_k = tf.cast(tf.shape(K)[-1], dtype=tf.float32)
        attention = tf.divide(attention, tf.sqrt(d_k))  # [batch_size, sequence_length, sequence_length]

    if masked_:
        raise NotImplementedError

    attention = tf.nn.softmax(attention, dim=-1)  # [batch_size, sequence_length, sequence_length]
    return attention


In [16]:
def model_fun(data, **config):
    Q = tf.layers.dense(data, config['hidden_dim'])  # [batch_size, sequence_length, hidden_dim]
    K = tf.layers.dense(data, config['hidden_dim'])  # [batch_size, sequence_length, hidden_dim]
    V = tf.layers.dense(data, config['n_classes'])  # [batch_size, sequence_length, n_classes]

    attention = attention_fun(Q, K)  # [batch_size, sequence_length, sequence_length]
    output = tf.matmul(attention, V)  # [batch_size, sequence_length, n_classes]
    return output

In [30]:
if __name__ == '__main__':
    inputs = input_fun(batch_size=32, sequence_length=10, hidden_dim=128)
    #with tf.Session() as sess:  print(inputs.eval())
    outputs = model_fun(inputs, hidden_dim=128, n_classes=2)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        outputs_ = sess.run(outputs)
        print(outputs_.shape)

[[[ 8.64741611e+00  4.31900692e+00 -9.97002125e-02 ...  1.65095139e+00
    2.85142422e+01 -1.14556618e+01]
  [-3.04894686e+00  3.77915645e+00  9.54752541e+00 ... -2.43744087e+01
    2.64614868e+00  6.24607468e+00]
  [-6.80655432e+00  2.43296318e+01 -1.13980789e+01 ...  1.79394302e+01
    9.53603935e+00 -1.17955828e+01]
  ...
  [-1.68654323e-02 -1.27606792e+01  6.51028395e-01 ... -2.44753838e+00
    6.09450722e+00  9.07594490e+00]
  [-6.79663181e+00  1.92100763e-01 -6.65590811e+00 ...  2.99443665e+01
    2.50837660e+00 -1.48180513e+01]
  [ 7.70586395e+00 -3.26023459e+00  2.29179263e-02 ... -3.32949114e+00
    2.43676496e+00  7.07141685e+00]]

 [[-5.38709736e+00 -9.43058968e-01  1.58931513e+01 ... -8.20481491e+00
    5.34416199e-01 -8.28966141e+00]
  [-6.71831131e+00 -1.42814913e+01 -2.01013255e+00 ...  3.77998567e+00
   -1.75789642e+00  4.20802712e-01]
  [ 1.55538101e+01  8.83392715e+00  1.32826500e+01 ...  3.69509029e+00
    1.87220726e+01 -1.55943251e+01]
  ...
  [ 3.87775612e+00  1.3

In [20]:
if __name__ == '__main__':
    inputs = input_fun(batch_size=32, sequence_length=10, hidden_dim=128)
    #with tf.Session() as sess:  print(inputs.eval())
    outputs = model_fun(inputs, hidden_dim=128, n_classes=2)

    with tf.Session() as sess:
        sess.run(tf.global_variables_initializer())

        outputs_ = sess.run(outputs)
        print(outputs_.shape)

<tf.Tensor 'random_normal_2:0' shape=(32, 10, 128) dtype=float32>

name: "random_normal"
op: "Add"
input: "random_normal/mul"
input: "random_normal/mean"
attr {
  key: "T"
  value {
    type: DT_FLOAT
  }
}

