# Code adapted from https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial2/Introduction_to_JAX.html

In [56]:
import os
import math
import numpy as np
import time
import jax
import jax.numpy as jnp
from flax import linen as nn
import optax
from flax.training import train_state
from flax.training import checkpoints
from typing import Any, Callable, Dict, Tuple
from ml_collections import ConfigDict

import torch.utils.data as data

import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('svg', 'pdf') # For export
from matplotlib.colors import to_rgba
import seaborn as sns
sns.set()

from tqdm.auto import tqdm

  set_matplotlib_formats('svg', 'pdf') # For export


# JAX Example Network

In [15]:
class SimpleClassifierCompact(nn.Module):
    num_hidden : int   # Number of hidden neurons
    num_outputs : int  # Number of output neurons

    @nn.compact  # Tells Flax to look for defined submodules
    def __call__(self, x):
        # Perform the calculation of the model to determine the prediction
        # while defining necessary layers
        x = nn.Dense(features=self.num_hidden)(x)
        x = nn.tanh(x)
        x = nn.Dense(features=self.num_outputs)(x)
        return x

In [19]:
rng = jax.random.PRNGKey(42)
model = SimpleClassifierCompact(num_hidden=8, num_outputs=1)
rng, inp_rng, init_rng = jax.random.split(rng, 3)
inp = jax.random.normal(inp_rng, (8, 2))  
params = model.init(init_rng, inp)
print(params)

