In [3]:
import jax
import jax.numpy as jnp
from flax import nnx
import optax

In [4]:
class SelfAttention(nnx.Module):
    def __init__(self, num_heads, d_model, rngs: nnx.Rngs):
        self.num_heads = num_heads
        self.d_model = d_model
        self.q_proj = nnx.Linear(d_model, d_model, rngs=rngs)
        self.k_proj = nnx.Linear(d_model, d_model, rngs=rngs)
        self.v_proj = nnx.Linear(d_model, d_model, rngs=rngs)
        self.out_proj = nnx.Linear(d_model, d_model, rngs=rngs)

    def __call__(self, x):
        batch, seq_len, d_model = x.shape
        head_dim = d_model // self.num_heads

        # Compute queries, keys, and values
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # Reshape to (batch, seq_len, num_heads, head_dim) then transpose to (batch, num_heads, seq_len, head_dim)
        q = q.reshape(batch, seq_len, self.num_heads, head_dim)
        k = k.reshape(batch, seq_len, self.num_heads, head_dim)
        v = v.reshape(batch, seq_len, self.num_heads, head_dim)
        q = jnp.transpose(q, (0, 2, 1, 3))
        k = jnp.transpose(k, (0, 2, 1, 3))
        v = jnp.transpose(v, (0, 2, 1, 3))

        # Scaled dot-product attention
        scale = head_dim ** -0.5
        attn_logits = jnp.einsum('bhqd,bhkd->bhqk', q, k) * scale
        attn_weights = jax.nn.softmax(attn_logits, axis=-1)
        attn_output = jnp.einsum('bhqk,bhvd->bhqd', attn_weights, v)

        # Transpose back and reshape to (batch, seq_len, d_model)
        attn_output = jnp.transpose(attn_output, (0, 2, 1, 3))
        attn_output = attn_output.reshape(batch, seq_len, d_model)
        output = self.out_proj(attn_output)
        return output

In [5]:
class MLPBlock(nnx.Module):
    def __init__(self, d_model, d_hidden, rngs: nnx.Rngs):
        self.linear1 = nnx.Linear(d_model, d_hidden, rngs=rngs)
        self.linear2 = nnx.Linear(d_hidden, d_model, rngs=rngs)

    def __call__(self, x):
        x = nnx.gelu(self.linear1(x))
        x = self.linear2(x)
        return x

In [6]:
class TransformerBlock(nnx.Module):
    def __init__(self, d_model, num_heads, d_hidden, rngs: nnx.Rngs):
        # Use BatchNorm for simplicity (alternatively, LayerNorm can be used)
        self.ln1 = nnx.BatchNorm(d_model, rngs=rngs)
        self.attn = SelfAttention(num_heads, d_model, rngs=rngs)
        self.ln2 = nnx.BatchNorm(d_model, rngs=rngs)
        self.mlp = MLPBlock(d_model, d_hidden, rngs=rngs)

    def __call__(self, x):
        attn_out = self.attn(self.ln1(x))
        x = x + attn_out
        mlp_out = self.mlp(self.ln2(x))
        x = x + mlp_out
        return x

In [7]:
class MiniGPT(nnx.Module):
    def __init__(self, vocab_size, seq_len, d_model, num_heads, d_hidden, num_layers, rngs: nnx.Rngs):
        self.vocab_size = vocab_size
        self.seq_len = seq_len
        self.token_embed = nnx.Embed(vocab_size, d_model, rngs=rngs)
        # Initialize positional embedding as a parameter (learned)
        self.pos_embed = nnx.Param(jnp.zeros((seq_len, d_model)), sharding=None)
        # Create a list of Transformer blocks
        self.blocks = [TransformerBlock(d_model, num_heads, d_hidden, rngs=rngs) for _ in range(num_layers)]
        self.ln_f = nnx.BatchNorm(d_model, rngs=rngs)
        self.head = nnx.Linear(d_model, vocab_size, rngs=rngs)

    def __call__(self, x):
        # x shape: (batch, seq_len) containing token ids
        x = self.token_embed(x)  # (batch, seq_len, d_model)
        x = x + self.pos_embed    # add positional embedding
        for block in self.blocks:
            x = block(x)
        x = self.ln_f(x)
        logits = self.head(x)     # (batch, seq_len, vocab_size)
        return logits

