<a href="https://colab.research.google.com/github/MadElf1337/jax-gpt2/blob/main/jax_gpt2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from typing import Any, Optional, Tuple, Callable
from dataclasses import dataclass
from functools import partial
import jax
import jax.numpy as jnp
from jax.typing import DTypeLike as Dtype
import flax.linen as nnx
from flax.core import FrozenDict, freeze, unfreeze
from flax.traverse_util import flatten_dict, unflatten_dict
from tqdm import tqdm

In [None]:
!pip install -q transformers

In [None]:
@dataclass(frozen=True)
class GPTConfig:
    ctx_len: int = 1024
    vocab_size: int = 50304
    n_layers: int = 12
    n_heads: int = 12
    n_embeds: int = 768
    dropout_rate: float = 0.1
    use_bias: bool = True
    dtype: Optional[str] = None

In [None]:
class CausalSelfAttention(nnx.Module):

    def __init__(self, config: GPTConfig, deterministic: Optional[bool] = None, use_proj_bias: bool = True):
        self.config = config
        self.use_proj_bias = use_proj_bias
        self.dtype = config.dtype
        self.deterministic = deterministic

    def __call__(self, config: GPTConfig, x, mask, deterministic=None):
        B, T, C = x.shape
        assert C % config.n_heads == 0
        head_dim = C // config.n_heads
        deterministic = nnx.merge('deterministic', self.deterministic, deterministic)

        qkv = nnx.Linear(C, 3*C, use_bias=self.use_proj_bias, dtype=config.dtype, name='c_attn')(x)
        qkv = jnp.einsum('ijk->ij(hd)', qkv, h=3*config.n_heads, d=head_dim)
        q, k, v = jnp.array_split(qkv, 3, axis=2)
        scale = 1.0 / jnp.sqrt(head_dim).astype(config.dtype)
        attn = jnp.einsum('bthd,bshd->bths', q, k)*scale
        attn = jnp.where(mask, attn, jnp.finfo(self.dtype).min)
        attn = nnx.softmax(attn).astype(self.dtype)
        attn = nnx.Dropout(config.dropout_rate)(attn, deterministic=deterministic)

        x = jnp.einsum('bths, bshd->bt(hd)', attn, v)
        x = nnx.Linear(3*C, C, use_bias=self.use_proj_bias, dtype=config.dtype, name='c_proj')(x)
        x = nnx.Dropout(rate=config.dropout_rate)(x, deterministic=deterministic)

        return x

In [None]:
class FFN(nnx.Module):
    def __init__(self, config, deterministic: Optional[bool] = None):
        self.deterministic = deterministic
        self.config = config

    def __call__(self, x, config: GPTConfig, deterministic=None):
        B, T, C = x.shape
        x = nnx.Linear(C, 4*C, dtype=config.dtype, use_bias=config.use_bias, name='c_fc')(x)
        x = nnx.gelu(x, approximate=True)
        x = nnx.Linear(4*C, C, dtype=config.dtype, use_bias=config.use_bias, name='c_proj')(x)
        x = nnx.Dropout(config.dropout_rate)(x, deterministic)
        return x

In [None]:
class Block(nnx.Module):
    config: GPTConfig

    def __init__(self, config: GPTConfig, deterministic: Optional[bool] = None):
        self.config = config
        self.ln_1 = nnx.LayerNorm(epsilon=1e-5, dtype=config.dtype, use_bias=config.use_bias)
        self.attn = CausalSelfAttention(config.n_heads, config.dtype, dropout_rate=config.dropout_rate)
        self.ln_2 = nnx.LayerNorm(epsilon=1e-5, dtype=config.dtype, use_bias=config.use_bias)
        self.ffn = FFN(config)

    def __call__(self, x, mask=None, deterministic=None):
        x = x + self.attn(self.ln_1(x), mask, deterministic)
        x = x + self.ffn(self.ln_2(x), deterministic)

        return x

In [None]:
def make_attention_mask(query_input: jax.Array,
                        key_input: jax.Array,
                        pairwise_fn: Callable[..., Any] = jnp.multiply,
                        extra_batch_dims: int = 0,
                        dtype: Dtype = jnp.float32):
    """Mask-making helper for attention weights."""
    mask = pairwise_fn(jnp.expand_dims(query_input, axis=-1),
                       jnp.expand_dims(key_input, axis=-2))
    mask = jnp.expand_dims(mask, axis=-3)
    mask = jnp.expand_dims(mask, axis=tuple(range(extra_batch_dims)))
    return mask.astype(dtype)

def make_causal_mask(x: jax.Array, extra_batch_dims: int = 0, dtype: Dtype = jnp.float32) -> jax.Array:
    """Make a causal mask for self-attention."""
    idxs = jnp.broadcast_to(jnp.arange(x.shape[-1], dtype=jnp.int32), x.shape)
    return make_attention_mask(idxs, idxs, jnp.greater_equal, extra_batch_dims=extra_batch_dims, dtype=dtype)

In [None]:
class GPT(nnx.Module):  # Inherit from nnx.Module
    config: GPTConfig

    def setup(self):  # Use setup() for module initialization
        # Initialize wte and wpe here, no need for conditional logic inside __call__
        self.wte = nnx.Embed(self.config.vocab_size, self.config.n_embeds, dtype=self.config.dtype, name='wte')
        self.wpe = nnx.Embed(self.config.ctx_len, self.config.n_embeds, dtype=self.config.dtype, name='wpe')
        self.blocks = [
            Block(self.config, name=str(i)) for i in range(self.config.n_layers)
        ]
        self.ln_f = nnx.LayerNorm(epsilon=1e-5, dtype=self.config.dtype, use_bias=self.config.use_bias, name='ln_f')


    def __call__(self, idx, deterministic=None):
        B, T = idx.shape
        assert T <= self.config.ctx_len

        pos = jnp.arange(0, T)[None]
        attn_mask = make_causal_mask(idx, dtype=bool)

        token_embeddings = self.wte(idx)  # Use self.wte
        pos_embeddings = self.wpe(pos)  # Use self.wpe and pos

        x = nnx.Dropout(rate=self.config.dropout_rate)(token_embeddings + pos_embeddings, deterministic=deterministic)

        for block in self.blocks:  # Iterate through blocks
            x = block(x, attn_mask, deterministic=deterministic)

        x = self.ln_f(x)  # Use self.ln_f

        logits = jnp.einsum('...d,vd->...v', self.wte.variables['params']['embedding'], x)  # Access embedding directly
        return logits
