In [1]:
import equinox as eqx
import jax
import jax.numpy as jnp
import optax
from jaxtyping import Array, Float, PRNGKeyArray, PyTree
from icecream import ic

In [2]:
def get_next_multiple_of_64(number: int) -> int:
    while number % 64 != 0:
        number += 1
    return number

In [3]:
import tiktoken

enc = tiktoken.get_encoding("gpt2")
encoded = enc.encode("hello, world")
print(f"{encoded=}")
# pad enc.n_vocab to nearest multiple of 64 to make it even and for efficiency
n_vocab = get_next_multiple_of_64(enc.n_vocab)
print(f"{n_vocab=}")

encoded=[31373, 11, 995]
n_vocab=50304


In [4]:
decoded = enc.decode(encoded)
print(f"{decoded=}")

decoded='hello, world'


In [5]:
def get_positional_encoding(n_tokens: int, n_vocab: int) -> Float[Array, "n_tokens n_vocab"]:     
    pos = jnp.arange(n_tokens)[:, jnp.newaxis]
    div_term = jnp.exp(jnp.arange(0, n_vocab, 2) * -(jnp.log(10000.0) / n_vocab))
    # alternatively: div_term = 1 / 10000 ** (jnp.arange(0, D, 2) / D) 
    # that's closer to the actual notation they used. 
    pos_enc = jnp.zeros((n_tokens, n_vocab))
    pos_enc = pos_enc.at[:, 0::2].set(jnp.sin(pos * div_term))
    pos_enc = pos_enc.at[:, 1::2].set(jnp.cos(pos * div_term))
    return pos_enc 


In [9]:
class MultiHeadAttention(eqx.Module):
    n_heads: int = eqx.field(static=True)
    qkv_size: int = eqx.field(static=True)

    query: eqx.nn.Linear
    key: eqx.nn.Linear
    value: eqx.nn.Linear

    output: eqx.nn.Linear
    def __init__(self, input_dim: int, n_heads: int, key: PRNGKeyArray) -> None:
        key, *subkeys = jax.random.split(key, 5)

        self.qkv_size = input_dim // n_heads
        
        self.query = eqx.nn.Linear(in_features=input_dim, out_features=n_heads * self.qkv_size, key=subkeys[0], use_bias=False)
        self.key = eqx.nn.Linear(in_features=input_dim, out_features=n_heads * self.qkv_size, key=subkeys[1], use_bias=False)
        self.value = eqx.nn.Linear(in_features=input_dim, out_features=n_heads * self.qkv_size, key=subkeys[2], use_bias=False)

        self.output = eqx.nn.Linear(in_features=input_dim, out_features=input_dim, key=subkeys[3], use_bias=False) 

        self.n_heads = n_heads

    def _project(self, proj, x):
        seq_length, _ = x.shape
        projection = jax.vmap(proj)(x)
        return projection.reshape(seq_length, self.n_heads, -1)

    def __call__(self, x: Array, masking: bool):
        T, _ = x.shape

        q = self._project(self.query, x)
        k = self._project(self.key, x)
        v = self._project(self.value, x)

        assert q.shape == (T, self.n_heads, self.qkv_size)
        assert k.shape == (T, self.n_heads, self.qkv_size)
        assert v.shape == (T, self.n_heads, self.qkv_size)

        dot_product_vmap = jax.vmap(
            lambda q, k: jnp.dot(q, k.T), 
            in_axes=(1, 1), 
            out_axes=1
        )
        attention_scores = dot_product_vmap(q, k)
        ic(attention_scores.shape)
        attention_scores = attention_scores / jnp.sqrt(self.qkv_size)
        if masking:
            mask = jnp.tril(jnp.ones(shape=(T, T))) == 1
            mask = jnp.expand_dims(mask, axis=1) # we add an extra dimension at axis 1 for broadcasting
            ic(mask.shape)
            attention_scores = jnp.where(mask, attention_scores, float("-inf"))
            # print(f"{attention_scores}")
        
        attention_scores = jax.nn.softmax(attention_scores, axis=-1)
        ic(attention_scores.shape)
        matmul_vmap = jax.vmap(
            lambda s, v: jnp.dot(s, v), 
            in_axes=(1, 1), 
            out_axes=1
        )

        output = matmul_vmap(attention_scores, v)
        # print(f"before reshaping {output.shape=}")
        output = output.reshape(T, -1)
        # print(f"after reshaping {output.shape=}")
        output = jax.vmap(self.output)(output)
        ic(output.shape)
        return output
    
n_vocab = 128
N_HEADS = 2
N_EMBD = 4096
T = 4 # 4 tokens 
mha = MultiHeadAttention(
    input_dim=n_vocab,
    n_heads=N_HEADS,
    key=jax.random.PRNGKey(21)
)

x = jax.random.uniform(shape=(T, n_vocab), key=jax.random.PRNGKey(11))
output = mha(x, True)


