Run on kaggle TPUv3-8

In [1]:
import os
os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '.98'

In [3]:
import jax.numpy     as jnp
import jax.random    as random
import jax.tree_util as jtree
from jax import grad, jit, vmap
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P 
import jax

In [4]:
import optax
import flax
from flax import nnx

In [5]:
import time
import requests
import numpy as np
from tqdm import tqdm
import matplotlib.pyplot as plt

from datasets import load_dataset

from IPython.display import display, clear_output

In [None]:
CONTEXT_SIZE = 384 # 上下文長度 / Context window size
EMBED_DIM = 256 # 嵌入維度 / Embedding dimension
QKV_DIM = 64 # QKV 向量維度 / QKV feature dimension
HEAD_SIZE = 4 # 注意力頭數 / Number of attention heads
BLOCK_SIZE = 8 # 解碼器模塊數 / Number of decoder blocks

MICRO_BATCH_SIZE = 16 # Batch size per iter 
GLOBAL_BATCH_SIZE = 32 # Batch size per step 

LEARING_RATE = 1e-6
MOMENTUM = 0.9 # Momentum factor for optimizer

DATASET = "minipile" # "toy" or "openwebtext" or "minipile"

train_steps = 5000
eval_every = 100

# Assertion to ensure attention dimensions match
assert QKV_DIM*HEAD_SIZE == EMBED_DIM

In [None]:
# List all local JAX devices (TPU cores)
jax.local_devices()

In [None]:
# TPU device mesh and Sharding strategy for tensor/model parallelism

# Number of available TPU cores
num_device = len(jax.local_devices()) 

# Create a mesh with axis names for logical partitioning
mesh = Mesh(devices=np.array(jax.devices()).reshape(1, num_device),axis_names=('data', 'model')) 

# Setup sharding strategy for dataset across the mesh
data_sharding = NamedSharding(mesh, P())

# pseudo random number generator's key
key = random.PRNGKey(777) # random seed

# Dataset

In [None]:
import time
import requests
from tqdm import tqdm

start_time = time.time()

if DATASET == "openwebtext":
    file_path = "/kaggle/input/openwebtext-dataset/train_split.txt"
    with open(file_path, "r", encoding="utf-8") as f:
        lines = []
        for line in tqdm(f, desc="Loading OpenWebText", unit=" lines"):
            lines.append(line)
            
    print("Openwebtext loaded.")

elif DATASET == "minipile":

    ds = load_dataset("JeanKaddour/minipile")

    print(f"資料集總筆數: {len(ds)}")
    print("MiniPile loaded.")
    
else:
    pass

end_time = time.time()
print(f"經過的時間: {end_time - start_time:.2f} 秒")

In [None]:
# Preview Sample Data
ds["train"][0]["text"]

In [None]:
# Vocabulary Encoding / Decoding 

def build_vocabulary(dataset):
    """從資料集中建立字元詞彙表"""
    char_set = set()
    print("正在從資料集建立詞彙表...")
    for example in tqdm(dataset['train'], desc="掃描字元"):
        char_set.update(example['text'])

    # 加入特殊符號
    # <<pad>>: 填充符號 Padding
    # <<eos>>: 序列結束符號 End-of-Sequence
    # <<sos>>: 序列開始符號 Start-of-Sequence (here ignore)
    # <<unk>>: 未知字元 Unknown
    special_tokens = ["<<pad>>", "<<eos>>", "<<unk>>"]
    sorted_chars = sorted(list(char_set))
    
    return special_tokens + sorted_chars


chars = build_vocabulary(ds)
vocab_size = len(chars)

# Character-to-index and index-to-character mappings
ctoi = { c:i for i, c in enumerate(chars) }
itoc = { i:c for c, i in ctoi.items() } 

PAD_IDX = ctoi["<<pad>>"]
EOS_IDX = ctoi["<<eos>>"]
UNK_IDX = ctoi["<<unk>>"]

encode = lambda s: [ctoi.get(c, ctoi["<<unk>>"]) for c in s]
decode = lambda l: "".join([itoc[i] for i in l])

