<a href="https://colab.research.google.com/github/WilsonWang01/hands-on-gpt2/blob/main/training_pipeline.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Train a miniGPT language model with JAX

<table class="tfo-notebook-buttons" align="left">
  <td>
    <a target="_blank" href="https://kaggle.com/kernels/welcome?src=https://github.com/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_for_LLM_pretraining.ipynb"><img src="https://www.kaggle.com/static/images/logos/kaggle-logo-transparent-300.png" height="32" width="70"/>Run in Kaggle</a>
  </td>
  <td>
    <a target="_blank" href="https://colab.research.google.com/github/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_for_LLM_pretraining.ipynb"><img src="https://www.tensorflow.org/images/colab_logo_32px.png" />Run in Google Colab</a>
  </td>
  <td>
    <a target="_blank" href="https://github.com/jax-ml/jax-ai-stack/blob/main/docs/source/JAX_for_LLM_pretraining.ipynb"><img src="https://www.tensorflow.org/images/GitHub-Mark-32px.png" />View source on GitHub</a>
  </td>
</table>

This tutorial demonstrates how to use JAX, [Flax NNX](http://flax.readthedocs.io) and [Optax](http://optax.readthedocs.io) for language model (pre)training using data and tensor [parallelism](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization) for [Single-Program Multi-Data](https://en.wikipedia.org/wiki/Single_program,_multiple_data)). It was originally inspired by the [Keras miniGPT tutorial](https://keras.io/examples/generative/text_generation_with_miniature_gpt/).

Here, you will learn how to:

- Define the miniGPT model with Flax and JAX automatic parallelism
- Load and preprocess the dataset
- Create the loss and training step functions
- Train the model on TPUs on Kaggle or Google Colab
- Profile for hyperparameter tuning

If you are new to JAX for AI, check out the [introductory tutorial](https://jax-ai-stack.readthedocs.io/en/latest/neural_net_basics.html), which covers neural network building with [Flax NNX](https://flax.readthedocs.io/en/latest/nnx_basics.html).

**Note:** If you are using [Kaggle](https://www.kaggle.com/), select the free TPU v5e-8 as the hardware accelerator. If you are using [Google Colab](https://colab.research.google.com/), select the free Google Cloud TPU v5e-1 as the hardware accelerator. You may also use Google Cloud TPUs.

Check the available JAX devices, or [`jax.Device`](https://jax.readthedocs.io/en/latest/_autosummary/jax.Device.html), with [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html). The output of the cell below will show a list of 8 (eight) devices.

In [None]:
# Âº∫Âà∂ÂÆâË£ÖÈÄÇÈÖçÁöÑ 0.8.2‰ª•‰∏ä ÁâàÊú¨ÔºåÈÅøÂÖçÁâàÊú¨ÂÜ≤Á™Å
!pip install "jax[tpu]>=0.8.2" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html

In [1]:
import jax
# Ê£ÄÊü•ËÆæÂ§áÂàóË°®
print(jax.devices())

E0000 00:00:1769878847.384461   53229 common_lib.cc:650] Could not set metric server port: INVALID_ARGUMENT: Could not find SliceBuilder port 8471 in any of the 0 ports provided in `tpu_process_addresses`="local"
=== Source Location Trace: === 
learning/45eac/tfrc/runtime/common_lib.cc:238


[TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=2, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=4, process_index=0, coords=(0,2,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(1,2,0), core_on_chip=0), TpuDevice(id=6, process_index=0, coords=(0,3,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,3,0), core_on_chip=0)]


### download web_novel dataset from dropbox

In [None]:
import os
import json
import zipfile

# 1. ‰∏ãËΩΩ GuoFeng Corpus (Dropbox)
print("‚¨áÔ∏è Downloading GuoFeng Webnovel Corpus...")
!wget -O webnovel_data.zip "https://www.dropbox.com/scl/fo/dtrf3pe1vfbo5nse16648/AAZ5SFnuwohj7IJ2J-Q8zHs?rlkey=486vbn17qra1ez91btj0n4xu2&e=1&dl=1"

# 2. Ëß£Âéã
print("üìÇ Unzipping...")
with zipfile.ZipFile("webnovel_data.zip", 'r') as zip_ref:
    zip_ref.extractall("webnovel_raw")

# 3. ËΩ¨Êç¢‰∏∫ JSONL
output_file = "/kaggle/working/webnovel_train.jsonl"
source_zh = None
source_en = None

# Ëá™Âä®ÂØªÊâæËß£ÂéãÂêéËóèÂú®Ê∑±Â±ÇÁõÆÂΩïÈáåÁöÑÊñá‰ª∂
for root, dirs, files in os.walk("webnovel_raw"):
    if "train.zh" in files: source_zh = os.path.join(root, "train.zh")
    if "train.en" in files: source_en = os.path.join(root, "train.en")

if source_zh and source_en:
    print(f"‚úÖ Found source files:\n  ZH: {source_zh}\n  EN: {source_en}")
    
    with open(source_zh, 'r', encoding='utf-8') as f_zh, \
         open(source_en, 'r', encoding='utf-8') as f_en, \
         open(output_file, 'w', encoding='utf-8') as f_out:
        
        count = 0
        for line_zh, line_en in zip(f_zh, f_en):
            zh_text = line_zh.strip()
            en_text = line_en.strip()
            if not zh_text or not en_text: continue
            
            # Ê†ºÂºèÔºö‰∏≠Êñá + Êç¢Ë°å + Ëã±Êñá
            # ËøôÊ†∑ËÆ≠ÁªÉÂá∫Êù•ÁöÑÊ®°ÂûãÔºåÁªôÂÆÉ‰∏≠ÊñáÔºåÂÆÉÂ∞±‰ºöÈ¢ÑÊµãÂá∫Ëã±ÊñáÁøªËØë
            record = { "text": f"{zh_text}\n{en_text}" }
            
            f_out.write(json.dumps(record, ensure_ascii=False) + '\n')
            count += 1
            
    print(f"üéâ Done! Processed {count} lines. File saved at: {output_file}")
else:
    print("‚ùå Error: Could not find train.zh or train.en files.")

In [None]:
# Download & Process the Data for JAX
import os
import json
import zipfile
# 1. Download the GuoFeng Corpus directly
# We use the 'dl=1' flag to get the file directly
print("‚¨áÔ∏è Downloading GuoFeng Webnovel Corpus...")
!wget -O webnovel_data.zip "https://www.dropbox.com/scl/fo/dtrf3pe1vfbo5nse16648/AAZ5SFnuwohj7IJ2J-Q8zHs?rlkey=486vbn17qra1ez91btj0n4xu2&e=1&dl=1"
# 2. Unzip
print("üìÇ Unzipping...")
with zipfile.ZipFile("webnovel_data.zip", 'r') as zip_ref:
    zip_ref.extractall("webnovel_raw")
# 3. Convert to JSONL (JAX friendly format)
# We will combine train.zh (Chinese) and train.en (English)
# into a format where the model learns to translate or associate pairs.
final_output_file = "/content/webnovel_train.jsonl"
source_zh = "webnovel_raw/V1/TRAIN/train.zh"
source_en = "webnovel_raw/V1/TRAIN/train.en"
# Verify path logic (in case unzip structure varies)
# Sometimes dropbox zips create a top-level folder. We search for the files.
found_zh = None
found_en = None
for root, dirs, files in os.walk("webnovel_raw"):
    if "train.zh" in files: found_zh = os.path.join(root, "train.zh")
    if "train.en" in files: found_en = os.path.join(root, "train.en")
if found_zh and found_en:
    print(f"‚úÖ Found source files:\n  ZH: {found_zh}\n  EN: {found_en}")

    print("üîÑ converting to JSONL...")
    with open(found_zh, 'r', encoding='utf-8') as f_zh, \
         open(found_en, 'r', encoding='utf-8') as f_en, \
         open(final_output_file, 'w', encoding='utf-8') as f_out:

        count = 0
        for line_zh, line_en in zip(f_zh, f_en):
            zh_text = line_zh.strip()
            en_text = line_en.strip()
            if not zh_text or not en_text: continue

            # --- CHOOSE YOUR TRAINING FORMAT HERE ---

            # OPTION A: Translation (Instruction Tuning style)
            # Use this if your model supports 'instruction' fields
            # record = {
            #     "instruction": "Translate chinese to english",
            #     "input": zh_text,
            #     "output": en_text
            # }
            # OPTION B: Raw Text (Pretraining style)
            # Use this if you are doing standard Causal Language Modeling
            # The model just sees the Chinese followed by English
            record = {
                "text": f"{zh_text}\n{en_text}"
            }

            f_out.write(json.dumps(record, ensure_ascii=False) + '\n')
            count += 1

    print(f"üéâ Done! Created {final_output_file} with {count} examples.")
else:
    print("‚ùå Could not find train.zh or train.en inside the zip.")

Import the necessary modules, including JAX NumPy, Flax NNX, Optax, Grain, pandas, and Tiktoken:

In [None]:
# ‰∏ÄÊ¨°ÊÄßÂÆâË£ÖÊâÄÊúâÂèØËÉΩÁº∫Â§±ÁöÑÂ∫ì
!pip install tiktoken grain-nightly flax optax

In [3]:
import jax
import jax.numpy as jnp

from jax.sharding import Mesh, PartitionSpec as P, NamedSharding # For data and model parallelism (explained in more detail later)
from jax.experimental import mesh_utils

import flax.nnx as nnx
import optax

from dataclasses import dataclass
import grain.python as pygrain
import pandas as pd
import time

## Define the miniGPT model with Flax and JAX automatic parallelism

### Leveraging JAX's data and tensor parallelism

One of the most powerful features of JAX is [device parallelism](https://jax.readthedocs.io/en/latest/notebooks/Distributed_arrays_and_automatic_parallelization) for SPMD.

- The data parallelism technique enables, for example, the training data to run via multiple parts (this is called sharding) - batches - in parallel and simultaneously across different devices, such as GPUs and Google TPUs. This allows to use larger batch sizes to speed up training.
- Tensor parallelism allows us to split the model parameter tensors across several devices (sharding model tensors).
- You can learn more about the basics of JAX parallelism in more detail in the [Introduction to parallel programming](https://jax.readthedocs.io/en/latest/sharded-computation.html) on the JAX documentation site.

In this example, we'll utilize a 4-way data parallel and 2-way tensor parallel setup, which is aligned with Kaggle TPU v5e-8 or newer GCP TPUs chips.

Note that as of October 2025, free-tier Colab only offers TPU v5e-1, which can no longer support SPMD.

### jax.sharding.Mesh

Earlier, we imported [`jax.sharding.Mesh`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.Mesh) - is a multidimensional NumPy array of JAX devices, where each axis of the mesh has a name, such as `'x'` or `'y'`. This will help encapsulate the information about the TPU resource organization for distributing computations across the devices.

Our `Mesh` will have two arguments:
- `devices`: This will take the value of [`jax.experimental.mesh_utils((4, 2))`](https://jax.readthedocs.io/en/latest/jax.experimental.mesh_utils.html), enabling us to build a device mesh. It is a NumPy ndarray with JAX devices (a list of devices from the JAX backend as obtained from [`jax.devices()`](https://jax.readthedocs.io/en/latest/_autosummary/jax.devices.html#jax.devices))..
- `axis_names`, where:
  - `batch`: 4 devices along the first axis - i.e. sharded into 4 - for data parallelism; and
  - `model`: 2 devices along the second axis - i.e. sharded into 2 -  for tensor parallism

This matches the structure in the Kaggle TPU v5e setup.

Let's instantiate `Mesh` as `mesh` and declare the TPU configuration to define how data and model parameters are distributed across the devices:

In [9]:
# Create a `Mesh` object representing TPU device arrangement.
# For example, for Kaggle TPU v5e-8:
if jax.device_count() == 8:
    mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))

    ### Alternatively, we could use the 8-way data parallelism with only one line of code change.
    ### JAX enables quick experimentation with different partitioning strategies
    ### like this. We will come back to this point at the end of this tutorial.
    # mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))

### For free-tier Colab TPU, which only has a single TPU core
if jax.device_count() == 1:
    mesh = Mesh(mesh_utils.create_device_mesh((1, 1)), ("batch", "model"))

# ============================================================================
# KAGGLE TPU Â§öËØ≠Ë®Ä TOKENIZER (Yi-1.5, 64K ËØçË°®)
# ============================================================================
# ËØçË°®Â§ßÂ∞è: 64,000 (ÈÄÇÂêà‰∏≠Ëã±ÊñáÊ∑∑ÂêàËÆ≠ÁªÉ)
# Êù•Ê∫ê: 01-ai/Yi-1.5-6B (HuggingFace)
# ============================================================================

In [4]:
# ============================================================================
# KAGGLE TPU Â§öËØ≠Ë®Ä TOKENIZER (Yi-1.5, 64K ËØçË°®)
# ============================================================================
# ËØçË°®Â§ßÂ∞è: 64,000 (ÈÄÇÂêà‰∏≠Ëã±ÊñáÊ∑∑ÂêàËÆ≠ÁªÉ)
# Êù•Ê∫ê: 01-ai/Yi-1.5-6B (HuggingFace)
# ============================================================================

# Step 1: ÂÆâË£Ö‰æùËµñ (ËøêË°å‰∏ÄÊ¨°ÂêéÈáçÂêØsession)
# !pip install transformers

from typing import List, Optional, Union


class MultilingualTokenizer:
    """
    Yi-1.5 tokenizer wrapper (64K vocab)
    ÊîØÊåÅ‰∏≠Ëã±ÊñáÔºåtiktoken ÂÖºÂÆπ API
    """
    
    def __init__(self, model_name: str = "01-ai/Yi-1.5-6B"):
        """
        Args:
            model_name: HuggingFace Ê®°ÂûãIDÔºåÊé®ËçêÈÄâÈ°π:
                - "01-ai/Yi-1.5-6B" (ÈªòËÆ§): 64KËØçË°®Ôºå‰∏≠Ëã±Êñá‰ºòÁßÄ
                - "baichuan-inc/Baichuan2-7B-Base": 125KËØçË°®
        """
        from transformers import AutoTokenizer
        
        self._tokenizer = AutoTokenizer.from_pretrained(
            model_name, 
            trust_remote_code=True,
            use_fast=True
        )
        
        self._eot_token = self._tokenizer.eos_token_id
        self._bos_token = self._tokenizer.bos_token_id
        self._pad_token = self._tokenizer.pad_token_id if self._tokenizer.pad_token_id is not None else 0
        
        # TPUÂØπÈΩê: ËØçË°®ÂøÖÈ°ªËÉΩË¢´128Êï¥Èô§
        raw_vocab = len(self._tokenizer)
        self._padded_vocab = ((raw_vocab // 128) + 1) * 128 if raw_vocab % 128 != 0 else raw_vocab
    
    @property
    def n_vocab(self) -> int:
        """ÂéüÂßãËØçË°®Â§ßÂ∞è"""
        return len(self._tokenizer)
    
    @property
    def padded_vocab_size(self) -> int:
        """TPUÂØπÈΩêÂêéÁöÑËØçË°®Â§ßÂ∞è"""
        return self._padded_vocab
    
    @property
    def eot_token(self) -> int:
        """ÁªìÊùüÊ†áËÆ∞ID"""
        return self._eot_token
    
    @property
    def bos_token(self) -> int:
        """ÂºÄÂßãÊ†áËÆ∞ID"""
        return self._bos_token
    
    @property
    def pad_token(self) -> int:
        """Â°´ÂÖÖÊ†áËÆ∞ID"""
        return self._pad_token
    
    def encode(self, text: str, allowed_special: Optional[set] = None, add_special_tokens: bool = False) -> List[int]:
        """ÁºñÁ†ÅÊñáÊú¨‰∏∫token IDs"""
        return self._tokenizer.encode(text, add_special_tokens=add_special_tokens)
    
    def decode(self, tokens: Union[List[int], int]) -> str:
        """Ëß£Á†Åtoken IDs‰∏∫ÊñáÊú¨"""
        if isinstance(tokens, int):
            tokens = [tokens]
        return self._tokenizer.decode(tokens, skip_special_tokens=True)


def get_tokenizer(model_name: str = "01-ai/Yi-1.5-6B") -> MultilingualTokenizer:
    """ÂàõÂª∫tokenizer (ÈªòËÆ§Yi-1.5, 64KËØçË°®)"""
    return MultilingualTokenizer(model_name)


# ============================================================================
# ‰ΩøÁî®Á§∫‰æã / USAGE
# ============================================================================
if __name__ == "__main__" or True:  # Always run in notebook
    print("=" * 50)
    print("Initializing Yi-1.5 Tokenizer (64K vocab)...")
    
    tokenizer = get_tokenizer()
    
    print(f"‚úì Raw vocab size: {tokenizer.n_vocab:,}")
    print(f"‚úì Padded vocab (TPU): {tokenizer.padded_vocab_size:,}")
    print(f"‚úì Divisible by 128: {tokenizer.padded_vocab_size % 128 == 0}")
    
    # ÊµãËØïÁºñÁ†Å
    test_texts = [
        "ËøôÊòØ‰∏≠ÊñáÊµãËØï",
        "English test",
        "‰∏≠Ëã±Ê∑∑ÂêàMixedÊñáÊú¨"
    ]
    
    print("\n--- Tokenization Test ---")
    for text in test_texts:
        ids = tokenizer.encode(text)
        print(f"'{text}' ‚Üí {len(ids)} tokens")


# ============================================================================
# Âú®Ê®°ÂûãÈÖçÁΩÆ‰∏≠‰ΩøÁî®:
# vocab_size = tokenizer.padded_vocab_size  # 64,128
# model = MiniGPT(vocab_size=vocab_size, ...)
# ============================================================================


Initializing Yi-1.5 Tokenizer (64K vocab)...




‚úì Raw vocab size: 63,992
‚úì Padded vocab (TPU): 64,000
‚úì Divisible by 128: True

--- Tokenization Test ---
'ËøôÊòØ‰∏≠ÊñáÊµãËØï' ‚Üí 3 tokens
'English test' ‚Üí 2 tokens
'‰∏≠Ëã±Ê∑∑ÂêàMixedÊñáÊú¨' ‚Üí 6 tokens


To leverage model parallelism, we need to instruct the JAX compiler how to shard the model tensors across the TPU devices. Earlier, we also imported [`jax.sharding.PartitionSpec`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.PartitionSpec) and [`jax.sharding.NamedSharding`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.NamedSharding):
- [`PartitionSpec`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.PartitionSpec) (using alias `P`) defines how tensors are sharded across the devices in our `Mesh`. Its elements describe how an input dimension is partitioned across mesh dimensions. For example, in `PartitionSpec('x', 'y')` the first dimension of data is sharded across `x` axis of the mesh, and the second one - across the `y` axis.
  - We'll use `PartitionSpec` to describe how to shard a tensor across, for example, the `model` axis or be replicated on other dimensions (which is denoted by `None`).
- [`NamedSharding`](https://jax.readthedocs.io/en/latest/jax.sharding.html#jax.sharding.NamedSharding) is a (`Mesh`, `PartitionSpec`) pair that describes how to shard a model tensor across our `mesh`.
- We combine `Mesh` (the TPU resources) with `PartitionSpec` and create a `NamedSharding`, which instructs how to shard each model tensor across the TPU devices.

Additionally, we'll use Flax NNX's [`flax.nnx.with_partitioning`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/spmd.html#flax.nnx.with_partitioning) to let each model layer know that the model weights or tensors need to be sharded according to our specification. We need to do this for every tensor/layer in the model.
- `nnx.with_partitioning` will take two arguments, such as the `initializer` (such as [`flax.nnx.initializers.xavier_uniform`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/initializers.html#flax.nnx.initializers.xavier_uniform) and [`flax.nnx.initializers.zeros_init`](https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/initializers.html#flax.nnx.initializers.zeros_init)) and `sharding` (e.g. `NamedSharding(Mesh, PartitionSpec)` or `NamedSharding(mesh, P('model')` in our case).

In [8]:
# --- 1. Embedding Layer ---
class TokenAndPositionEmbedding(nnx.Module):
    def __init__(self, maxlen, vocab_size, embed_dim, rngs):
        self.token_embed = nnx.Embed(
            num_embeddings=vocab_size,
            features=embed_dim,
            embedding_init=nnx.with_partitioning(nnx.initializers.normal(stddev=0.02), P(None, 'model')),
            rngs=rngs,
        )
        self.position_embed = nnx.Embed(
            num_embeddings=maxlen,
            features=embed_dim,
            embedding_init=nnx.with_partitioning(nnx.initializers.normal(stddev=0.02), P(None, 'model')),
            rngs=rngs,
        )
    def __call__(self, x):
        positions = jnp.arange(0, x.shape[-1])
        return self.token_embed(x) + self.position_embed(positions)
# --- 2. Transformer Block ---
class TransformerBlock(nnx.Module):
    def __init__(self, embed_dim: int, num_heads: int, ff_dim: int, *, rngs: nnx.Rngs, rate: float = 0.1):
        self.mha = nnx.MultiHeadAttention(
            num_heads=num_heads,
            in_features=embed_dim,
            kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), P(None, 'model')),
            bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), P('model')),
            rngs=rngs,
        )
        self.dropout1 = nnx.Dropout(rate=rate)
        self.layer_norm1 = nnx.LayerNorm(
            epsilon=1e-6,
            num_features=embed_dim,
            scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), P('model')),
            bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), P('model')),
            rngs=rngs,
        )
        self.linear1 = nnx.Linear(
            in_features=embed_dim,
            out_features=ff_dim,
            kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), P(None, 'model')),
            bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), P('model')),
            rngs=rngs,
        )
        self.dropout2 = nnx.Dropout(rate=rate)
        self.linear2 = nnx.Linear(
            in_features=ff_dim,
            out_features=embed_dim,
            kernel_init=nnx.with_partitioning(nnx.initializers.xavier_uniform(), P(None, 'model')),
            bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), P('model')),
            rngs=rngs,
        )
        self.layer_norm2 = nnx.LayerNorm(
            epsilon=1e-6,
            num_features=embed_dim,
            scale_init=nnx.with_partitioning(nnx.initializers.ones_init(), P('model')),
            bias_init=nnx.with_partitioning(nnx.initializers.zeros_init(), P('model')),
            rngs=rngs,
        )
    # Added rngs=None
    def __call__(self, x, mask=None, deterministic=False, rngs=None):
        x_norm = self.layer_norm1(x)
        # Pass rngs to MHA and Dropout
        x_mha = self.mha(x_norm, mask=mask, decode=False, deterministic=deterministic, rngs=rngs) 
        x_mha = self.dropout1(x_mha, deterministic=deterministic, rngs=rngs)
        x = x + x_mha 
        x_norm = self.layer_norm2(x)
        x_ff = self.linear1(x_norm)
        x_ff = nnx.gelu(x_ff)
        x_ff = self.dropout2(x_ff, deterministic=deterministic, rngs=rngs)
        x_ff = self.linear2(x_ff)
        x = x + x_ff 
        return x