In [8]:
@nnx.jit
def loss_fn(model, x, y, smoothing=0.1):
    logits = model(x)  # shape: (batch, seq_len, vocab_size)
    # Convert integer labels to one-hot vectors
    one_hot = jax.nn.one_hot(y, num_classes=logits.shape[-1])
    # Apply label smoothing: for correct class use (1 - smoothing), distribute smoothing uniformly to all classes.
    smoothed_labels = one_hot * (1 - smoothing) + smoothing / logits.shape[-1]
    # Compute log softmax of logits
    log_probs = jax.nn.log_softmax(logits, axis=-1)
    # Compute the mean cross-entropy loss
    loss = -jnp.mean(jnp.sum(smoothed_labels * log_probs, axis=-1))
    return loss

In [9]:
@nnx.jit
def train_step(model, optimizer, x, y):
    loss, grads = nnx.value_and_grad(loss_fn)(model, x, y)
    optimizer.update(grads)  # In-place update of model parameters via shared references
    return loss

@nnx.jit
def eval_step(model, x, y):
    logits = model(x)
    loss = loss_fn(model, x, y)
    predictions = jnp.argmax(logits, axis=-1)
    accuracy = jnp.mean(predictions == y)
    return loss, accuracy

In [10]:
vocab_size = 100
seq_len = 16
batch_size = 32
num_batches = 100  # Number of training batches for this demo

# Initialize a random key for synthetic data generation.
rng = jax.random.PRNGKey(42)

In [11]:
def get_batch(rng):
    rng, subkey = jax.random.split(rng)
    x = jax.random.randint(subkey, (batch_size, seq_len), 0, vocab_size)
    rng, subkey = jax.random.split(rng)
    y = jax.random.randint(subkey, (batch_size, seq_len), 0, vocab_size)
    return rng, x, y

d_model = 32      # Embedding and hidden dimension
num_heads = 4     # Number of attention heads
d_hidden = 64     # Hidden dimension for the MLP in each Transformer block
num_layers = 2    # Number of Transformer blocks

# Create a PRNG stream for Flax NNX initialization.
rngs = nnx.Rngs(0)

# Instantiate the model
model = MiniGPT(vocab_size, seq_len, d_model, num_heads, d_hidden, num_layers, rngs=rngs)

# Create the optimizer; the optimizer shares a reference with the model for in-place updates.
optimizer = nnx.Optimizer(model, optax.adam(learning_rate=1e-3))

for i in range(num_batches):
    rng, x_batch, y_batch = get_batch(rng)
    loss = train_step(model, optimizer, x_batch, y_batch)
    if i % 10 == 0:
        print(f"Batch {i:03d}, Loss: {loss:.4f}")

num_eval_batches = 20
eval_losses = []
eval_accuracies = []

for i in range(num_eval_batches):
    rng, x_eval, y_eval = get_batch(rng)
    loss, acc = eval_step(model, x_eval, y_eval)
    eval_losses.append(loss)
    eval_accuracies.append(acc)

avg_eval_loss = jnp.mean(jnp.array(eval_losses))
avg_eval_accuracy = jnp.mean(jnp.array(eval_accuracies))
print(f"Evaluation Loss: {avg_eval_loss:.4f}, Evaluation Accuracy: {avg_eval_accuracy*100:.2f}%")

rng, x_infer, _ = get_batch(rng)
logits = model(x_infer)
print("Inference logits shape:", logits.shape)

Batch 000, Loss: 5.0854
Batch 010, Loss: 4.9538
Batch 020, Loss: 4.8708
Batch 030, Loss: 4.7964
Batch 040, Loss: 4.7713
Batch 050, Loss: 4.6911
Batch 060, Loss: 4.6878
Batch 070, Loss: 4.6713
Batch 080, Loss: 4.6481
Batch 090, Loss: 4.6484
Evaluation Loss: 4.6365, Evaluation Accuracy: 0.90%
Inference logits shape: (32, 16, 100)