print(f"{vocab_size=}")
print(f"(PAD_IDX): {PAD_IDX}, (EOS_IDX): {EOS_IDX}")

In [None]:
# Encoding Validation
sample_text = "Hello World"
print(encode(sample_text))
print(decode(encode(sample_text)))

In [None]:
def create_batch_generator(key, dataset, batch_size, context_size, eos_idx, pad_idx):
    """

    """
    n_samples = len(dataset)
    
    while True:
        key, subkey = random.split(key)
        indices = random.choice(subkey, n_samples, shape=(batch_size,), replace=False)
        
        batch_x, batch_y, batch_mask = [], [], []
        
        for idx in indices:
            text = dataset[int(idx)]['text']
            encoded_content = encode(text)

            # --- 核心邏輯：區分截斷與否 ---
            
            # 情況 1: 序列被截斷 (原始文本比 context_size 長)
            if len(encoded_content) > context_size  :
                x = encoded_content[:context_size]
                y = encoded_content[1:context_size + 1]
                mask = [1] * context_size

            # 情況 2: 序列正常結束 (原始文本剛好)
            elif  len(encoded_content) == context_size:
                x = encoded_content
                y = encoded_content[1:context_size] + [eos_idx]
                mask = [1] * context_size
            
            # 情況 3: 序列正常結束 (原始文本較短)
            else:
                # 1. 在內容後加上 EOS
                content_with_eos = encoded_content + [eos_idx]
                
                # 2. 建立輸入 x (填充到 context_size)
                x = content_with_eos + [pad_idx] * (context_size - len(content_with_eos))
                
                # 3. 建立標籤 y
                # 標籤是 x 向左位移一格
                y = x[1:] + [pad_idx]
                
                # 4. 建立遮罩 (只標示非 PAD 位置)
                mask = [1] * len(content_with_eos) + [0] * (context_size - len(content_with_eos))

            batch_x.append(x)
            batch_y.append(y)
            batch_mask.append(mask)

        yield (jnp.array(batch_x, dtype=jnp.int32),
               jnp.array(batch_y, dtype=jnp.int32),
               jnp.array(batch_mask, dtype=jnp.int32))



In [None]:
context_size = 5
batch_size = 4 

# 建立一個 mock 資料集，精確對應三個邏輯分支
mock_dataset = [
    {'text': '123456'},  # 情境 1: > context_size (過長)
    {'text': '12345'},   # 情境 2: == context_size (剛好)
    {'text': '1234'},    # 情境 3: < context_size (過短)
    {'text': '123'},     # 情境 4: < context_size (過短)
]
                
print("--- 開始測試 ---")
print(f"Context Size: {context_size}, PAD Index: {PAD_IDX}, EOS Index: {EOS_IDX}\n")

# 建立生成器
mock_generator = create_batch_generator(
    random.PRNGKey(0), mock_dataset, batch_size, context_size, EOS_IDX, PAD_IDX
)

# 取得一批測試資料


# --- 測試分析 ---
for i in mock_dataset:
    x, y, mask = next(create_batch_generator(
        random.PRNGKey(0), [i], 1, context_size, EOS_IDX, PAD_IDX
    ))
    print(f"輸出 x: {x}")

    print(f"輸出 y: {y}")

    print(f"輸出 mask: {mask}")


# Model Definition

In [None]:
config = {
    "vocab_size": vocab_size,
    "context_size": CONTEXT_SIZE,
    "embed_dim": EMBED_DIM,
    "qkv_dim": QKV_DIM,
    "head_size": HEAD_SIZE,
    "block_size": BLOCK_SIZE,
    "activation": "gelu",
    "ffn_d": EMBED_DIM*4,
    "dropout": 0.2,
    "mask_pad": 0,
}