ic| attention_scores.shape: (4, 2, 4)
ic| mask.shape: (4, 1, 4)
ic| attention_scores.shape: (4, 2, 4)
ic| output.shape: (4, 128)


In [10]:
T = 4
tril = jnp.tril(jnp.ones(shape=(T, T)))
print(tril)
mask = jnp.where(tril == 0, jnp.full(shape=(T, T), fill_value=float("-inf")), jnp.zeros(shape=(T,T)))
print(f"{mask}")
mask = jax.nn.softmax(mask, axis=-1)
print(f"{mask}")

[[1. 0. 0. 0.]
 [1. 1. 0. 0.]
 [1. 1. 1. 0.]
 [1. 1. 1. 1.]]
[[  0. -inf -inf -inf]
 [  0.   0. -inf -inf]
 [  0.   0.   0. -inf]
 [  0.   0.   0.   0.]]
[[1.         0.         0.         0.        ]
 [0.5        0.5        0.         0.        ]
 [0.33333334 0.33333334 0.33333334 0.        ]
 [0.25       0.25       0.25       0.25      ]]


In [11]:
T = 4
h = 2
mask = jnp.tril(jnp.ones(shape=(T, T))) == 1
mask = jnp.expand_dims(mask, axis=1)
print(f"{mask.shape=}")
logits = jax.random.uniform(shape=(T, h, T), key=jax.random.PRNGKey(0))
logits = jnp.where(mask, logits, float("-inf"))
logits = jax.nn.softmax(logits, axis=-1)
print(f"{logits.shape=}")


mask.shape=(4, 1, 4)
logits.shape=(4, 2, 4)


In [12]:
T = 4
causal_mask_offset = 0
query_indices = jnp.arange(T)[:, None]
print(f"{query_indices.shape=}")
kv_indices = jnp.arange(T)[None, :]
mask = kv_indices <= query_indices + causal_mask_offset
print(mask)
logits = jax.random.uniform(shape=(mask.shape), key=jax.random.PRNGKey(221))
print(f"{jnp.finfo(logits.dtype).min=}")
print(logits)
logits = jnp.where(mask, logits, jnp.finfo(logits.dtype).min) # for more numerical stability
print(logits)
logits = jax.nn.softmax(logits, axis=-1)
print(logits)


query_indices.shape=(4, 1)
[[ True False False False]
 [ True  True False False]
 [ True  True  True False]
 [ True  True  True  True]]
jnp.finfo(logits.dtype).min=-3.4028235e+38
[[0.49120486 0.09953129 0.8435687  0.6532923 ]
 [0.7960056  0.16815436 0.27717125 0.25922954]
 [0.77414536 0.59465444 0.42191577 0.20185745]
 [0.11166441 0.01811409 0.218642   0.5060872 ]]
[[ 4.91204858e-01 -3.40282347e+38 -3.40282347e+38 -3.40282347e+38]
 [ 7.96005607e-01  1.68154359e-01 -3.40282347e+38 -3.40282347e+38]
 [ 7.74145365e-01  5.94654441e-01  4.21915770e-01 -3.40282347e+38]
 [ 1.11664414e-01  1.81140900e-02  2.18641996e-01  5.06087184e-01]]
[[1.         0.         0.         0.        ]
 [0.6520021  0.3479979  0.         0.        ]
 [0.3938847  0.32916766 0.27694768 0.        ]
 [0.22187074 0.20205593 0.246922   0.32915136]]


In [13]:
class RMSNorm(eqx.Module):
    weight: Array
    eps: float

    def __init__(self, dim: int, eps: float = 1e-6):
        super().__init__()
        self.eps = eps
        self.weight = jnp.ones(dim)

    def _norm(self, x: Array):
        return x * jax.lax.rsqrt(jnp.mean(x**2, axis=-1, keepdims=True) + self.eps)

    def __call__(self, x: Array) -> Array:
        output = self._norm(x)
        return output * self.weight



In [15]:
class Transformer(eqx.Module):
    input_embedding: eqx.nn.Embedding
    masked_mha: MultiHeadAttention
    feedforward: eqx.nn.MLP
    rms_norm: RMSNorm

    output: eqx.nn.Linear
    positional_encoding: Array 

    def __init__(self, n_dims: int, n_embd: int, n_heads: int, key: PRNGKeyArray, width_size: int=32, depth: int = 2, max_token_size: int = 8) -> None:
        key, *subkeys = jax.random.split(key, 20) # let's just split 20 for now, we'll probably need them later
        self.input_embedding = eqx.nn.Embedding(n_dims, n_embd, key=subkeys[0])
        self.masked_mha = MultiHeadAttention(input_dim=n_embd, n_heads=n_heads, key=subkeys[1])

        # Equinox has a built-in MLP module
        self.feedforward = eqx.nn.MLP(in_size=n_embd, out_size=n_embd, width_size=width_size, key=subkeys[2], depth=depth)
        self.positional_encoding = get_positional_encoding(max_token_size, n_embd)

        self.rms_norm = RMSNorm(dim=n_embd)

        self.output = eqx.nn.Linear(in_features=n_embd, out_features=n_dims, key=subkeys[4], use_bias=False)

    def __call__(self, x):
        print(f"side effect")
        x = jax.vmap(self.input_embedding)(x)
        x += self.positional_encoding
        x = self.rms_norm(self.masked_mha(x, masking=True) + x) # residual connection
        x = self.rms_norm(jax.vmap(self.feedforward)(x) + x) # residual connection
        x = jax.vmap(self.output)(x)
        # x = jax.nn.softmax(x) # we don't softmax here, because we want the raw logits for our loss function 
        # but you can totally softmax here and inverse that later; 
        return x 