# --- 3. MiniGPT ---
class MiniGPT(nnx.Module):
    def __init__(self, maxlen, vocab_size, embed_dim, num_heads, feed_forward_dim, num_transformer_blocks, *, rngs: nnx.Rngs):
        self.embedding_layer = TokenAndPositionEmbedding(
            maxlen, vocab_size, embed_dim, rngs=rngs
        )
        self.transformer_blocks = nnx.List([
            TransformerBlock(embed_dim, num_heads, feed_forward_dim, rngs=rngs) 
            for _ in range(num_transformer_blocks)
        ])
        self.output_layer = nnx.Linear(
            in_features=embed_dim,
            out_features=vocab_size,
            kernel_init=nnx.with_partitioning(
                nnx.initializers.xavier_uniform(), 
                P(None, 'model')
            ),
            bias_init=nnx.with_partitioning(
                nnx.initializers.zeros_init(), 
                P('model') 
            ),
            rngs=rngs
        )
    # propagate rngs
    def __call__(self, x, mask=None, deterministic=False, rngs=None):
        x = self.embedding_layer(x)
        for block in self.transformer_blocks:
            x = block(x, mask=mask, deterministic=deterministic, rngs=rngs)
        x = self.output_layer(x)
        return x
    def generate_token(self, input_ids):
        logits = self(input_ids, deterministic=True)
        return logits[0, -1, :] 
    def generate_text(self, max_tokens, start_tokens):
        tokens = list(start_tokens)
        for _ in range(max_tokens - len(start_tokens)):
            input_ids = jnp.array([tokens])
            logits = self(input_ids, deterministic=True)
            next_token = jnp.argmax(logits[0, -1, :]).item()
            tokens.append(next_token)
            if next_token == tokenizer.eot_token:
                break
        return tokenizer.decode(tokens)