In [None]:
class MaskedSelfAttention(nnx.Module):
    def __init__(self, config: dict, rngs: nnx.Rngs):
        init_fn = nnx.initializers.lecun_normal()
        
        assert config["embed_dim"] == config["head_size"] * config["qkv_dim"]
        self.qkv_feature = config["qkv_dim"]
        self.num_heads = config["head_size"]
        self.mask_pad = config["mask_pad"] # pad number in mask
        
        self.qkv_proj = nnx.Linear(in_features=config["embed_dim"], out_features=config["embed_dim"]*3, 
                                   kernel_init=nnx.with_partitioning(init_fn, (None, 'model')),
                                   dtype=jnp.bfloat16, param_dtype=jnp.bfloat16, rngs=rngs)
        self.output_proj = nnx.Linear(in_features=config["embed_dim"], out_features=config["embed_dim"], 
                                      kernel_init=nnx.with_partitioning(init_fn, ('model', None)),
                                      dtype=jnp.bfloat16, param_dtype=jnp.bfloat16, rngs=rngs)
        self.dropout = nnx.Dropout(config["dropout"], rngs=rngs)

    @nnx.remat 
    def __call__(self, x: jax.Array, masks: jax.Array):
        B, S, E = x.shape # (batch_size, seq_len, embed_dim)
        B_mask, S_mask = masks.shape # (batch_size, seq_len)
        
        assert B == B_mask and S == S_mask
        
        x = jax.lax.with_sharding_constraint(x, P())
        
        # (batch_size, seq_len, embed_dim) -> (batch_size, seq_len, 3 * qkv_dim * num_heads) , embed_dim = qkv_dim * num_heads
        qkv = self.qkv_proj(x) 
        
        # (batch_size, num_heads, seq_len, 3 * qkv_dim)
        qkv = qkv.reshape(B, S , self.num_heads, self.qkv_feature * 3 ).transpose(0, 2, 1, 3)      
        
        q, k, v = jnp.array_split(qkv, 3, axis=-1)

        # (QK^T / sqrt(d_k))
        # (batch, num_heads, seq_len, qkv_dim) @ (batch, num_heads, qkv_dim, seq_len) -> (batch, num_heads, seq_len, seq_len)
        scores = jnp.matmul(q, k.transpose(0, 1, 3, 2)) / jnp.sqrt(self.qkv_feature)

        # Attention Mask
        causal_mask = jnp.triu(jnp.ones((S, S), dtype=jnp.bool_), k=1)[None,None,:,:]
        padding_mask = (masks == self.mask_pad)[:, None, None, :]  
        attention_mask = jnp.logical_or(causal_mask, padding_mask) # combine masks
        # Apply mask
        large_negative = jnp.finfo(scores.dtype).min
        scores = jnp.where(attention_mask, large_negative, scores) 
        
        # Attetion weights
        attn_weights = nnx.softmax(scores, axis=-1)

        # (batch, num_heads, seq_len, seq_len) @ (batch, num_heads, seq_len, qkv_dim) -> (batch, num_heads, seq_len, qkv_dim)
        attn_output = jnp.matmul(attn_weights, v)

        # (batch, num_heads, seq_len, qkv_dim) -> (batch, seq_len, num_heads, qkv_dim) -> (batch, seq_len, embed_dim)
        attn_output = attn_output.transpose(0, 2, 1, 3).reshape(B, S, E)

        attn_output = jax.lax.with_sharding_constraint(attn_output, P(None, None, 'model'))
        
        output = self.output_proj(attn_output)

        x = jax.lax.with_sharding_constraint(x, P(None, 'model', None))
        output = self.dropout(output)
        return output

In [None]:
class GPTFeedForwardNetwork(nnx.Module):
    def __init__(self, config: dict, rngs: nnx.Rngs):
        init_fn = nnx.initializers.lecun_normal()

        self.layer1 = nnx.Linear(in_features=config["embed_dim"], out_features=config["ffn_d"], 
                                 kernel_init=nnx.with_partitioning(init_fn, (None, 'model')),
                                 dtype=jnp.bfloat16, param_dtype=jnp.bfloat16, rngs=rngs)
        self.layer2 = nnx.Linear(in_features=config["ffn_d"], out_features=config["embed_dim"],
                                 kernel_init=nnx.with_partitioning(init_fn, ('model', None)),
                                 dtype=jnp.bfloat16, param_dtype=jnp.bfloat16, rngs=rngs)
        
        self.activation = {
            'relu':nnx.relu,
            'swish':nnx.swish,
            'gelu':nnx.gelu
        }[config["activation"]]
        
        self.dropout = nnx.Dropout(config["dropout"], rngs=rngs)

    @nnx.remat 
    def __call__(self, x: jax.Array):
        x = jax.lax.with_sharding_constraint(x, P())
        x = self.layer1(x)
        x = self.activation(x) 
        x = self.layer2(x)
        x = jax.lax.with_sharding_constraint(x, P(None, 'model', None))
        x = self.dropout(x)
        return x