{'params': {'Dense_0': {'kernel': Array([[-0.8734889 ,  0.03292416,  0.45095628,  0.9860286 ,  0.9650168 ,
        -0.50356966, -0.567441  , -0.32092765],
       [ 0.6106076 , -0.8035141 , -0.8497237 , -1.0364467 ,  0.11642699,
        -0.37274948, -0.06301995,  0.23880544]], dtype=float32), 'bias': Array([0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)}, 'Dense_1': {'kernel': Array([[-0.08973367],
       [-0.15572299],
       [ 0.12597609],
       [-0.02248076],
       [ 0.48822802],
       [ 0.19107282],
       [-0.32372728],
       [-0.04857434]], dtype=float32), 'bias': Array([0.], dtype=float32)}}}


In [21]:
class XORDataset(data.Dataset):

    def __init__(self, size, seed, std=0.1):
        """
        Inputs:
            size - Number of data points we want to generate
            seed - The seed to use to create the PRNG state with which we want to generate the data points
            std - Standard deviation of the noise (see generate_continuous_xor function)
        """
        super().__init__()
        self.size = size
        self.np_rng = np.random.RandomState(seed=seed)
        self.std = std
        self.generate_continuous_xor()

    def generate_continuous_xor(self):
        data = self.np_rng.randint(low=0, high=2, size=(self.size, 2)).astype(np.float32)
        # If x=y, the label is 0.
        label = (data.sum(axis=1) == 1).astype(np.int32)
        # Add gaussian noise to the data points.
        data += self.np_rng.normal(loc=0.0, scale=self.std, size=data.shape)

        self.data = data
        self.label = label

    def __len__(self):
        # Number of data point we have. Alternatively self.data.shape[0], or self.label.shape[0]
        return self.size

    def __getitem__(self, idx):
        # Return the idx-th data point of the dataset
        # If we have multiple things to return (data point and label), we can return them as tuple
        data_point = self.data[idx]
        data_label = self.label[idx]
        return data_point, data_label

In [25]:
def numpy_collate(batch):
    if isinstance(batch[0], np.ndarray):
        return np.stack(batch)
    elif isinstance(batch[0], (tuple,list)):
        transposed = zip(*batch)
        return [numpy_collate(samples) for samples in transposed]
    else:
        return np.array(batch)

In [27]:
# Importantly, we define these in numpy 
train_dataset = XORDataset(size=2500, seed=42)
train_data_loader = data.DataLoader(train_dataset, batch_size=128, shuffle=True, collate_fn=numpy_collate)

In [29]:
optimizer = optax.sgd(learning_rate=0.1)
model_state = train_state.TrainState.create(apply_fn=model.apply,
                                            params=params,
                                            tx=optimizer)

In [31]:
def calculate_loss_acc(state, params, batch):
    data_input, labels = batch
    # Obtain the logits and predictions of the model for the input data
    logits = state.apply_fn(params, data_input).squeeze(axis=-1)
    pred_labels = (logits > 0).astype(jnp.float32)
    # Calculate the loss and accuracy
    loss = optax.sigmoid_binary_cross_entropy(logits, labels).mean()
    acc = (pred_labels == labels).mean()
    return loss, acc

In [33]:
@jax.jit  # Jit the function for efficiency
def train_step(state, batch):
    grad_fn = jax.value_and_grad(calculate_loss_acc,  # Function to calculate the loss
                                 argnums=1,  # Parameters are second argument of the function
                                 has_aux=True  # Function has additional outputs, here accuracy
                                )
    (loss, acc), grads = grad_fn(state, state.params, batch)
    state = state.apply_gradients(grads=grads)
    return state, loss, acc

In [35]:
@jax.jit  # Jit the function for efficiency
def eval_step(state, batch):
    _, acc = calculate_loss_acc(state, state.params, batch)
    return acc

In [37]:
def train_model(state, data_loader, num_epochs=100):
    for epoch in tqdm(range(num_epochs)):
        for batch in data_loader:
            state, loss, acc = train_step(state, batch)
            # We could use the loss and accuracy for logging here, e.g. in TensorBoard
    return state

In [39]:
trained_model_state = train_model(model_state, train_data_loader, num_epochs=100)

In [41]:
checkpoints.save_checkpoint(ckpt_dir='my_checkpoints/',  # Folder to save checkpoint in
                            target=trained_model_state,  # What to save. To only save parameters, use model_state.params
                            step=100,  # Training step or other metric to save best model on
                            prefix='my_model',  # Checkpoint file name prefix
                            overwrite=True   # Overwrite existing checkpoint files
                           )

loaded_model_state = checkpoints.restore_checkpoint(
                                             ckpt_dir='my_checkpoints/',   # Folder with the checkpoints
                                             target=model_state,   # (optional) matching object to rebuild state in
                                             prefix='my_model'  # Checkpoint file name prefix
                                            )

In [43]:
test_dataset = XORDataset(size=500, seed=123)
# drop_last -> Don't drop the last batch although it is smaller than 128
test_data_loader = data.DataLoader(test_dataset,
                                   batch_size=128,
                                   shuffle=False,
                                   drop_last=False,
                                   collate_fn=numpy_collate)

In [45]:
def eval_model(state, data_loader):
    all_accs, batch_sizes = [], []
    for batch in data_loader:
        batch_acc = eval_step(state, batch)
        all_accs.append(batch_acc)
        batch_sizes.append(batch[0].shape[0])
    # Weighted average since some batches might be smaller
    acc = sum([a*b for a,b in zip(all_accs, batch_sizes)]) / sum(batch_sizes)
    print(f"Accuracy of the model: {100.0*acc:4.2f}%")

In [47]:
eval_model(trained_model_state, test_data_loader)
# trained_model = model.bind(trained_model_state.params)

In [49]:
# For dropout random state
# See https://flax-linen.readthedocs.io/en/latest/guides/training_techniques/dropout.html

# For batch norm running stats
# See https://flax-linen.readthedocs.io/en/latest/guides/training_techniques/batch_norm.html

# Memory Reduction Techniques

In [89]:
import os

os.environ["XLA_FLAGS"] = (
    "--xla_gpu_enable_triton_softmax_fusion=true "
    "--xla_gpu_triton_gemm_any=false "
    "--xla_gpu_enable_async_collectives=true "
    "--xla_gpu_enable_latency_hiding_scheduler=true "
    "--xla_gpu_enable_highest_priority_async_stream=true "
)

## Mixed Precision

In [58]:
class MLPClassifier(nn.Module):
    dtype: Any
    hidden_size: int = 256
    num_classes: int = 100
    dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, x: jax.Array, train: bool) -> jax.Array:
        x = nn.Dense(
            features=self.hidden_size,
            dtype=self.dtype,  # Computation in specified dtype, params stay in float32
        )(x)
        x = nn.LayerNorm(dtype=self.dtype)(x)
        x = nn.silu(x)
        x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
        x = nn.Dense(
            features=self.num_classes,
            dtype=self.dtype,
        )(x)
        x = x.astype(jnp.float32)
        x = nn.log_softmax(x, axis=-1)
        return x

