### This code implements below GIF of self-attetntion in PyTorch and Tensorflow

![texto alternativo](https://pic2.zhimg.com/80/v2-b900fb952a100acd7dd8cd65ebd8bd61_1440w.gif)

# PyTorch

In [2]:
import torch

## 0. define input

In [13]:
x = [
  [1, 0, 1, 0], # Input 1
  [0, 2, 0, 2], # Input 2
  [1, 1, 1, 1]  # Input 3
 ]

x = torch.tensor(x, dtype=torch.float32)

print(x)

tensor([[1., 0., 1., 0.],
        [0., 2., 0., 2.],
        [1., 1., 1., 1.]])


## 0. define weight matrix for key, query, value

In [4]:
# all inputs transformed to key, query, value by sharing the same weight matrix
w_key = [
  [0, 0, 1],
  [1, 1, 0],
  [0, 1, 0],
  [1, 1, 0]
]
w_query = [
  [1, 0, 1],
  [1, 0, 0],
  [0, 0, 1],
  [0, 1, 1]
]
w_value = [
  [0, 2, 0],
  [0, 3, 0],
  [1, 0, 3],
  [1, 1, 0]
]
w_key = torch.tensor(w_key, dtype=torch.float32)
w_query = torch.tensor(w_query, dtype=torch.float32)
w_value = torch.tensor(w_value, dtype=torch.float32)

print("Weights for key: \n", w_key)
print("Weights for query: \n", w_query)
print("Weights for value: \n", w_value)

Weights for key: 
 tensor([[0., 0., 1.],
        [1., 1., 0.],
        [0., 1., 0.],
        [1., 1., 0.]])
Weights for query: 
 tensor([[1., 0., 1.],
        [1., 0., 0.],
        [0., 0., 1.],
        [0., 1., 1.]])
Weights for value: 
 tensor([[0., 2., 0.],
        [0., 3., 0.],
        [1., 0., 3.],
        [1., 1., 0.]])


## 1. project inputs onto weight matrix to get respective key, query, value

In [5]:
keys = x @ w_key
querys = x @ w_query
values = x @ w_value

print("Keys: \n", keys)
# tensor([[0., 1., 1.], # key for Input 1
#         [4., 4., 0.], # key for Input 2
#         [2., 3., 1.]]) # key for Input 3
print("Querys: \n", querys)
# tensor([[1., 0., 2.],  # query for Input 1
#         [2., 2., 2.],  # query for Input 2
#         [2., 1., 3.]]) # query for Input 3
print("Values: \n", values)
# tensor([[1., 2., 3.],  # value for Input 1
#         [2., 8., 0.],  # value for Input 2
#         [2., 6., 3.]]) # value for Input 3

Keys: 
 tensor([[0., 1., 1.],
        [4., 4., 0.],
        [2., 3., 1.]])
Querys: 
 tensor([[1., 0., 2.],
        [2., 2., 2.],
        [2., 1., 3.]])
Values: 
 tensor([[1., 2., 3.],
        [2., 8., 0.],
        [2., 6., 3.]])


## 2. for each input, calculate attention score by obtaining context similarity between its own query and all keys (including its own key)

In [6]:
attn_scores = querys @ keys.T

print(attn_scores)
# tensor([[ 2.,  4.,  4.],  # attention scores from Query 1
#         [ 4., 16., 12.],  # attention scores from Query 2
#         [ 4., 12., 10.]]) # attention scores from Query 3

tensor([[ 2.,  4.,  4.],
        [ 4., 16., 12.],
        [ 4., 12., 10.]])


## 3. apply softmax on attention score to normalize the weight

In [7]:
from torch.nn.functional import softmax

attn_scores_softmax= softmax(attn_scores, dim=-1)
print(attn_scores_softmax)
# tensor([[6.3379e-02, 4.6831e-01, 4.6831e-01],
#         [6.0337e-06, 9.8201e-01, 1.7986e-02],
#         [2.9539e-04, 8.8054e-01, 1.1917e-01]])

# For readability, approximate the above as follows
attn_scores_softmax = [
  [0.0, 0.5, 0.5], # attention scores from Query 1
  [0.0, 1.0, 0.0], # attention scores from Query 2
  [0.0, 0.9, 0.1]  # attention scores from Query 3
]
attn_scores_softmax = torch.tensor(attn_scores_softmax)
print(attn_scores_softmax)

tensor([[6.3379e-02, 4.6831e-01, 4.6831e-01],
        [6.0337e-06, 9.8201e-01, 1.7986e-02],
        [2.9539e-04, 8.8054e-01, 1.1917e-01]])
tensor([[0.0000, 0.5000, 0.5000],
        [0.0000, 1.0000, 0.0000],
        [0.0000, 0.9000, 0.1000]])


## 4. for each input, calculate its contextualised embedding from weighted values

In [9]:
outputs = attn_scores_softmax @ values 
print(outputs)
# tensor([[2.0000, 7.0000, 1.5000],  # Output 1
#         [2.0000, 8.0000, 0.0000],  # Output 2
#         [2.0000, 7.8000, 0.3000]]) # Output 3

tensor([[2.0000, 7.0000, 1.5000],
        [2.0000, 8.0000, 0.0000],
        [2.0000, 7.8000, 0.3000]])


# Tensorflow

In [10]:
import tensorflow as tf

In [14]:
x = [
  [1, 0, 1, 0], # Input 1
  [0, 2, 0, 2], # Input 2
  [1, 1, 1, 1]  # Input 3
 ]

x = tf.convert_to_tensor(x, dtype=tf.float32)

print(x)

tf.Tensor(
[[1. 0. 1. 0.]
 [0. 2. 0. 2.]
 [1. 1. 1. 1.]], shape=(3, 4), dtype=float32)


In [15]:
# all inputs transformed to key, query, value by sharing the same weight matrix
w_key = [
  [0, 0, 1],
  [1, 1, 0],
  [0, 1, 0],
  [1, 1, 0]
]
w_query = [
  [1, 0, 1],
  [1, 0, 0],
  [0, 0, 1],
  [0, 1, 1]
]
w_value = [
  [0, 2, 0],
  [0, 3, 0],
  [1, 0, 3],
  [1, 1, 0]
]
w_key = tf.convert_to_tensor(w_key, dtype=tf.float32)
w_query = tf.convert_to_tensor(w_query, dtype=tf.float32)
w_value = tf.convert_to_tensor(w_value, dtype=tf.float32)

print("Weights for key: \n", w_key)
print("Weights for query: \n", w_query)
print("Weights for value: \n", w_value)

Weights for key: 
 tf.Tensor(
[[0. 0. 1.]
 [1. 1. 0.]
 [0. 1. 0.]
 [1. 1. 0.]], shape=(4, 3), dtype=float32)
Weights for query: 
 tf.Tensor(
[[1. 0. 1.]
 [1. 0. 0.]
 [0. 0. 1.]
 [0. 1. 1.]], shape=(4, 3), dtype=float32)
Weights for value: 
 tf.Tensor(
[[0. 2. 0.]
 [0. 3. 0.]
 [1. 0. 3.]
 [1. 1. 0.]], shape=(4, 3), dtype=float32)


In [16]:
attn_scores = tf.matmul(querys, keys.T)

print(attn_scores)

tf.Tensor(
[[ 2.  4.  4.]
 [ 4. 16. 12.]
 [ 4. 12. 10.]], shape=(3, 3), dtype=float32)


In [17]:
attn_scores_softmax = tf.nn.softmax(attn_scores, axis=-1)
print(attn_scores_softmax)

# For readability, approximate the above as follows
attn_scores_softmax = [
  [0.0, 0.5, 0.5], # attention scores from Query 1
  [0.0, 1.0, 0.0], # attention scores from Query 2
  [0.0, 0.9, 0.1]  # attention scores from Query 3
]
attn_scores_softmax = tf.convert_to_tensor(attn_scores_softmax)
print(attn_scores_softmax)

tf.Tensor(
[[6.3378938e-02 4.6831051e-01 4.6831051e-01]
 [6.0336647e-06 9.8200780e-01 1.7986100e-02]
 [2.9538720e-04 8.8053685e-01 1.1916770e-01]], shape=(3, 3), dtype=float32)
tf.Tensor(
[[0.  0.5 0.5]
 [0.  1.  0. ]
 [0.  0.9 0.1]], shape=(3, 3), dtype=float32)


In [18]:
outputs = tf.matmul(attn_scores_softmax, values)
print(outputs)

tf.Tensor(
[[2.        7.        1.5      ]
 [2.        8.        0.       ]
 [2.        7.7999997 0.3      ]], shape=(3, 3), dtype=float32)