Set some hyperparameters.

In [7]:
vocab_size = tokenizer.padded_vocab_size  # Â∑≤ÁªèÂØπÈΩêÔºåÁõ¥Êé•Áî®
print(f"Vocab: {vocab_size}")  # Â∫îËØ•ÊòØ 64,000

Vocab: 64000


Ë∂ÖÂèÇÊï∞ÈÖçÁΩÆ

In [6]:
# (Ë∂ÖÂèÇÊï∞ÈÖçÁΩÆ)
vocab_size = tokenizer.padded_vocab_size
num_transformer_blocks = 8
maxlen = 256
embed_dim = 256
num_heads = 8
feed_forward_dim = 256
batch_size = 144 * jax.device_count() / 2
if jax.device_count() == 1:
    batch_size = 144
num_epochs = 1
top_k = 10

## Loading and preprocessing the data

Data loading and preprocessing with [Grain](https://github.com/google/grain).

In [10]:
import json
@dataclass
class TextDataset:
    data: list
    maxlen: int
    def __len__(self):
        return len(self.data)
    def __getitem__(self, idx: int):
        # Use Tiktoken for tokenization
        # Note: We append <|endoftext|> to clearly mark end of sample
        text = self.data[idx]
        encoding = tokenizer.encode(text)
        
        # Add EOS token manually if not present
        if encoding[-1] != tokenizer.eot_token:
            encoding.append(tokenizer.eot_token)
        # Truncate and Pad
        encoding = encoding[:self.maxlen] 
        padded = encoding + [0] * (self.maxlen - len(encoding))
        return padded
def load_and_preprocess_data(file_path, batch_size, maxlen):
    print(f"üìñ Loading data from {file_path}...")
    
    # Ensure batch_size is an Integer
    batch_size = int(batch_size) 
    
    data = []
    with open(file_path, 'r', encoding='utf-8') as f:
        for line in f:
            if line.strip():
                item = json.loads(line)
                if 'text' in item:
                    data.append(item['text'])
    
    print(f"‚úÖ Loaded {len(data)} examples.")
    
    dataset = TextDataset(data, maxlen)
    sampler = pygrain.IndexSampler(
        len(dataset),
        shuffle=True,
        seed=42,
        shard_options=pygrain.NoSharding(),
        num_epochs=num_epochs,
    )
    dl = pygrain.DataLoader(
        data_source=dataset,
        sampler=sampler,
        # Fix is here: explicitly using the integer variable
        operations=[pygrain.Batch(batch_size=batch_size, drop_remainder=True)],
    )
    return dl
# Re-run the data loading line right after defining this function
text_dl = load_and_preprocess_data('/kaggle/working/webnovel_train.jsonl', batch_size, maxlen)

üìñ Loading data from /kaggle/working/webnovel_train.jsonl...
‚úÖ Loaded 1920191 examples.


## Defining the loss function and training step function

In [12]:
# Updated Loss Function to accept RNGs
def loss_fn(model, batch, rngs):
    # Pass 'rngs' to the model call
    logits = model(batch[0], rngs=rngs, deterministic=False) 
    loss = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=batch[1]).mean()
    return loss, logits