In [62]:
x = jnp.ones((512, 128), dtype=jnp.float32)
rngs = {"params": jax.random.PRNGKey(0), "dropout": jax.random.PRNGKey(1)}
model_float32 = MLPClassifier(dtype=jnp.float32)
model_float32.tabulate(rngs, x, train=True, console_kwargs={"force_jupyter": True});

In [64]:
model_bfloat16 = MLPClassifier(dtype=jnp.bfloat16)
model_bfloat16.tabulate(rngs, x, train=True, console_kwargs={"force_jupyter": True})

'\n\n'

## Gradient Checkpointing / Activation Recomputation

In [68]:
def gelu(x: jax.Array) -> jax.Array:
    jax.debug.print("Executing GeLU")
    x3 = jnp.power(x, 3)
    tanh_input = np.sqrt(2 / np.pi) * (x + 0.044715 * x3)
    return 0.5 * x * (1 + jnp.tanh(tanh_input))

In [70]:
def loss_fn(x: jax.Array, remat: bool) -> jax.Array:
    act_fn = gelu
    if remat:
        act_fn = jax.remat(act_fn)
    return jnp.mean(act_fn(x))

In [72]:
x = jax.random.normal(jax.random.PRNGKey(0), (100,))
grad_fn = jax.grad(loss_fn)
_ = grad_fn(x, remat=True) # remat controls which tensors are stored and which are recomputed during the backward pass

Executing GeLU
Executing GeLU


In [74]:
_ = loss_fn(x, remat=False)

Executing GeLU


## Gradient Accumulation

In [97]:
import functools
from pprint import pprint
from typing import Any, Callable, Dict, Tuple

import flax.linen as nn
import jax
import jax.numpy as jnp
import numpy as np
import optax
from flax.struct import dataclass
from flax.training import train_state

# Type aliases
PyTree = Any
Metrics = Dict[str, Tuple[jax.Array, ...]]

In [99]:
class TrainState(train_state.TrainState):
    rng: jax.Array

In [101]:
@dataclass
class Batch:
    inputs: jax.Array
    labels: jax.Array

In [103]:
def classification_loss_fn(
    params: PyTree, apply_fn: Any, batch: Batch, rng: jax.Array
) -> Tuple[PyTree, Metrics]:
    """Classification loss function with cross-entropy."""
    logits = apply_fn({"params": params}, batch.inputs, train=True, rngs={"dropout": rng})
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch.labels)
    correct_pred = jnp.equal(jnp.argmax(logits, axis=-1), batch.labels)
    batch_size = batch.inputs.shape[0]
    step_metrics = {"loss": (loss.sum(), batch_size), "accuracy": (correct_pred.sum(), batch_size)}
    loss = loss.mean()
    return loss, step_metrics

