In [None]:
# This Python 3 environment comes with many helpful analytics libraries installed
# It is defined by the kaggle/python Docker image: https://github.com/kaggle/docker-python
# For example, here's several helpful packages to load

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)

# Input data files are available in the read-only "../input/" directory
# For example, running this (by clicking run or pressing Shift+Enter) will list all files under the input directory

import os
for dirname, _, filenames in os.walk('/kaggle/input'):
    for filename in filenames:
        print(os.path.join(dirname, filename))

# You can write up to 20GB to the current directory (/kaggle/working/) that gets preserved as output when you create a version using "Save & Run All" 
# You can also write temporary files to /kaggle/temp/, but they won't be saved outside of the current session

In [1]:
# ============================================================
# Toy Transformer using JAX + Flax on AG News
# ============================================================

!pip install -q flax optax datasets tensorboard

# ============================================================
# 0. GPU preference + bfloat16 global policy
# ============================================================
import os
os.environ["JAX_PLATFORM_NAME"] = "gpu"

import jax
import jax.numpy as jnp
from flax import linen as nn
from flax.training import train_state
import optax
from datasets import load_dataset
from torch.utils.tensorboard import SummaryWriter
import numpy as np

# Check available backend
print("JAX default backend:", jax.default_backend())
print("Available devices:", jax.devices())

# ============================================================
# 1. Enable bfloat16 precision (optional fallback to float32)
# ============================================================
if any(d.device_kind == "TPU" or "P100" in d.device_kind for d in jax.devices()):
    dtype = jnp.bfloat16
    print(" Using bfloat16 precision")
else:
    dtype = jnp.float32
    print("Using float32 (fallback, GPU/CPU without bfloat16 support)")

# ============================================================
# 2. Dataset (AG News small subset)
# ============================================================
key = jax.random.PRNGKey(42)
dataset = load_dataset("ag_news")
train_data = dataset["train"].shuffle(seed=0).select(range(500))
test_data = dataset["test"].select(range(200))

def tokenize(text):
    return text.lower().split()

vocab = {}
def encode(text, max_len=32):
    tokens = tokenize(text)
    ids = []
    for t in tokens:
        if t not in vocab:
            vocab[t] = len(vocab) + 1
        ids.append(vocab[t])
    ids = ids[:max_len]
    arr = np.array(ids + [0]*(max_len - len(ids)))
    return arr

X_train = np.stack([encode(x["text"]) for x in train_data])
y_train = np.array([x["label"] for x in train_data])
X_test = np.stack([encode(x["text"]) for x in test_data])
y_test = np.array([x["label"] for x in test_data])

vocab_size = len(vocab) + 1
num_classes = 4
max_len = X_train.shape[1]