In [None]:
class GPTDecoderBlock(nnx.Module):
    def __init__(self, config: dict, rngs: nnx.Rngs):

        self.atn = MaskedSelfAttention(config, rngs)
        self.norm1 = nnx.LayerNorm(num_features=config["embed_dim"],
                                   dtype=jnp.bfloat16, param_dtype=jnp.bfloat16, rngs=rngs)
        self.ffn = GPTFeedForwardNetwork(config, rngs)
        self.norm2 = nnx.LayerNorm(num_features=config["embed_dim"], 
                                   dtype=jnp.bfloat16, param_dtype=jnp.bfloat16, rngs=rngs)

    # Enable activation recomputation to reduce memory usage 
    @nnx.remat  
    def __call__(self, x: jax.Array, padding_masks : jax.Array):
        x = x + self.atn(self.norm1(x), padding_masks)
        x = x + self.ffn(self.norm2(x))
        return x

In [None]:
class GPT(nnx.Module):
    def __init__(self, config: dict, rngs: nnx.Rngs):
        init_fn = nnx.initializers.lecun_normal()
        params_key = rngs.params() 
        split_keys = jax.random.split(params_key, config["block_size"])

        # token embedding and positional embedding shared embedding
        self.embeding = nnx.Embed(num_embeddings=config["vocab_size"]+config["context_size"], features=config["embed_dim"]
                                  , dtype=jnp.bfloat16, param_dtype=jnp.bfloat16, rngs=rngs)
        self.blocks = self.blocks_create(config, split_keys)
        self.norm = nnx.LayerNorm(num_features=config["embed_dim"],
                                  dtype=jnp.bfloat16, param_dtype=jnp.bfloat16, rngs=rngs)
        self.linear_proj = nnx.Linear(in_features=config["embed_dim"], out_features=config["vocab_size"], 
                                      kernel_init=nnx.with_partitioning(init_fn, (None, None)),
                                      dtype=jnp.bfloat16, param_dtype=jnp.bfloat16, rngs=rngs)
        
        
    def __call__(self, input_data: jax.Array, padding_masks : jax.Array):
            
        n_batch, n_contxt = input_data.shape # (Batch, Seq)

        # shift postion for embeding
        position = jnp.arange(config["vocab_size"], config["vocab_size"]+config["context_size"], dtype=jnp.int32)
        position = jnp.repeat(position[None, :], n_batch, axis=0)

        # Create array for two embeddings
        x = jnp.zeros((n_batch, config["context_size"], 2), dtype=jnp.int32)
        x = x.at[:, :, 0].set(input_data)  
        x = x.at[:, :, 1].set(position)  
        x = self.embeding(x)
        
        # Summing over embedding axis (token + position)
        x = jnp.sum(x, axis=-2) # Reduce postion embeding vector and token embeding vector
        
        x = jax.lax.with_sharding_constraint(x, P(None, 'model', None))
        
        # Blocks
        x = GPT.blocks_forward(self.blocks, x, padding_masks)

        # Last norm
        x = self.norm(x) 

        x = jax.lax.with_sharding_constraint(x, P())
        # Last layer
        logits = self.linear_proj(x)
        
        return logits 

    @nnx.vmap(in_axes=(None, None, 0), out_axes=0,)
    def blocks_create(self, config, key: jax.Array):
      return GPTDecoderBlock(config, nnx.Rngs(key))

    @nnx.scan(in_axes=(0, nnx.Carry, None), out_axes=nnx.Carry,)
    def blocks_forward(block, x, masks):
        x = block(x, masks)
        return x