@nnx.jit
def train_step(model: MiniGPT, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, batch, dropout_key):
    # Create the NNX RNG stream from the JAX key
    dropout_rngs = nnx.Rngs(dropout=dropout_key)
    
    grad_fn = nnx.value_and_grad(loss_fn, has_aux=True)
    
    # Pass 'dropout_rngs' as the 3rd argument to matches loss_fn(model, batch, rngs)
    (loss, logits), grads = grad_fn(model, batch, dropout_rngs)
    
    metrics.update(loss=loss, logits=logits, lables=batch[1])
    optimizer.update(model, grads)

## Training the model

Start training. It takes ~50 minutes on Colab.

Note that for data parallel, we are sharding the training data along the `batch` axis using `jax.device_put` with `NamedSharding`.

We are also using the `jax.vmap` transformation to produce the target sequences faster.

In [13]:
with mesh:
    model = create_model(rngs=nnx.Rngs(0))
    optimizer = nnx.Optimizer(model, optax.adam(1e-3), wrt=nnx.Param)
metrics = nnx.MultiMetric(
    loss=nnx.metrics.Average("loss"),
)
# Initialize Main Random Key
rng = jax.random.PRNGKey(0)
start_prompt = "Once upon a time"
start_tokens = tokenizer.encode(start_prompt)[:maxlen]
print("Initial generated text (Untrained):")
generated_text = model.generate_text(maxlen, start_tokens)
print(generated_text) # Print simple output
metrics_history = {
    "train_loss": [],
}
prep_target_batch = jax.vmap(
    lambda tokens: jnp.concatenate((tokens[1:], jnp.array([0])))
)
step = 0
for epoch in range(num_epochs):
    start_time = time.time()
    for batch in text_dl:
        if len(batch) % len(jax.devices()) != 0:
            continue
            
        input_batch = jnp.array(jnp.array(batch).T)
        target_batch = prep_target_batch(input_batch)
        
        # 1. NEW: Split the key for this step
        rng, dropout_key = jax.random.split(rng)
        
        # 2. NEW: Pass dropout_key to train_step
        train_step(
            model,
            optimizer,
            metrics,
            jax.device_put(
                (input_batch, target_batch), NamedSharding(mesh, P("batch", None))
            ),
            dropout_key # <--- Added key here
        )
        if (step + 1) % 1 == 0: # Log every 10 steps
            print("hi")
            for metric, value in metrics.compute().items():
                metrics_history[f"train_{metric}"].append(value)
            metrics.reset()
            elapsed_time = time.time() - start_time
            print(
                f"\nStep {step + 1}, Loss: {metrics_history['train_loss'][-1]:.4f}, Time: {elapsed_time:.2f}s"
            )
            # Re-generate text to see progress
            print("Generated text:")
            print(model.generate_text(maxlen, start_tokens))
            
            start_time = time.time()
        step += 1
