In [7]:
import numpy as np

class TinyBERT:
    def __init__(self):
        self.word_to_id = {'the': 0, 'cat': 1, 'sits': 2, 'sleeps': 3}

        self.embeddings = np.array([
            [0.2, -0.5, 0.1],   # the
            [-0.3, 0.4, 0.2],   # cat
            [0.1, 0.3, -0.4],   # sits
            [-0.2, -0.1, 0.5]   # sleeps
        ])

        self.max_sequence_length = 10
        self.position_dim = 3
        self.position_embeddings = np.random.randn(self.max_sequence_length, self.position_dim) * 0.1

    def attention(self, sentence):
        print('TinyBert --- attention()')
        print(sentence)
        # convert to embeddings
        word_ids = [self.word_to_id[word] for word in sentence]
        print(word_ids)
        vectors = self.embeddings[word_ids]
        print(vectors)

        positional_vectors = self.position_embeddings[:len(sentence)]
        vectors = vectors + positional_vectors
        print(vectors)

        # calculate attention scores by seeing how aligned they are in terms of drections
        # more similar words/embeddings pointing in same direction should have similar attentions
        scores = np.dot(vectors, vectors.T)
        print(scores)
        scores = scores / np.sqrt(3) # scale scores
        print(scores)

        # convert to probabilities
        probs = np.exp(scores)
        print(probs)
        attention_probs = probs / probs.sum(axis=1, keepdims=True)
        print(attention_probs)

        # work out new embeddings for words based on weighted sum of other words embeddings
        newEmbeddings = np.dot(attention_probs, vectors)
        print(newEmbeddings)
        return attention_probs, newEmbeddings


In [10]:
bert = TinyBERT()
sentence = ['cat', 'sits', 'the', ]
weights, new_reps = bert.attention(sentence)

TinyBert --- attention()
['cat', 'sits', 'the']
[1, 2, 0]
[[-0.3  0.4  0.2]
 [ 0.1  0.3 -0.4]
 [ 0.2 -0.5  0.1]]
[[-0.25497954  0.4399019   0.11021631]
 [ 0.07694333  0.38871628 -0.50786617]
 [ 0.18630067 -0.41576185  0.31090431]]
[[ 0.27067589  0.09540292 -0.19613056]
 [ 0.09540292  0.41494866 -0.30517659]
 [-0.19613056 -0.30517659  0.30422735]]
[[ 0.1562748   0.0550809  -0.11323603]
 [ 0.0550809   0.23957072 -0.17619378]
 [-0.11323603 -0.17619378  0.17564574]]
[[1.16914744 1.05662609 0.89293987]
 [1.05662609 1.27070355 0.83845549]
 [0.89293987 0.83845549 1.1920157 ]]
[[0.37488133 0.33880192 0.28631675]
 [0.33376431 0.40138654 0.26484915]
 [0.30544451 0.28680725 0.40774824]]
[[-0.01617752  0.17756925 -0.04173089]
 [-0.00487748  0.19273487 -0.08472163]
 [ 0.02014957  0.07632611  0.01477595]]


In [11]:
print("\nAttention weights:")
for i, word in enumerate(sentence):
    print(f"\n{word} pays attention to each word:")
    for j, other_word in enumerate(sentence):
        print(f"{other_word}: {weights[i][j]:.3f}")


Attention weights:

cat pays attention to each word:
cat: 0.375
sits: 0.339
the: 0.286

sits pays attention to each word:
cat: 0.334
sits: 0.401
the: 0.265

the pays attention to each word:
cat: 0.305
sits: 0.287
the: 0.408