In [105]:
def accumulate_gradients_loop(
    state: TrainState,
    batch: Batch,
    rng: jax.random.PRNGKey,
    num_minibatches: int,
    loss_fn: Callable,
) -> Tuple[PyTree, Metrics]:
    """Calculate gradients and metrics for a batch using gradient accumulation.

    Args:
        state: Current training state.
        batch: Full training batch.
        rng: Random number generator to use.
        num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps.
        loss_fn: Loss function to calculate gradients and metrics.

    Returns:
        Tuple with accumulated gradients and metrics over the minibatches.
    """
    batch_size = batch.inputs.shape[0]
    minibatch_size = batch_size // num_minibatches
    rngs = jax.random.split(rng, num_minibatches)
    # Define gradient function for single minibatch.
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)
    # Prepare loop variables.
    grads = None
    metrics = None
    for minibatch_idx in range(num_minibatches):
        with jax.named_scope(f"minibatch_{minibatch_idx}"):
            # Split the batch into minibatches.
            start = minibatch_idx * minibatch_size
            end = start + minibatch_size
            minibatch = jax.tree_map(lambda x: x[start:end], batch)
            # Calculate gradients and metrics for the minibatch.
            (_, step_metrics), step_grads = grad_fn(
                state.params, state.apply_fn, minibatch, rngs[minibatch_idx]
            )
            # Accumulate gradients and metrics across minibatches.
            if grads is None:
                grads = step_grads
                metrics = step_metrics
            else:
                grads = jax.tree_map(jnp.add, grads, step_grads)
                metrics = jax.tree_map(jnp.add, metrics, step_metrics)
    # Average gradients over minibatches.
    grads = jax.tree_map(lambda g: g / num_minibatches, grads)
    return grads, metrics

In [107]:
def accumulate_gradients_scan(
    state: TrainState,
    batch: Batch,
    rng: jax.random.PRNGKey,
    num_minibatches: int,
    loss_fn: Callable,
) -> Tuple[PyTree, Metrics]:
    """Calculate gradients and metrics for a batch using gradient accumulation.

    In this version, we use `jax.lax.scan` to loop over the minibatches. This is more efficient in terms of compilation time.

    Args:
        state: Current training state.
        batch: Full training batch.
        rng: Random number generator to use.
        num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps.
        loss_fn: Loss function to calculate gradients and metrics.

    Returns:
        Tuple with accumulated gradients and metrics over the minibatches.
    """
    batch_size = batch.inputs.shape[0]
    minibatch_size = batch_size // num_minibatches
    rngs = jax.random.split(rng, num_minibatches)
    grad_fn = jax.value_and_grad(loss_fn, has_aux=True)

    def _minibatch_step(minibatch_idx: jax.Array | int) -> Tuple[PyTree, Metrics]:
        """Determine gradients and metrics for a single minibatch."""
        minibatch = jax.tree_map(
            lambda x: jax.lax.dynamic_slice_in_dim(  # Slicing with variable index (jax.Array).
                x, start_index=minibatch_idx * minibatch_size, slice_size=minibatch_size, axis=0
            ),
            batch,
        )
        (_, step_metrics), step_grads = grad_fn(
            state.params, state.apply_fn, minibatch, rngs[minibatch_idx]
        )
        return step_grads, step_metrics

    def _scan_step(
        carry: Tuple[PyTree, Metrics], minibatch_idx: jax.Array | int
    ) -> Tuple[Tuple[PyTree, Metrics], None]:
        """Scan step function for looping over minibatches."""
        step_grads, step_metrics = _minibatch_step(minibatch_idx)
        carry = jax.tree_map(jnp.add, carry, (step_grads, step_metrics))
        return carry, None

    # Determine initial shapes for gradients and metrics.
    grads_shapes, metrics_shape = jax.eval_shape(_minibatch_step, 0)
    grads = jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), grads_shapes)
    metrics = jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), metrics_shape)
    # Loop over minibatches to determine gradients and metrics.
    (grads, metrics), _ = jax.lax.scan(
        _scan_step, init=(grads, metrics), xs=jnp.arange(num_minibatches), length=num_minibatches
    )
    # Average gradients over minibatches.
    grads = jax.tree_map(lambda g: g / num_minibatches, grads)
    return grads, metrics

In [109]:
def accumulate_gradients(*args, use_scan: bool = False, **kwargs) -> Tuple[PyTree, Metrics]:
    if use_scan:
        return accumulate_gradients_scan(*args, **kwargs)
    else:
        return accumulate_gradients_loop(*args, **kwargs)