# Final text generation
print("Final generated text:")
print(model.generate_text(maxlen, start_tokens))

NameError: name 'create_model' is not defined

Visualize the training loss.

In [None]:
import matplotlib.pyplot as plt
plt.plot(metrics_history['train_loss'])
plt.title('Training Loss')
plt.xlabel('Step')
plt.ylabel('Loss')
plt.show()

As you can see, the model goes from generating completely random words at the beginning to generating sensible tiny stories at the end of the training. So essentially we have pretrained a small LLM to write tiny stories for us.

## Saving the checkpoint

Save the model checkpoint.

In [None]:
import orbax.checkpoint as orbax

state = nnx.state(model)

checkpointer = orbax.PyTreeCheckpointer()
checkpointer.save('/content/save', args=orbax.args.PyTreeSave(state), force=True)

# Make sure the files are there
!ls /content/save/

## Profiling for hyperparameter tuning

**Note:** this section assume multiple TPU cores. Free-tier Colab TPU v5e-1 cannot run here.

In [None]:
!pip install -Uq tensorboard-plugin-profile tensorflow tensorboard

Load the tensorboard colab extension.

In [None]:
%load_ext tensorboard

As we're going to be running this model a number of times, we need some scaffolding to more easily compare our work. For a baseline, we'll need to perform some warmup to guarantee that our code is JIT'd and that our TPUs are warm. For improved comparability, we'll only start tracing after we've finished warmup.

