In [67]:
import jax
import jax.numpy as jnp
from flax import linen as nn
import optax
import math

In [68]:
batch_size = 4
block_size = 8

In [69]:
jnp.arange(block_size)

Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)

In [70]:
key = jax.random.PRNGKey(42)

data = jax.random.randint(key, shape=(4, 8), minval=1, maxval=164)

In [71]:
data.shape

(4, 8)

In [72]:
data

Array([[  2,  93,   9,  79, 115, 137,  59,  76],
       [  9, 121,  34, 140,   9,  64,  97,  48],
       [ 96, 163, 108,  63,  88, 138,  55,  23],
       [ 31,  29,  56,  10,  55,   5,  31,  25]], dtype=int32)

In [73]:
vocab_size = 160
embed_dim = 32
block_size = 8
head_size = 16
head_num = 2

In [74]:
class SingleAttentionHead(nn.Module):
    embed_dim: int
    head_size: int

    def setup(self):
        self.key = nn.Dense(self.head_size, use_bias=False)
        self.query = nn.Dense(self.head_size, use_bias=False)
        self.value = nn.Dense(self.head_size, use_bias=False)

    def __call__(self, data):
        k = self.key(data)
        q = self.query(data)
        v = self.value(data)

        weights = jnp.matmul(q,jnp.swapaxes(k, -2,-1)) / math.sqrt(self.head_size) # (B,T,T)
        #Lower triangular mask matrix of the size B, T, C (same btw as attention)
        mask = jnp.tril(weights)
        # for every zero, make it to -inf
        weights = nn.softmax(jnp.where(mask == 0, -9e16, weights), axis=-1) # axis=-1 since we only want to softmax for each row of T not for the whole data as a whole

        attention = jnp.matmul(weights, v) # (B,T,C)

        return attention

In [75]:
class MultiHeadAttention(nn.Module):
    head_num: int
    embed_dim: int
    
    def setup(self):
        self.heads = [SingleAttentionHead(embed_dim=self.embed_dim, head_size=self.embed_dim//self.head_num) for _ in range(self.head_num)]
        #self.think = nn.Dense()
        
    def __call__(self, data):
        multiple_attentions = jnp.concatenate([head(data) for head in self.heads], axis=-1)
        
        
        return multiple_attentions

In [76]:
class Model(nn.Module):

    def setup(self):
        self.token_embedding_table = nn.Embed(vocab_size, embed_dim)
        self.position_embedding_table = nn.Embed(block_size, embed_dim) # 1-D array of blocksize (context window), device=optional?
        
        self.attention_head = SingleAttentionHead(embed_dim=embed_dim, head_size=head_size)
        
        self.multihead_attention = MultiHeadAttention(head_num=head_num, embed_dim=embed_dim)
        
    def __call__(self, data):
        token = self.token_embedding_table(data)
        position = self.position_embedding_table(jnp.arange(block_size))
        embedded_data = token + position

        attention = self.multihead_attention(embedded_data)

        return attention

In [77]:
model = Model()
params = model.init(jax.random.PRNGKey(0), jax.numpy.zeros((1,1), dtype=jax.numpy.int32))['params']
attention = model.apply({'params': params}, data)

In [78]:
attention.shape

(4, 8, 32)