In [41]:
import tensorflow as tf

In [42]:
tf.random.set_seed(42)

In [43]:
B, T, C = 4, 4, 2    # batch, time, channels  -> This channel dimension is added by embedding layer
x = tf.reshape(tf.range(0, 32, dtype= tf.float32), (B, T, C))

In [44]:
x

<tf.Tensor: shape=(4, 4, 2), dtype=float32, numpy=
array([[[ 0.,  1.],
        [ 2.,  3.],
        [ 4.,  5.],
        [ 6.,  7.]],

       [[ 8.,  9.],
        [10., 11.],
        [12., 13.],
        [14., 15.]],

       [[16., 17.],
        [18., 19.],
        [20., 21.],
        [22., 23.]],

       [[24., 25.],
        [26., 27.],
        [28., 29.],
        [30., 31.]]], dtype=float32)>

In [45]:
tril = tf.linalg.band_part(tf.ones((T, T)), -1, 0)
weights = tf.zeros((T, T))
weights, tril

(<tf.Tensor: shape=(4, 4), dtype=float32, numpy=
 array([[0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.],
        [0., 0., 0., 0.]], dtype=float32)>,
 <tf.Tensor: shape=(4, 4), dtype=float32, numpy=
 array([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]], dtype=float32)>)

In [46]:
weights = tf.where(tril == 0, float('-inf'), weights)
weights

<tf.Tensor: shape=(4, 4), dtype=float32, numpy=
array([[  0., -inf, -inf, -inf],
       [  0.,   0., -inf, -inf],
       [  0.,   0.,   0., -inf],
       [  0.,   0.,   0.,   0.]], dtype=float32)>

In [47]:
weights = tf.nn.softmax(weights)
weights

<tf.Tensor: shape=(4, 4), dtype=float32, numpy=
array([[1.        , 0.        , 0.        , 0.        ],
       [0.5       , 0.5       , 0.        , 0.        ],
       [0.33333334, 0.33333334, 0.33333334, 0.        ],
       [0.25      , 0.25      , 0.25      , 0.25      ]], dtype=float32)>

In [48]:
xbow = weights @ x    # (B, T, T) @ (B, T, C) -----> (B, T, C)
xbow

<tf.Tensor: shape=(4, 4, 2), dtype=float32, numpy=
array([[[ 0.      ,  1.      ],
        [ 1.      ,  2.      ],
        [ 2.      ,  3.      ],
        [ 3.      ,  4.      ]],

       [[ 8.      ,  9.      ],
        [ 9.      , 10.      ],
        [10.      , 11.      ],
        [11.      , 12.      ]],

       [[16.      , 17.      ],
        [17.      , 18.      ],
        [18.      , 19.000002],
        [19.      , 20.      ]],

       [[24.      , 25.      ],
        [25.      , 26.      ],
        [26.      , 27.      ],
        [27.      , 28.      ]]], dtype=float32)>

### Single head of self-attention

In [49]:
tf.random.set_seed(42)
B, T, C = 4, 8, 32
x = tf.random.normal((B, T, C))

In [50]:
head_size = 16
key = tf.keras.layers.Dense(head_size, activation= None, use_bias= False)
query = tf.keras.layers.Dense(head_size, activation= None, use_bias= False)
value = tf.keras.layers.Dense(head_size, activation= None, use_bias= False)
k = key(x)
q = query(x)
weights = q @ tf.transpose(k, perm= [0, 2, 1])

In [51]:
k.shape, q.shape, weights.shape

(TensorShape([4, 8, 16]), TensorShape([4, 8, 16]), TensorShape([4, 8, 8]))

In [52]:
tril = tf.linalg.band_part(tf.ones((T, T)), -1, 0)
weights = tf.where(tril == 0, float('-inf'), weights)
weights = tf.nn.softmax(weights)
v = value(x)
out = weights @ v

In [53]:
out.shape

TensorShape([4, 8, 16])