In [673]:
import torch
import torch.nn as nn
from torch.nn import functional as F

import jax
import jax.numpy as jnp
from flax import linen as nn
import optax
import math

In [674]:
batch_size = 4
block_size = 8

In [675]:
ix = torch.randint(100, (batch_size,))

In [676]:
ix.shape

torch.Size([4])

In [677]:
ix

tensor([13, 83, 91, 20])

In [678]:
torch.arange(block_size)

tensor([0, 1, 2, 3, 4, 5, 6, 7])

In [679]:
jnp.arange(block_size)

Array([0, 1, 2, 3, 4, 5, 6, 7], dtype=int32)

In [680]:
key = jax.random.PRNGKey(42)

data = jax.random.randint(key, shape=(4, 8), minval=1, maxval=164)

In [681]:
data.shape

(4, 8)

In [682]:
data

Array([[  2,  93,   9,  79, 115, 137,  59,  76],
       [  9, 121,  34, 140,   9,  64,  97,  48],
       [ 96, 163, 108,  63,  88, 138,  55,  23],
       [ 31,  29,  56,  10,  55,   5,  31,  25]], dtype=int32)

In [683]:
data_transposed = jnp.swapaxes(data, -2,-1)
data_transposed.shape

weights = jnp.matmul(data,jnp.swapaxes(data, -2,-1))
print(weights.shape)

(4, 4)


In [684]:
vocab_size = 160
embed_dim = 32
block_size = 8
head_size = 16

In [685]:
class SingleAttentionHead(nn.Module):
    embed_size: int
    head_size: int

    def setup(self):
        self.key = nn.Dense(self.head_size, use_bias=False)
        self.query = nn.Dense(self.head_size, use_bias=False)
        self.value = nn.Dense(self.head_size, use_bias=False)

    def __call__(self, data):
        k = self.key(data)
        q = self.query(data)
        v = self.value(data)

        weights = jnp.matmul(q,jnp.swapaxes(k, -2,-1)) / math.sqrt(self.head_size) # (B,T,T)
        #Lower triangular mask matrix of the size B, T, C (same btw as attention)
        mask = jnp.tril(weights)
        # for every zero, make it to -inf
        weights = nn.softmax(jnp.where(mask == 0, -9e16, weights), axis=-1) # axis=-1 since we only want to softmax for each row of T not for the whole data as a whole

        attention = jnp.matmul(weights, v) # (B,T,C)

        return attention

In [686]:
class Model(nn.Module):

    def setup(self):
        self.token_embedding_table = nn.Embed(vocab_size, embed_dim)
        self.position_embedding_table = nn.Embed(block_size, embed_dim) # 1-D array of blocksize (context window), device=optional?
        self.attention_head = SingleAttentionHead(embed_size=embed_dim, head_size=head_size)
    def __call__(self, data):
        token = self.token_embedding_table(data)
        position = self.position_embedding_table(jnp.arange(block_size))
        embedded_data = token + position

        attention = self.attention_head(embedded_data)

        return attention

In [687]:
model = Model()
params = model.init(jax.random.PRNGKey(0), jax.numpy.zeros((1,1), dtype=jax.numpy.int32))['params']
attention = model.apply({'params': params}, data)

In [688]:
attention.shape

(4, 8, 16)

In [689]:
attention_head = SingleAttentionHead(embed_size=embed_dim, head_size=head_size)

In [690]:
attention_head(data)

AttributeError: "SingleAttentionHead" object has no attribute "key". If "key" is defined in '.setup()', remember these fields are only accessible from inside 'init' or 'apply'.

In [None]:


weights = jnp.dot(q,jnp.transpose(k)) / math.sqrt(head_size) # (B,T,T)
#getting the dimensions from the attention
#B, T, C = attention.shape

#Lower triangular mask matrix of the size B, T, C (same btw as attention)
mask = jnp.tril(weights)
        # for every zero, make it to -inf
weights = nn.softmax(jnp.where(mask == 0, -9e16, weights), axis=-1) # axis=-1 since we only want to softmax for each row of T not for the whole data as a whole

attention = jnp.dot(weights, v) # (B,T,C)