key = jax.random.PRNGKey(42)
INPUT_DIMS = 128
N_EMBD = 4096
N_HEADS = 4
MAX_T = 8
transformer = Transformer(n_dims=INPUT_DIMS, n_embd=N_EMBD, n_heads=N_HEADS, key=key)

x = jnp.ones(shape=(MAX_T), dtype=jnp.int32)

transformer(x).shape



ic| attention_scores.shape: (8, 4, 8)


side effect


ic| mask.shape: (8, 1, 8)
ic| attention_scores.shape: (8, 4, 8)
ic| output.shape: (8, 4096)


(8, 128)

In [77]:
from tinyshakespeareloader.hamlet import get_data


data = get_data()


train_dataloader, test_dataloader, vocabulary_size, chars, encode, decode = data["train_dataloader"], data["test_dataloader"], data["vocabulary_size"], data["chars"], data["encode"], data["decode"]
key = jax.random.PRNGKey(420)
INPUT_DIMS: int = int(vocabulary_size)
N_EMBD = 32
N_HEADS = 4
MAX_T = 8

def loss_fn(transformer: Transformer, x: Array, y: Array):
    logits = eqx.filter_vmap(transformer)(x)
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, y)

    return jnp.mean(loss)

def evaluate(transformer: Transformer, test_dataloader):
    loss = 0
    jitted_loss_fn = eqx.filter_jit(loss_fn)
    for x, y in test_dataloader:
        x = jnp.array(x.numpy())
        y = jnp.array(y.numpy())
        loss += jitted_loss_fn(transformer, x, y)
    
    return loss / len(test_dataloader)

@eqx.filter_jit
def step(transformer: PyTree, opt_state: optax.OptState, optimiser: optax.GradientTransformation, x: Array, y: Array):
    loss, grads = eqx.filter_value_and_grad(loss_fn)(transformer, x, y)
    updates, opt_state = optimiser.update(grads, opt_state, transformer)
    transformer = eqx.apply_updates(transformer, updates)
    return transformer, opt_state, loss

transformer = Transformer(n_dims=INPUT_DIMS, n_embd=N_EMBD, n_heads=N_HEADS, key=key)
#start_loss = evaluate(transformer, test_dataloader)
#print(f"{start_loss=}")
optimiser = optax.adamw(learning_rate=0.001)
opt_state = optimiser.init(eqx.filter(transformer, eqx.is_inexact_array))
for i, (x, y) in enumerate(train_dataloader):
    x = jnp.array(x.numpy())
    y = jnp.array(y.numpy())
    transformer, opt_state, loss = step(transformer, opt_state, optimiser, x, y)
    if i % 100 == 0:
        eval_loss = evaluate(transformer, test_dataloader)
        print(f"{i}. {loss=}, {eval_loss=}")

print("done.")
print(f"{evaluate(transformer, test_dataloader)=}")


side effect
side effect
0. loss=Array(4.3851104, dtype=float32), eval_loss=Array(4.289815, dtype=float32)
100. loss=Array(3.234735, dtype=float32), eval_loss=Array(3.4245338, dtype=float32)
200. loss=Array(3.30385, dtype=float32), eval_loss=Array(3.2770212, dtype=float32)
300. loss=Array(2.545452, dtype=float32), eval_loss=Array(3.1839383, dtype=float32)
400. loss=Array(2.1868153, dtype=float32), eval_loss=Array(3.1490939, dtype=float32)
500. loss=Array(2.4573233, dtype=float32), eval_loss=Array(3.0842142, dtype=float32)
600. loss=Array(3.0487125, dtype=float32), eval_loss=Array(3.0219848, dtype=float32)
700. loss=Array(2.1939778, dtype=float32), eval_loss=Array(2.9301224, dtype=float32)
800. loss=Array(1.9340211, dtype=float32), eval_loss=Array(2.9110885, dtype=float32)
900. loss=Array(2.3659914, dtype=float32), eval_loss=Array(2.970211, dtype=float32)
1000. loss=Array(2.9821432, dtype=float32), eval_loss=Array(2.8746958, dtype=float32)
1100. loss=Array(2.7669654, dtype=float32), eval

KeyboardInterrupt: 