# TensorFlow code for Self-Attention 

### Referrence:
* [Attention Is All You Need](https://arxiv.org/abs/1706.03762)
* [Transformer: A Novel Neural Network Architecture for Language Understanding](https://research.googleblog.com/2017/08/transformer-novel-neural-network.html)
* Tensor2tensor(https://github.com/tensorflow/tensor2tensor)

In [1]:
import tensorflow as tf

In [2]:
def input_fun(**config):
    data = tf.random_normal((
        config['batch_size'], config['sequence_length'], config['hidden_dim']))
    return data

In [17]:
def attention_fun(Q, K, scaled_=True, masked_=False):
    attention = tf.matmul(Q, K, transpose_b=True)  # [batch_size, sequence_length, sequence_length]

    if scaled_:
        d_k = tf.cast(tf.shape(K)[-1], dtype=tf.float32)
        attention = tf.divide(attention, tf.sqrt(d_k))  # [batch_size, sequence_length, sequence_length]

    if masked_:
        raise NotImplementedError

    attention = tf.nn.softmax(attention, dim=-1)  # [batch_size, sequence_length, sequence_length]
    return attention


In [16]:
def model_fun(data, **config):
    Q = tf.layers.dense(data, config['hidden_dim'])  # [batch_size, sequence_length, hidden_dim]
    K = tf.layers.dense(data, config['hidden_dim'])  # [batch_size, sequence_length, hidden_dim]
    V = tf.layers.dense(data, config['n_classes'])  # [batch_size, sequence_length, n_classes]

    attention = attention_fun(Q, K)  # [batch_size, sequence_length, sequence_length]
    output = tf.matmul(attention, V)  # [batch_size, sequence_length, n_classes]
    return output

In [3]:
if __name__ == '__main__':
    inputs = input_fun(batch_size=32, sequence_length=10, hidden_dim=128)
    #outputs = model_fun(inputs, hidden_dim=128, n_classes=2)

    #with tf.Session() as sess:
    #    sess.run(tf.global_variables_initializer())

    #    outputs_ = sess.run(outputs)
    #    print(outputs_.shape)

Tensor("random_normal:0", shape=(32, 10, 128), dtype=float32)


In [7]:
dir(inputs)

['OVERLOADABLE_OPERATORS',
 '__abs__',
 '__add__',
 '__and__',
 '__array_priority__',
 '__bool__',
 '__class__',
 '__copy__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__div__',
 '__doc__',
 '__eq__',
 '__floordiv__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getitem__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__invert__',
 '__iter__',
 '__le__',
 '__lt__',
 '__matmul__',
 '__mod__',
 '__module__',
 '__mul__',
 '__ne__',
 '__neg__',
 '__new__',
 '__nonzero__',
 '__or__',
 '__pow__',
 '__radd__',
 '__rand__',
 '__rdiv__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__rfloordiv__',
 '__rmatmul__',
 '__rmod__',
 '__rmul__',
 '__ror__',
 '__rpow__',
 '__rsub__',
 '__rtruediv__',
 '__rxor__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__sub__',
 '__subclasshook__',
 '__truediv__',
 '__weakref__',
 '__xor__',
 '_as_node_def_input',
 '_as_tf_output',
 '_c_api_shape',
 '_consumers',
 '_dtype',
 '_get_input_ops_without_shapes',
 '_id',
 '_name',
 '_op',
 '_o

In [15]:
print(inputs.op)

name: "random_normal"
op: "Add"
input: "random_normal/mul"
input: "random_normal/mean"
attr {
  key: "T"
  value {
    type: DT_FLOAT
  }
}