# ============================================================
# 3. Transformer Encoder (bfloat16 support)
# ============================================================
class PositionalEncoding(nn.Module):
    emb_dim: int
    max_len: int
    dtype: any

    @nn.compact
    def __call__(self, x):
        pos = jnp.arange(self.max_len)[:, None]
        i = jnp.arange(self.emb_dim)[None, :]
        angle_rates = 1 / jnp.power(10000, (2 * (i//2)) / self.emb_dim)
        angle_rads = pos * angle_rates
        pos_encoding = jnp.where(i % 2 == 0,
                                 jnp.sin(angle_rads),
                                 jnp.cos(angle_rads))
        pos_encoding = pos_encoding[jnp.newaxis, :, :].astype(self.dtype)
        return x + pos_encoding[:, :x.shape[1], :]

class FeedForward(nn.Module):
    emb_dim: int
    hidden_dim: int
    dtype: any

    @nn.compact
    def __call__(self, x):
        x = nn.Dense(self.hidden_dim, dtype=self.dtype)(x)
        x = nn.relu(x)
        x = nn.Dense(self.emb_dim, dtype=self.dtype)(x)
        return x

class TransformerBlock(nn.Module):
    emb_dim: int
    num_heads: int
    ff_dim: int
    dtype: any

    @nn.compact
    def __call__(self, x):
        attn_out = nn.MultiHeadDotProductAttention(
            num_heads=self.num_heads,
            dtype=self.dtype
        )(x, x)
        x = nn.LayerNorm(dtype=self.dtype)(x + attn_out)

        ff_out = FeedForward(self.emb_dim, self.ff_dim, self.dtype)(x)
        x = nn.LayerNorm(dtype=self.dtype)(x + ff_out)
        return x

class TransformerEncoder(nn.Module):
    vocab_size: int
    max_len: int
    emb_dim: int = 64
    num_heads: int = 4
    ff_dim: int = 128
    num_layers: int = 2
    num_classes: int = 4
    dtype: any = jnp.float32

    @nn.compact
    def __call__(self, x):
        # Embedding + positional encoding
        x = nn.Embed(self.vocab_size, self.emb_dim, dtype=self.dtype)(x)
        x = PositionalEncoding(self.emb_dim, self.max_len, self.dtype)(x)

        for _ in range(self.num_layers):
            x = TransformerBlock(self.emb_dim, self.num_heads, self.ff_dim, self.dtype)(x)

        x = jnp.mean(x, axis=1)
        x = nn.Dense(self.num_classes, dtype=jnp.float32)(x)  # logits in float32
        return x

# ============================================================
# 4. Initialize model & optimizer
# ============================================================
model = TransformerEncoder(vocab_size=vocab_size, max_len=max_len, dtype=dtype)
params = model.init(key, jnp.ones((1, max_len), dtype=jnp.int32))
tx = optax.adam(1e-3)

state = train_state.TrainState.create(
    apply_fn=model.apply,
    params=params,
    tx=tx
)

# ============================================================
# 5. Training functions
# ============================================================
def cross_entropy_loss(params, batch):
    logits = state.apply_fn(params, batch["inputs"])
    one_hot = jax.nn.one_hot(batch["labels"], num_classes=num_classes)
    return optax.softmax_cross_entropy(logits, one_hot).mean()

@jax.jit
def train_step(state, batch):
    grads = jax.grad(cross_entropy_loss)(state.params, batch)
    return state.apply_gradients(grads=grads)

@jax.jit
def compute_accuracy(params, batch):
    preds = jnp.argmax(state.apply_fn(params, batch["inputs"]), axis=-1)
    return jnp.mean(preds == batch["labels"])

# ============================================================
# 6. TensorBoard setup
# ============================================================
writer = SummaryWriter("runs/jax_bfloat16_transformer")

# ============================================================
# 7. Training loop
# ============================================================
for epoch in range(50):
    batch = {
        "inputs": jax.device_put(jnp.array(X_train, dtype=jnp.int32)),
        "labels": jax.device_put(jnp.array(y_train))
    }

    state = train_step(state, batch)
    loss = cross_entropy_loss(state.params, batch)
    acc = compute_accuracy(state.params, batch)

    writer.add_scalar("Loss/train", float(loss), epoch)
    writer.add_scalar("Accuracy/train", float(acc), epoch)

    print(f"Epoch {epoch+1}: Loss={loss:.4f}, Accuracy={acc*100:.2f}%")

writer.close()
print("Training complete — visualize with TensorBoard!")


[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m42.8/42.8 MB[0m [31m26.1 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
bigframes 2.12.0 requires google-cloud-bigquery-storage<3.0.0,>=2.30.0, which is not installed.
pylibcudf-cu12 25.2.2 requires pyarrow<20.0.0a0,>=14.0.0; platform_machine == "x86_64", but you have pyarrow 21.0.0 which is incompatible.
cudf-cu12 25.2.2 requires pyarrow<20.0.0a0,>=14.0.0; platform_machine == "x86_64", but you have pyarrow 21.0.0 which is incompatible.
bigframes 2.12.0 requires google-cloud-bigquery[bqstorage,pandas]>=3.31.0, but you have google-cloud-bigquery 3.25.0 which is incompatible.
bigframes 2.12.0 requires rich<14,>=12.4.4, but you have rich 14.1.0 which is incompatible.
cudf-polars-cu12 25.6.0 requires pylibcudf-cu12==25.6.*, but you have pylibc

2025-10-09 18:01:29.781306: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1760032889.968786      37 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1760032890.026787      37 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
INFO:2025-10-09 18:01:39,351:jax._src.xla_bridge:924: Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
INFO:2025-10-09 18:01:39,363:jax._src.xla_bridge:924: Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory


JAX default backend: gpu
Available devices: [CudaDevice(id=0)]
 Using bfloat16 precision


README.md: 0.00B [00:00, ?B/s]

data/train-00000-of-00001.parquet:   0%|          | 0.00/18.6M [00:00<?, ?B/s]

data/test-00000-of-00001.parquet:   0%|          | 0.00/1.23M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/120000 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/7600 [00:00<?, ? examples/s]

Epoch 1: Loss=1.4297, Accuracy=29.00%
Epoch 2: Loss=1.5390, Accuracy=22.00%
Epoch 3: Loss=1.4309, Accuracy=22.00%
Epoch 4: Loss=1.3730, Accuracy=23.20%
Epoch 5: Loss=1.4010, Accuracy=29.20%
Epoch 6: Loss=1.4179, Accuracy=29.00%
Epoch 7: Loss=1.3968, Accuracy=29.00%
Epoch 8: Loss=1.3815, Accuracy=39.80%
Epoch 9: Loss=1.3763, Accuracy=25.80%
Epoch 10: Loss=1.3620, Accuracy=25.80%
Epoch 11: Loss=1.3478, Accuracy=46.00%
Epoch 12: Loss=1.3472, Accuracy=40.20%
Epoch 13: Loss=1.3524, Accuracy=44.20%
Epoch 14: Loss=1.3503, Accuracy=45.00%
Epoch 15: Loss=1.3395, Accuracy=44.80%
Epoch 16: Loss=1.3266, Accuracy=55.20%
Epoch 17: Loss=1.3167, Accuracy=48.60%
Epoch 18: Loss=1.3085, Accuracy=72.40%
Epoch 19: Loss=1.2979, Accuracy=54.60%
Epoch 20: Loss=1.2863, Accuracy=53.40%
Epoch 21: Loss=1.2770, Accuracy=47.80%
Epoch 22: Loss=1.2631, Accuracy=47.00%
Epoch 23: Loss=1.2381, Accuracy=52.40%
Epoch 24: Loss=1.2105, Accuracy=54.40%
Epoch 25: Loss=1.1790, Accuracy=55.20%
Epoch 26: Loss=1.1389, Accuracy=75

In [3]:
from tensorboard import notebook

notebook.start("--logdir runs/jax_bfloat16_transformer")


<IPython.core.display.Javascript object>

In [4]:
import os
print(os.listdir("runs/jax_bfloat16_transformer"))


['events.out.tfevents.1760032910.420c89ce379d.37.0']