# Loss Function and Training Step

In [None]:
@nnx.jit
def loss_fn(model, data_inputs, masks, labels):
  """Forward pass loss computation."""
  logits = model(data_inputs, masks).astype(jnp.float32)
    
  mask = (labels != PAD_IDX)
    
  losses = optax.softmax_cross_entropy_with_integer_labels(logits=logits, labels=labels)
  mean_loss = jnp.sum(losses * mask) / jnp.sum(mask)
  return mean_loss, logits

In [None]:
grad_fnc = nnx.jit(nnx.value_and_grad(loss_fn, has_aux=True))

def train_step(model, optimizer: nnx.Optimizer, metrics: nnx.MultiMetric, x, mask ,y, x_stack, mask_stack, y_stack, pad_index):
    """Train step with gradient accumulation over x, x_stack / y, y_stack."""

    micro_batch_num =  x_stack.shape[0] + 1
    
    # First forward-backward pass
    (loss, logits), grads = grad_fnc(model, x, mask, y)
    accum_grads = grads
    total_loss = loss
    logits_list = [logits]
    y_list = [y]

    # Accumulate over stacked microbatches
    for x, mask, y in zip(x_stack, mask_stack, y_stack):
        (loss, logits), grads = grad_fnc(model, x, mask, y)

        total_loss += loss

        accum_grads = jtree.tree_map(lambda ag, g: ag + g, accum_grads, grads)

        logits_list.append(logits)
        y_list.append(y)

    # Average accumulated gradient
    accum_grads = jtree.tree_map(lambda g: g / micro_batch_num, accum_grads)

    # Optimizer update
    optimizer.update(accum_grads)   
    
    # Trace train loss and accuarcy
    logits = jnp.reshape(jnp.concatenate(logits_list, axis=0), (-1, logits_list[0].shape[-1]))
    labels = jnp.reshape(jnp.concatenate(y_list, axis=0), (-1,))

    metrics.update(
        loss=total_loss / len(logits_list),
        logits=logits[(labels != pad_index)],
        labels=labels[(labels != pad_index)],
    )
    
    
def eval_step(model, metrics: nnx.MultiMetric, x, masks, y, pad_index):
    """Evaluation step for validation."""
    loss, logits = loss_fn(model, x, masks, y)

    logits = jnp.reshape(jnp.concatenate(logits, axis=0), (-1, logits.shape[-1]))
    labels = jnp.reshape(jnp.concatenate(y, axis=0), (-1,))
    
    metrics.update(loss=loss, logits=logits[(labels != pad_index)], labels=labels[(labels != pad_index)])  # In-place updates.

# Model Sharding and Parallelism
This section handles model partitioning across TPU devices using JAX sharding constraints.

In [None]:
@nnx.jit(static_argnames=["mesh"])
def create_sharded_model(mesh):
  with mesh:  
      model = GPT(config, rngs=nnx.Rngs(99)) # Unsharded at this moment.
      state = nnx.state(model)               # The model's state, a pure pytree.
      pspecs = nnx.get_partition_spec(state)    
      sharded_state = jax.lax.with_sharding_constraint(state, pspecs)
      nnx.update(model, sharded_state)       # The model is sharded now!
      return model

In [None]:
try:
    model = create_sharded_model(mesh)
except Exception as e:
    print(f"error: {e}")

model.train()
optimizer = nnx.Optimizer(model, optax.adamw(LEARING_RATE, MOMENTUM))
metrics = nnx.MultiMetric(
    accuracy=nnx.metrics.Accuracy(),
    loss=nnx.metrics.Average('loss'),
)

train_generator = create_batch_generator(
    random.PRNGKey(0), ds["train"], MICRO_BATCH_SIZE, CONTEXT_SIZE, EOS_IDX, PAD_IDX
)
val_generator = create_batch_generator(
    random.PRNGKey(1), ds["validation"], MICRO_BATCH_SIZE, CONTEXT_SIZE, EOS_IDX, PAD_IDX
)

