In [3]:
from dataclasses import dataclass

In [4]:
@dataclass
class ModelConfig:
  vocab_size : int = 50257
  n_head : int = 12
  n_embed : int = 768
  n_layer : int = 12
  block_size : int = 1024
  dropout_rate : float = 0.1


In [13]:
import jax
import jax.numpy as jnp

In [5]:
from flax import linen as nn
import jax.numpy as jnp

class CausalSelfAttention(nn.Module):
  config:ModelConfig

  @nn.compact

  def __call__(self, x, deterministic = True):
    assert len(x.shape)==3
    b, l, d = x.shape

    q = nn.Dense(self.config.n_embed)(x)
    k = nn.Dense(self.config.n_embed)(x)
    v = nn.Dense(self.config.n_embed)(x)

    q = jnp.reshape(q, (b, l, d//self.config.n_head, self.config.n_head))
    k = jnp.reshape(k, (b, l, d//self.config.n_head, self.config.n_head))
    v = jnp.reshape(v, (b, l, d//self.config.n_head, self.config.n_head))

    norm = jnp.sqrt(list(jnp.shape(k))[-1])

    attn = jnp.matmul(q, jnp.transpose(k, (0, 1, 3, 2))) / norm
    mask = jnp.tril(attn)

    attn = jnp.where(mask[:, :, :l, :l], attn, float("-inf"))
    probs = jax.nn.softmax(attn, axis=-1)
    y = jnp.matmul(probs, v)
    y = jnp.reshape(y, (b, l, d))
    y = nn.Dense(self.config.n_embed)(y)

    return y





In [6]:
class MLP(nn.Module):

  config : ModelConfig

  @nn.compact
  def __call__(self, x, deterministic = True):
    x=nn.Dense(self.config.n_embed*4)(x)
    x=nn.gelu(x)
    x=nn.Dropout(rate=self.config.dropout_rate)(x, deterministic=True)
    x=nn.Dense(self.config.n_embed)(x)
    x=nn.Dropout(rate=self.config.dropout_rate)(x, deterministic=True)
    return x


class Block(nn.Module):
  config : ModelConfig

  @nn.compact
  def __call__(self, x):
    x = nn.LayerNorm()(x)
    x = x + CausalSelfAttention(self.config)(x)
    x = nn.LayerNorm()(x)
    x = x + MLP(self.config)(x)
    return x

In [7]:
class GPT(nn.Module):
  config : ModelConfig

  @nn.compact

  def __call__(self, x, deterministic=True):
    B, T = x.shape
    assert T <= self.config.block_size

    pos = jnp.arange(0, T)[None]
    pos_emb = nn.Embed(self.config.block_size, self.config.n_embed)(pos)
    wte = nn.Embed(self.config.vocab_size, self.config.n_embed)
    tok_emb = wte(x)
    x = tok_emb + pos_emb

    for _ in range(self.config.n_layer):
      x = Block(self.config)(x)
    x = nn.LayerNorm()(x)
    logits = nn.Dense(config.n_embed, config.vocab_size)(x)

    return logits


  def init(self, rng):
    tokens = jnp.zeros((1, self.config.block_size), dtype=jnp.uint16)
    params = jax.jit(super().init, static_argnums=(2,))(rng, tokens, True)
    return params



In [10]:
def count_params(params):
  p=jax.tree_util.tree_map(lambda a: a.size if isinstance(a, jnp.ndarray) else 0, params)
  return jax.tree_util.tree_reduce(lambda a, b : a+b, p)

In [35]:
config = ModelConfig()
key = jax.random.PRNGKey(0)
model = GPT(config)
params = model.init(key)
# count_params(params)

In [15]:
print(model)

GPT(
    # attributes
    config = ModelConfig(vocab_size=50257, n_head=12, n_embed=768, n_layer=12, block_size=1024, dropout_rate=0.1)
)


In [16]:
# from transformers import AutoModelForCausalLM

# model = AutoModelForCausalLM.from_pretrained("gpt2")

# print(model)


In [None]:
params

{'params': {'Dense_0': {'kernel': Array([[-0.95366824,  0.43563786, -0.7954482 , -0.49190977],
          [ 0.5308177 ,  0.74109775,  0.6027838 , -0.02684463],
          [ 0.24410059,  0.44881055, -1.050442  ,  0.4932145 ]],      dtype=float32),
   'bias': Array([0., 0., 0., 0.], dtype=float32)},
  'Dense_1': {'kernel': Array([[ 0.00166371,  0.16012576],
          [ 0.09040862, -0.42028674],
          [ 0.32189375,  0.43688348],
          [-0.5580085 , -0.36031362]], dtype=float32),
   'bias': Array([0., 0.], dtype=float32)}}}

In [36]:
class DataLoader:
  def __init__(self, B, T):
    self.current_position = 0
    self.B = B
    self.T = T

    with open("/content/input.txt","r") as f:
      text = f.read()
    enc = tiktoken.get_encoding("gpt2")
    self.tokens = jnp.array(enc.encode(text))
    print(f"loaded {len(self.tokens)} tokens in the datasets" )
    print(f" 1 epoch = {len(self.tokens)//(B*T)} batches")

  def next_batch(self):
    B,T = self.B, self.T
    buf = self.tokens[self.current_position:self.current_position+B*T+1]
    x,y = jnp.reshape(buf[:-1],(B,T)), jnp.reshape(buf[1:],(B,T))
    self.current_position += B*T
    if self.current_position + B*T+1 > len(self.tokens):
      self.current_position = 0
    return x,y

In [40]:
import jax
import jax.numpy as jnp
import optax
from flax import linen as nn
from flax.training.train_state import TrainState  # <- THIS is the TrainState you need
from flax.core import FrozenDict
from typing import Tuple


def init_train_state(key, config) -> TrainState:
  model = GPT(config)
  params = model.init(key)
  optimizer = optax.adamw(3e-4, b1=0.9, b2=0.98, eps=1e-9, weight_decay=1e-1)
  train_state = TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=optimizer)
  return train_state

@jax.jit
def train_step(state: TrainState, x: jnp.ndarray, y: jnp.ndarray) -> Tuple[jnp.ndarray, TrainState]:

  def loss_fn(params: FrozenDict) -> jnp.ndarray:

      logits = state.apply_fn(params, x, False)
      loss = optax.softmax_cross_entropy_with_integer_labels(logits, y).mean()
      print(loss)
      return loss

  loss, grads = jax.value_and_grad(loss_fn, has_aux=False)(state.params)
  new_state = state.apply_gradients(grads=grads)
  return loss, new_state

In [41]:
import tiktoken
import time, math
train_steps = 50
data_loader = DataLoader(B=4, T=128)
x, y = data_loader.next_batch()
for step in range(train_steps):
  t0 = time.time()
  loss, train_state = train_step(train_state, x, y)
  t1 = time.time()
  dt = t1-t0

  tokens_processed = data_loader.B * data_loader.T
  tokens_per_sec = tokens_processed/dt

  print(f"step {step}/{train_steps} | loss : {loss:.4f} | dt {dt*1000 :.2f}ms | token/sec = {tokens_per_sec:.3f}")

loaded 338025 tokens in the datasets
 1 epoch = 660 batches
Traced<ShapedArray(float32[])>with<JVPTrace> with
  primal = Traced<ShapedArray(float32[])>with<DynamicJaxprTrace>
  tangent = Traced<ShapedArray(float32[])>with<JaxprTrace> with
    pval = (ShapedArray(float32[]), None)
    recipe = JaxprEqnRecipe(eqn_id=<object object at 0x7bf49468f1c0>, in_tracers=(Traced<ShapedArray(float32[4,128]):JaxprTrace>,), out_tracer_refs=[<weakref at 0x7bf4cd794590; to 'JaxprTracer' at 0x7bf4cd796d00>], out_avals=[ShapedArray(float32[])], primitive=pjit, params={'jaxpr': { lambda ; a:f32[4,128]. let
    b:f32[] = reduce_sum[axes=(0, 1)] a
    c:f32[] = div b 512.0
  in (c,) }, 'in_shardings': (UnspecifiedValue,), 'out_shardings': (UnspecifiedValue,), 'in_layouts': (None,), 'out_layouts': (None,), 'donated_invars': (False,), 'ctx_mesh': None, 'name': '_mean', 'keep_unused': False, 'inline': True, 'compiler_options_kvs': ()}, effects=set(), source_info=<jax._src.source_info_util.SourceInfo object at 

In [38]:
data_loader = DataLoader(B=4, T=128)
x, y = data_loader.next_batch()

loaded 338025 tokens in the datasets
 1 epoch = 660 batches


In [39]:
print(x)
print(y)

[[ 5962 22307    25   198  8421   356  5120   597  2252    11  3285   502
   2740    13   198   198  3237    25   198  5248   461    11  2740    13
    198   198  5962 22307    25   198  1639   389   477 12939  2138   284
   4656   621   284  1145   680    30   198   198  3237    25   198  4965
   5634    13 12939    13   198   198  5962 22307    25   198  5962    11
    345   760   327  1872   385  1526 28599   318  4039  4472   284   262
    661    13   198   198  3237    25   198  1135   760   470    11   356
    760   470    13   198   198  5962 22307    25   198  5756   514  1494
    683    11   290   356  1183   423 11676   379   674   898  2756    13
    198  3792   470   257 15593    30   198   198  3237    25   198  2949
    517  3375   319   470    26  1309   340   307]
 [ 1760    25  1497    11  1497     0   198   198 12211 22307    25   198
   3198  1573    11   922  4290    13   198   198  5962 22307    25   198
   1135   389 17830  3595  4290    11   262  1458  1173  1547