In [111]:
def train_step(
    state: TrainState,
    metrics: Metrics | None,
    batch: Batch,
    num_minibatches: int,
) -> Tuple[TrainState, Metrics]:
    """Training step function.

    Executes a full training step with gradient accumulation.

    Args:
        state: Current training state.
        metrics: Current metrics, accumulated from previous training steps.
        batch: Training batch.
        num_minibatches: Number of minibatches to split the batch into. Equal to the number of gradient accumulation steps.

    Returns:
        Tuple with updated training state (parameters, optimizer state, etc.) and metrics.
    """
    # Split the random number generator for the current step.
    rng, step_rng = jax.random.split(state.rng)
    # Determine gradients and metrics for the full batch.
    grads, step_metrics = accumulate_gradients(
        state, batch, step_rng, num_minibatches, loss_fn=classification_loss_fn, use_scan=True
    )
    # Optimizer step.
    new_state = state.apply_gradients(grads=grads, rng=rng)
    # Accumulate metrics across training steps.
    if metrics is None:
        metrics = step_metrics
    else:
        metrics = jax.tree_map(jnp.add, metrics, step_metrics)
    return new_state, metrics

In [113]:
batch_size = 512
num_inputs = 128
num_classes = 100
rng_seed = 0

rng = jax.random.PRNGKey(rng_seed)
data_input_rng, data_label_rng, model_rng, state_rng = jax.random.split(rng, 4)
batch = Batch(
    inputs=jax.random.normal(data_input_rng, (batch_size, num_inputs)),
    labels=jax.random.randint(data_label_rng, (batch_size,), 0, num_classes),
)

In [115]:
# Zero dropout for checking later equality between training with and without gradient accumulation.
model = MLPClassifier(dtype=jnp.bfloat16, dropout_rate=0.0)
params = model.init(model_rng, batch.inputs, train=False)["params"]
state = TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optax.adam(1e-3),
    rng=state_rng,
)

In [117]:
_, metric_shapes = jax.eval_shape(
    functools.partial(train_step, num_minibatches=4),
    state,
    None,
    batch,
)
print("Metric shapes:")
pprint(metric_shapes)