In [None]:
trace_dir = "/tmp/jax-trace/"

def loop_step(batch, step):
    input_batch = jnp.array(jnp.array(batch).T)
    target_batch = prep_target_batch(input_batch)
    train_step(model, optimizer, metrics, jax.device_put((input_batch, target_batch), NamedSharding(mesh, P('batch', None))))

def generate_trace():
    tracing_steps = 30
    warmup_steps = 5
    for current_step in range(warmup_steps + tracing_steps):
        if current_step == warmup_steps:
            jax.profiler.start_trace(trace_dir)
        with jax.profiler.StepTraceAnnotation("train", step_num=current_step):
            batch = next(text_dl)
            loop_step(batch, current_step)

    jax.profiler.stop_trace()

Now we'll perform some traces to compare results of different batch sizes. This will take several minutes as we need to reprocess our input data to prepare new batches each time.

In [None]:
trace_dir = "/tmp/jax-trace-batch-comparison/"

batch_size = 64
text_dl = iter(load_and_preprocess_data('TinyStories-train.txt', batch_size, maxlen))
generate_trace()

batch_size = 256
text_dl = iter(load_and_preprocess_data('TinyStories-train.txt', batch_size, maxlen))
generate_trace()

Run Tensorboard with the Profiler Plugin to compare our runs. Runs are listed in order from newest to oldest, so the top run in the list will be have `batch_size = 256`.