In [None]:
nnx.display(model)

# Training Loop

In [None]:
# ---- Real-time plotting ----

metrics_history = {
    'train_loss': [],
    'train_accuracy': [],
    'test_loss': [],
    'test_accuracy': [],
    'steps': [] 
}

def get_subkeys(key):
    keys = random.split(key, GLOBAL_BATCH_SIZE // MICRO_BATCH_SIZE + 1)
    return keys[0], keys[1:]


with mesh:
    # Main training loop
    for step in tqdm(range(train_steps)):
        
        # Load data for the first micro-batch and stacked micro-batches for gradient accumulation
        key, subkeys = get_subkeys(key) 
        x, y, masks = next(train_generator)
        xs, ys, maskss = zip(*(next(train_generator) for k in subkeys[1:]))
        x_stack, y_stack, masks_stack = jnp.stack(xs), jnp.stack(ys), jnp.stack(maskss)
        train_step(model, optimizer, metrics, x, masks, y, x_stack, masks_stack, y_stack, PAD_IDX)

        # Periodic evaluation and metric logging
        if step > 0 and (step % eval_every == 0 or step == train_steps - 1):  
            # Log the training metrics.
            for metric, value in metrics.compute().items():  # Compute the metrics.
                metrics_history[f'train_{metric}'].append(value)  # Record the metrics.
            metrics.reset()  # Reset the metrics for the test set.
    
            # Compute the metrics on the validation set 
            for _ in range(GLOBAL_BATCH_SIZE // MICRO_BATCH_SIZE):
                x, y, masks = next(val_generator)
                eval_step(model, metrics, x, masks, y, PAD_IDX)
    
            # Log the test metrics.
            for metric, value in metrics.compute().items():
                metrics_history[f'test_{metric}'].append(value)
            metrics.reset()  # Reset the metrics 
    
            # Record the step number for the x-axis
            metrics_history['steps'].append(step)
            
            # --- Update the plot for Kaggle ---
            # Clear previous output and redraw plots for live update
            clear_output(wait=True) # Clear the previous output of the storage cell, wait=True to avoid flickering

            print(
              f"[train] step: {step}, "
              f"loss: {metrics_history['train_loss'][-1]}, "
              f"accuracy: {metrics_history['train_accuracy'][-1] * 100}"
            )
            print(
              f"[test] step: {step}, "
              f"loss: {metrics_history['test_loss'][-1]}, "
              f"accuracy: {metrics_history['test_accuracy'][-1] * 100}"
            )
            
            # Recreate the chart and axes each time
            fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
    
            ax1.set_title('Loss')
            ax1.set_xlabel('Training Step')
            ax1.set_ylabel('Loss')
            ax2.set_title('Accuracy')
            ax2.set_xlabel('Training Step')
            ax2.set_ylabel('Accuracy (%)')
    
            ax1.plot(metrics_history['steps'], metrics_history['train_loss'], 'r-', label='train_loss')
            ax1.plot(metrics_history['steps'], metrics_history['test_loss'], 'b-', label='test_loss')
            ax2.plot(metrics_history['steps'], [acc * 100 for acc in metrics_history['train_accuracy']], 'r-', label='train_accuracy')
            ax2.plot(metrics_history['steps'], [acc * 100 for acc in metrics_history['test_accuracy']], 'b-', label='test_accuracy')
    
            ax1.legend()
            ax2.legend()
            plt.tight_layout() # Adjust layout to avoid overlap
            display(fig) # Display the chart in the Notebook
            plt.close(fig) # Close the chart object to free up memory and avoid displaying it twice


print("Training finished.")

# Plot

In [None]:
import matplotlib.pyplot as plt  # Visualization

# Plot loss and accuracy in subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))
ax1.set_title('Loss')
ax2.set_title('Accuracy')
for dataset in ('train', 'test'):
  ax1.plot(metrics_history[f'{dataset}_loss'], label=f'{dataset}_loss')
  ax2.plot(metrics_history[f'{dataset}_accuracy'], label=f'{dataset}_accuracy')
ax1.legend()
ax2.legend()
plt.show()