Metric shapes:
{'accuracy': (ShapeDtypeStruct(shape=(), dtype=int32),
              ShapeDtypeStruct(shape=(), dtype=int32)),
 'loss': (ShapeDtypeStruct(shape=(), dtype=float32),
          ShapeDtypeStruct(shape=(), dtype=int32))}


  minibatch = jax.tree_map(
  grads = jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), grads_shapes)
  metrics = jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), metrics_shape)
  minibatch = jax.tree_map(
  carry = jax.tree_map(jnp.add, carry, (step_grads, step_metrics))
  grads = jax.tree_map(lambda g: g / num_minibatches, grads)


In [119]:
train_step_jit = jax.jit(
    train_step,
    static_argnames="num_minibatches",
)

In [121]:
def train_with_minibatches(
    state: TrainState,
    batch: Batch,
    num_minibatches: int,
    num_train_steps: int,
) -> Tuple[TrainState, Metrics]:
    """Small helper function for training loop."""
    train_metrics = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
    for _ in range(num_train_steps):
        state, train_metrics = train_step_jit(state, train_metrics, batch, num_minibatches)
    return state, train_metrics

## JAX Donation

In [124]:
train_step_donated = jax.jit(
    train_step,
    static_argnames="num_minibatches",
    donate_argnames=(
        "state",
        "metrics",
    ),
)

## Putting it Together - Transformer Model

In [129]:
import os
import urllib.request
from urllib.error import HTTPError

# Github URL where python scripts are stored.
base_url = "https://raw.githubusercontent.com/phlippe/uvadlc_notebooks/master/docs/tutorial_notebooks/scaling/JAX/"
# Files to download.
python_files = ["single_gpu.py", "utils.py"]
# For each file, check whether it already exists. If not, try downloading it.
for file_name in python_files:
    if not os.path.isfile(file_name):
        file_url = base_url + file_name
        print(f"Downloading {file_url}...")
        try:
            urllib.request.urlretrieve(file_url, file_name)
        except HTTPError as e:
            print(
                "Something went wrong. Please try to download the file directly from the GitHub repository, or contact the author with the full output including the following error:\n",
                e,
            )

Downloading https://raw.githubusercontent.com/phlippe/uvadlc_notebooks/master/docs/tutorial_notebooks/scaling/JAX/single_gpu.py...
Downloading https://raw.githubusercontent.com/phlippe/uvadlc_notebooks/master/docs/tutorial_notebooks/scaling/JAX/utils.py...


In [166]:
from utils import set_XLA_flags_gpu
set_XLA_flags_gpu()
from single_gpu import Batch, TrainState, accumulate_gradients, print_metrics # Functions from above

In general, 
- `self.config.dtype` allows us to use mixed precision training
    - We keep softmax calculations in float32
- `self.confit.remat` contains the blocks we want to explicitly remat
- `scan` prevents recompilation of repeated blocks
- `accumulate_gradients` defined earlier allows us to accumulate gradients

In [137]:
class MLPBlock(nn.Module):
    config: ConfigDict
    train: bool

    @nn.compact
    def __call__(self, x: jax.Array) -> jax.Array:
        input_features = x.shape[-1]
        x = nn.LayerNorm(dtype=self.config.dtype, name="pre_norm")(x)
        x = nn.Dense(
            features=self.config.mlp_expansion * input_features,
            dtype=self.config.dtype,
            name="input_layer",
        )(x)
        x = nn.gelu(x)
        x = nn.Dense(
            features=input_features,
            dtype=self.config.dtype,
            name="output_layer",
        )(x)
        x = nn.Dropout(rate=self.config.dropout_rate, deterministic=not self.train)(x)
        return x

In [139]:
def dot_product_attention(
    query: jax.Array,
    key: jax.Array,
    value: jax.Array,
    mask: jax.Array | None,
    softmax_dtype: jnp.dtype = jnp.float32,
):
    """Dot-product attention.

    Follows the setup of https://flax.readthedocs.io/en/latest/api_reference/flax.linen/layers.html#flax.linen.dot_product_attention,
    but supports switch to float32 for numerical stability during softmax.

    Args:
        query: The query array, shape [..., num queries, num heads, hidden size].
        key: The key array, shape [..., num keys, num heads, hidden size].
        value: The value array, shape [..., num keys, num heads, hidden size].
        mask: The boolean mask array (0 for masked values, 1 for non-masked). If None, no masking is applied.
        softmax_dtype: The dtype to use for the softmax and dot-product operation.

    Returns:
        The attention output array, shape [..., num queries, num heads, hidden size].
    """
    num_features = query.shape[-1]
    dtype = query.dtype
    scale = num_features**-0.5
    query = query * scale
    # Switch dtype right before the dot-product for numerical stability.
    query = query.astype(softmax_dtype)
    key = key.astype(softmax_dtype)
    weights = jnp.einsum("...qhd,...khd->...hqk", query, key)
    if mask is not None:
        weights = jnp.where(mask, weights, jnp.finfo(softmax_dtype).min)
    weights = nn.softmax(weights, axis=-1)
    # After softmax, switch back to the original dtype
    weights = weights.astype(dtype)
    new_vals = jnp.einsum("...hqk,...khd->...qhd", weights, value)
    new_vals = new_vals.astype(dtype)
    return new_vals

In [141]:
class AttentionBlock(nn.Module):
    config: ConfigDict
    mask: jax.Array | None
    train: bool

    @nn.compact
    def __call__(self, x: jax.Array) -> jax.Array:
        input_features = x.shape[-1]
        x = nn.LayerNorm(dtype=self.config.dtype, name="pre_norm")(x)
        qkv = nn.DenseGeneral(
            features=(self.config.num_heads, self.config.head_dim * 3),
            dtype=self.config.dtype,
            name="qkv",
        )(x)
        q, k, v = jnp.split(qkv, 3, axis=-1)
        x = dot_product_attention(q, k, v, mask=self.mask, softmax_dtype=self.config.softmax_dtype)
        x = nn.DenseGeneral(
            features=input_features,
            axis=(-2, -1),
            dtype=self.config.dtype,
            name="output_layer",
        )(x)
        x = nn.Dropout(rate=self.config.dropout_rate, deterministic=not self.train)(x)
        return x

In [143]:
class TransformerBlock(nn.Module):
    config: ConfigDict
    mask: jax.Array | None
    train: bool

    @nn.compact
    def __call__(self, x: jax.Array) -> jax.Array:
        # MLP block
        mlp = MLPBlock
        if "MLP" in self.config.remat:
            mlp = nn.remat(mlp, prevent_cse=False)
        x = x + mlp(config=self.config, train=self.train, name="mlp")(x)
        # Attention block
        attn = AttentionBlock
        if "Attn" in self.config.remat:
            attn = nn.remat(attn, prevent_cse=False)
        x = x + attn(config=self.config, mask=self.mask, train=self.train, name="attn")(x)
        return x

In [145]:
class Transformer(nn.Module):
    config: ConfigDict

    @nn.compact
    def __call__(
        self, x: jax.Array, mask: jax.Array | None = None, train: bool = True
    ) -> jax.Array:
        if mask is None and self.config.causal_mask:
            mask = nn.make_causal_mask(x, dtype=jnp.bool_)
        # Input layer.
        x = nn.Embed(
            num_embeddings=self.config.vocab_size,
            features=self.config.hidden_size,
            dtype=self.config.dtype,
            name="embed",
        )(x)
        pos_emb = self.param(
            "pos_emb",
            nn.initializers.normal(stddev=0.02),
            (self.config.max_seq_len, self.config.hidden_size),
        )
        pos_emb = pos_emb.astype(self.config.dtype)
        x = x + pos_emb[None, : x.shape[1]]
        # Transformer blocks.
        block_fn = functools.partial(TransformerBlock, config=self.config, mask=mask, train=train)
        if "Block" in self.config.remat:
            block_fn = nn.remat(block_fn, prevent_cse=False)
        if self.config.scan_layers:
            block = block_fn(name="block")
            x, _ = nn.scan(
                lambda module, carry, _: (module(carry), None),
                variable_axes={"params": 0},
                split_rngs={"params": True, "dropout": True},
                length=self.config.num_layers,
            )(block, x, ())
        else:
            for l_idx in range(self.config.num_layers):
                x = block_fn(name=f"block_{l_idx}")(x)
        # Output layer.
        x = nn.LayerNorm(dtype=self.config.dtype, name="post_norm")(x)
        x = nn.Dense(
            features=self.config.num_outputs,
            dtype=self.config.dtype,
            name="output_layer",
        )(x)
        x = x.astype(jnp.float32)
        return x

In [147]:
data_config = ConfigDict(
    dict(
        batch_size=64,
        seq_len=512,
        vocab_size=2048,
    )
)
model_config = ConfigDict(
    dict(
        hidden_size=1024,
        dropout_rate=0.1,
        mlp_expansion=4,
        num_layers=12,
        head_dim=128,
        causal_mask=True,
        max_seq_len=data_config.seq_len,
        vocab_size=data_config.vocab_size,
        num_outputs=data_config.vocab_size,
        dtype=jnp.bfloat16,
        softmax_dtype=jnp.float32,
        scan_layers=True,
        remat=("MLP", "Attn"),
    )
)
model_config.num_heads = model_config.hidden_size // model_config.head_dim
optimizer_config = ConfigDict(
    dict(
        learning_rate=4e-4,
        num_minibatches=4,
    )
)
config = ConfigDict(
    dict(
        model=model_config,
        optimizer=optimizer_config,
        data=data_config,
        seed=42,
    )
)

In [149]:
model = Transformer(config=config.model)
optimizer = optax.adam(
    learning_rate=optax.warmup_exponential_decay_schedule(
        init_value=0,
        peak_value=config.optimizer.learning_rate,
        warmup_steps=10,
        transition_steps=1,
        decay_rate=0.99,
    )
)

In [151]:
tokens = jax.random.randint(
    jax.random.PRNGKey(0),
    (config.data.batch_size, config.data.seq_len),
    1,
    config.data.vocab_size,
)
batch_transformer = Batch(
    inputs=jnp.pad(tokens[:, :-1], ((0, 0), (1, 0)), constant_values=0),
    labels=tokens,
)

In [152]:
model_rng, state_rng = jax.random.split(jax.random.PRNGKey(config.seed))
params = model.init(
    model_rng,
    batch_transformer.inputs[: config.data.batch_size // config.optimizer.num_minibatches],
    train=False,
)["params"]
state = TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=optimizer,
    rng=state_rng,
)

In [154]:
def next_token_pred_loss(
    params: PyTree, apply_fn: Any, batch: Batch, rng: jax.Array
) -> Tuple[PyTree, Metrics]:
    """Next token prediction loss function."""
    logits = apply_fn({"params": params}, batch.inputs, train=True, rngs={"dropout": rng})
    loss = optax.softmax_cross_entropy_with_integer_labels(logits, batch.labels)
    correct_pred = jnp.equal(jnp.argmax(logits, axis=-1), batch.labels)
    batch_size = np.prod(batch.labels.shape)
    step_metrics = {"loss": (loss.sum(), batch_size), "accuracy": (correct_pred.sum(), batch_size)}
    loss = loss.mean()
    return loss, step_metrics

In [155]:
@functools.partial(
    jax.jit,
    donate_argnames=(
        "state",
        "metrics",
    ),
)
def train_step_transformer(
    state: TrainState,
    metrics: Metrics | None,
    batch: Batch,
) -> Tuple[TrainState, Metrics]:
    """Training step function.

    Executes a full training step with gradient accumulation for the next-token prediction task.

    Args:
        state: Current training state.
        metrics: Current metrics, accumulated from previous training steps.
        batch: Training batch.

    Returns:
        Tuple with updated training state (parameters, optimizer state, etc.) and metrics.
    """
    # Split the random number generator for the current step.
    rng, step_rng = jax.random.split(state.rng)
    # Determine gradients and metrics for the full batch.
    grads, step_metrics = accumulate_gradients( # This does the accumulation of gradients that was defined earlier
        state,
        batch,
        step_rng,
        config.optimizer.num_minibatches,
        loss_fn=next_token_pred_loss,
        use_scan=True,
    )
    # Optimizer step.
    new_state = state.apply_gradients(grads=grads, rng=rng)
    # Accumulate metrics across training steps.
    if metrics is None:
        metrics = step_metrics
    else:
        metrics = jax.tree_map(jnp.add, metrics, step_metrics)
    return new_state, metrics

In [159]:
_, metric_shapes = jax.eval_shape(
    train_step_transformer,
    state,
    None,
    batch_transformer,
)
metrics = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)

  minibatch = jax.tree_map(
  grads = jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), grads_shapes)
  metrics = jax.tree_map(lambda x: jnp.zeros(x.shape, x.dtype), metrics_shape)
  minibatch = jax.tree_map(
  carry = jax.tree_map(jnp.add, carry, (step_grads, step_metrics))
  grads = jax.tree_map(lambda g: g / num_minibatches, grads)
  metrics = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)


In [161]:
for _ in tqdm(range(4)):
    state, metrics = train_step_transformer(state, metrics, batch_transformer)
final_metrics = jax.tree_map(lambda x: jnp.zeros(x.shape, dtype=x.dtype), metric_shapes)
state, final_metrics = train_step_transformer(state, final_metrics, batch_transformer)
print_metrics(final_metrics, "Final metrics - Transformer")

In [163]:
jax.profiler.start_trace("traces/")
for i in range(3):
    with jax.profiler.StepTraceAnnotation("train_step", step_num=i + 1):
        state, metrics = train_step_transformer(state, metrics, batch_transformer)
metrics["loss"][0].block_until_ready()
jax.profiler.stop_trace()

In [None]:
# %load_ext tensorboard
# %tensorboard --logdir traces/single_gpu_transformer