The key metrics to focus on here for this hyperparameter are FLOPS Utilization and Average Step Time.

In general, we want to maximize FLOPS Utilization while minimizing the step time per training example. In this case, we can see that increasing the batch size from 64 -> 256 achieves both of those. FLOPS increases from 16% to 27%. Average Step Time increase from 100ms to 260ms, however we increased our batch size by 300%. This means we move from 1.5ms per training example to 1.02ms per training example.

In [None]:
%tensorboard --logdir=$trace_dir

Next, we can explore alternative parallelism methods. In cell #4, we used 4-way data parallel and 2-way tensor parallel. 8-way data parallel is another popular way. Let's compare results between them. To switch to 8-way data parallel, we'll replace the `Mesh` definition with:

`mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))`

JAX will automatically figure out how to shard the model and data to use the new partition strategy and nothing else need to be done. Re-connect the TPU runtime and run it again to see how it runs.

How simple and powerful is this! And that's the beauty of JAX automatic parallelism.

In [None]:
trace_dir = "/tmp/jax-trace-parallelism-comparison/"

mesh = Mesh(mesh_utils.create_device_mesh((4, 2)), ('batch', 'model'))
generate_trace()

mesh = Mesh(mesh_utils.create_device_mesh((8, 1)), ('batch', 'model'))
generate_trace()

Once again we'll run tensorboard.

Looking at the results, we see that the step times are nearly the same, however the FLOPS Utilization is at 13% for 8-way data parallelism compared to 27% or 4-way data parallelism.

By looking at the Trace Viewer tool and looking under each TPU's ops, we can see that the TPUs spend a large amount of time idle while waiting for the host, as well as spending a good amount of time in `reduce_sum` operations.

In [None]:
%tensorboard --logdir=$trace_dir

By changing hyperparameters and comparing profiles, we're able to gain significant insights into our bottlenecks and limitations. These are just two examples of hyperparameters to tune, but plenty more of them will have significant effects on training speed and resource utilization.