# TensorFlow code for Self-Attention 

### Referrence:
* [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
* [Transformer: A Novel Neural Network Architecture for Language Understanding](https://research.googleblog.com/2017/08/transformer-novel-neural-network.html)
* Tensor2tensor (https://github.com/tensorflow/tensor2tensor)

In [1]:
import tensorflow as tf

In [2]:
def input_fun(**config):
    data = tf.random_normal((
        config['batch_size'], config['sequence_length'], config['hidden_dim']))
    return data

In [17]:
def attention_fun(Q, K, scaled_=True, masked_=False):
    attention = tf.matmul(Q, K, transpose_b=True)  # [batch_size, sequence_length, sequence_length]

    if scaled_:
        d_k = tf.cast(tf.shape(K)[-1], dtype=tf.float32)
        attention = tf.divide(attention, tf.sqrt(d_k))  # [batch_size, sequence_length, sequence_length]

    if masked_:
        raise NotImplementedError

    attention = tf.nn.softmax(attention, dim=-1)  # [batch_size, sequence_length, sequence_length]
    return attention


In [16]:
def model_fun(data, **config):
    Q = tf.layers.dense(data, config['hidden_dim'])  # [batch_size, sequence_length, hidden_dim]
    K = tf.layers.dense(data, config['hidden_dim'])  # [batch_size, sequence_length, hidden_dim]
    V = tf.layers.dense(data, config['n_classes'])  # [batch_size, sequence_length, n_classes]

    attention = attention_fun(Q, K)  # [batch_size, sequence_length, sequence_length]
    output = tf.matmul(attention, V)  # [batch_size, sequence_length, n_classes]
    return output

In [19]:
if __name__ == '__main__':
    inputs = input_fun(batch_size=32, sequence_length=10, hidden_dim=128)
    with tf.Session() as sess:  print(inputs.eval())
    #outputs = model_fun(inputs, hidden_dim=128, n_classes=2)

    #with tf.Session() as sess:
    #    sess.run(tf.global_variables_initializer())

    #    outputs_ = sess.run(outputs)
    #    print(outputs_.shape)

[[[ 1.4977258   0.29166442  0.00683167 ... -0.94390374  2.132048
    0.10638817]
  [ 0.19402815  0.7921161   0.80763364 ... -1.4685776  -1.7131332
    0.6364754 ]
  [-0.8855895   0.2456611  -1.2027229  ...  0.8814262   0.70316225
    0.2799854 ]
  ...
  [ 0.1501507   0.9974407  -0.39481825 ...  0.65158623  0.85993665
    0.38736168]
  [ 0.10191716 -0.10676987  0.44088697 ...  2.0026267   1.1333106
   -0.5815502 ]
  [ 0.17985587 -0.8863931  -0.46068347 ...  0.64572054  0.9446574
   -0.50968754]]

 [[-0.93790776  0.08903116 -1.7041032  ...  1.1371131   0.4319286
    0.8350716 ]
  [-0.18624283  1.282129    0.6841612  ...  0.3939439  -2.1237617
    1.065678  ]
  [-1.1238523   1.658433   -1.209692   ... -0.02763248 -0.3910349
    1.0298834 ]
  ...
  [-0.59902745 -0.7648693  -0.5120379  ...  1.1642501  -1.4624057
   -1.2281034 ]
  [-1.4802809   0.15423995 -0.27076465 ...  1.5185727  -0.12518252
   -0.76886964]
  [-1.4296391   0.1306755  -1.2566013  ...  0.24928606 -0.51026887
    1.8233235 ]

In [20]:
inputs

<tf.Tensor 'random_normal_2:0' shape=(32, 10, 128) dtype=float32>

name: "random_normal"
op: "Add"
input: "random_normal/mul"
input: "random_normal/mean"
attr {
  key: "T"
  value {
    type: DT_FLOAT
  }
}

