In [1]:
import jax
import jax.numpy as jnp
import flax.linen as nn
from flax.training import train_state
from flax.training import checkpoints
from flax import struct
import optax
import numpy as np
from typing import Optional, Tuple, Any
import math
from datasets import load_dataset
from torch.utils.data import DataLoader, Dataset
import torch
from transformers import AutoTokenizer
import os
from tqdm import tqdm
import wandb
import os
import re

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import jax
import jax.numpy as jnp
# Check TPU availability and setup
def check_tpu_setup():
    """Check if TPU is available and properly configured."""
    try:
        # Check if we're running on TPU
        if jax.devices('tpu'):
            print("✅ TPU detected!")
            devices = jax.devices('tpu')
            print(f"   TPU devices: {len(devices)}")
            for i, device in enumerate(devices):
                print(f"   Device {i}: {device}")

            # Print TPU-specific info
            print(f"   JAX backend: {jax.lib.xla_bridge.get_backend().platform}")
            print(f"   Total TPU cores: {jax.device_count()}")
            print(f"   Local devices: {jax.local_device_count()}")

            return True
        else:
            print("❌ No TPU devices found")
            print(f"   Available devices: {jax.devices()}")
            return False

    except Exception as e:
        print(f"❌ Error checking TPU: {e}")
        return False

# Check current setup
print("Device Information:")
print(f"JAX version: {jax.__version__}")
print(f"Platform: {jax.lib.xla_bridge.get_backend().platform}")
print(f"Device count: {jax.device_count()}")

# Check TPU specifically
has_tpu = check_tpu_setup()

# Test a simple computation
test_array = jnp.array([1, 2, 3, 4, 5])
result = jnp.sum(test_array ** 2)
print(f"\nTest computation result: {result}")
print(f"Computation device: {result.device}")

print("jax.devices():", jax.devices())
print("jax.device_count():", jax.device_count())
print("jax.local_device_count():", jax.local_device_count())
print("jax.process_count():", jax.process_count())
print("jax.process_index():", jax.process_index())

Device Information:
JAX version: 0.6.2
Platform: cpu
Device count: 1
❌ Error checking TPU: Backend 'tpu' failed to initialize: INTERNAL: Failed to open libtpu.so: dlopen(libtpu.so, 0x0001): tried: 'libtpu.so' (no such file), '/System/Volumes/Preboot/Cryptexes/OSlibtpu.so' (no such file), '/opt/miniconda3/envs/trainLLM/bin/../lib/libtpu.so' (no such file), '/usr/lib/libtpu.so' (no such file, not in dyld cache), 'libtpu.so' (no such file). Available backends are ['cpu']


  print(f"Platform: {jax.lib.xla_bridge.get_backend().platform}")



Test computation result: 55
Computation device: TFRT_CPU_0
jax.devices(): [CpuDevice(id=0)]
jax.device_count(): 1
jax.local_device_count(): 1
jax.process_count(): 1
jax.process_index(): 0


In [3]:
import os
# os.environ['FLAX_USE_LEGACY_CHECKPOINTS'] = '1'  # Force legacy checkpoints to avoid Orbax issues

os.environ.update({
  "NCCL_LL128_BUFFSIZE": "-2",
  "NCCL_LL_BUFFSIZE": "-2",
   "NCCL_PROTO": "SIMPLE,LL,LL128",
 })

os.environ['XLA_FLAGS'] = (
    '--xla_gpu_triton_gemm_any=True '
    '--xla_gpu_enable_latency_hiding_scheduler=true '
)

In [4]:
tokenizer = AutoTokenizer.from_pretrained("gpt2", token="")
tokenizer.pad_token = '[PAD]'

In [65]:
# Configuration class for model parameters
from dataclasses import dataclass

@dataclass
class GPTConfig:
    vocab_size_english: int = 25000
    vocab_size_hindi: int = 25000
    max_seq_len: int = 64
    d_model: int = 512
    num_layers: int = 6
    num_heads: int = 8
    d_ff: int = 2048
    dropout_rate: float = 0.1
    lr: float = 6e-4
    min_lr: float = 0.1 * lr
    warmup_steps: int = 700
    total_steps: int = 20000
    batch_size: int = 256
    required_bsz_tokens: int = 524288
    gradient_accumulation_steps: int = int(required_bsz_tokens // (batch_size * max_seq_len))
    mixed_precision: bool = True
    num_epochs: int = 1
    eval_steps: int = 200

config = GPTConfig()

In [49]:
from jax import config as jax_config

# Set matmul precision based on mixed_precision setting
if config.mixed_precision:
    jax_config.update("jax_default_matmul_precision", "bfloat16")
    print("Set JAX matmul precision to bfloat16 (mixed precision enabled)")
else:
    jax_config.update("jax_default_matmul_precision", "float32")
    print("Set JAX matmul precision to float32 (mixed precision disabled)")

Set JAX matmul precision to bfloat16 (mixed precision enabled)


In [50]:
# Helper function to get the appropriate dtype based on mixed_precision setting
def get_dtype():
    """Return bfloat16 if mixed_precision is True, else float32."""
    return jnp.bfloat16 if config.mixed_precision else jnp.float32

print(f"Using dtype: {get_dtype()} (mixed_precision={config.mixed_precision})")

Using dtype: <class 'jax.numpy.bfloat16'> (mixed_precision=True)


In [51]:
!wandb login

  pid, fd = os.forkpty()


[34m[1mwandb[0m: Currently logged in as: [33mrajceo2031[0m ([33mrentio[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [66]:
from flax.training import train_state

class TrainState(train_state.TrainState):
    grad_accum: Any = None
    accum_step: int = 0

In [53]:
#Loading TinyStories dataset from Huggingface
train_dataset = load_dataset("roneneldan/TinyStories", split="train", token='')
val_dataset = load_dataset("roneneldan/TinyStories", split="validation", token='')

Using the latest cached version of the dataset since roneneldan/TinyStories couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /Users/yuvrajsingh9886/.cache/huggingface/datasets/roneneldan___tiny_stories/default/0.0.0/f54c09fd23315a6f9c86f9dc80f725de7d8f9c64 (last modified on Tue Sep 23 04:41:01 2025).
Using the latest cached version of the dataset since roneneldan/TinyStories couldn't be found on the Hugging Face Hub
Found the latest cached dataset configuration 'default' at /Users/yuvrajsingh9886/.cache/huggingface/datasets/roneneldan___tiny_stories/default/0.0.0/f54c09fd23315a6f9c86f9dc80f725de7d8f9c64 (last modified on Tue Sep 23 04:41:01 2025).


In [67]:
class SelfAttention(nn.Module):
    d_model: int = config.d_model
    num_heads: int = config.num_heads
    dropout_rate: float = config.dropout_rate

    def setup(self):
        self.head_size = self.d_model // self.num_heads

        # Proper initialization for attention layers
        self.d_Q = nn.Dense(
            features=self.head_size,
            use_bias=False,
            dtype=get_dtype(),
            kernel_init=nn.initializers.normal(stddev=0.02)
        )
        self.d_K = nn.Dense(
            features=self.head_size,
            use_bias=False,
            dtype=get_dtype(),
            kernel_init=nn.initializers.normal(stddev=0.02)
        )
        self.d_V = nn.Dense(
            features=self.head_size,
            use_bias=False,
            dtype=get_dtype(),
            kernel_init=nn.initializers.normal(stddev=0.02)
        )
        self.d_O = nn.Dense(
            features=self.d_model,
            use_bias=False,
            dtype=get_dtype(),
            kernel_init=nn.initializers.normal(stddev=0.02)
        )
        self.dropout = nn.Dropout(self.dropout_rate)

    def __call__(self, x, training=True):
        B, T, C = x.shape
        query = self.d_Q(x)
        key = self.d_K(x)
        value = self.d_V(x)

        # Proper attention scaling using head_size
        weights = jnp.matmul(query, key.transpose(0, 2, 1)) * (self.head_size ** -0.5)

        # # Better causal mask using -inf
        # mask = jnp.tril(jnp.ones((T, T)))
        # weights = jnp.where(mask == 0, -jnp.inf, weights)

        weights = nn.softmax(weights, axis=-1)
        # weights = self.dropout(weights, deterministic=not training)  # Apply dropout to attention weights

        out = jnp.matmul(weights, value)
        out = self.d_O(out)
        out = self.dropout(out, deterministic=not training)
        return out

In [68]:
class FullMHA(nn.Module):
    d_model: int = config.d_model
    num_heads: int = config.num_heads
    dropout_rate: float = config.dropout_rate

    def setup(self):
        self.heads = [SelfAttention(self.d_model, self.num_heads, self.dropout_rate) for _ in range(self.num_heads)]

        # Proper initialization for output projection
        self.linear = nn.Dense(
            features=self.d_model,
            dtype=get_dtype(),
            kernel_init=nn.initializers.normal(stddev=0.02),
            bias_init=nn.initializers.zeros
        )
        self.dropout = nn.Dropout(self.dropout_rate)

    def __call__(self, x, training=True):
        out = jnp.concatenate([head(x, training) for head in self.heads], axis=-1)
        out = self.linear(out)
        out = self.dropout(out, deterministic=not training)
        return out

In [69]:
class MaskedSelfAttention(nn.Module):
    d_model: int = config.d_model
    num_heads: int = config.num_heads
    dropout_rate: float = config.dropout_rate

    def setup(self):
        self.head_size = self.d_model // self.num_heads

        # Proper initialization for attention layers
        self.d_Q = nn.Dense(
            features=self.head_size,
            use_bias=False,
            dtype=get_dtype(),
            kernel_init=nn.initializers.normal(stddev=0.02)
        )
        self.d_K = nn.Dense(
            features=self.head_size,
            use_bias=False,
            dtype=get_dtype(),
            kernel_init=nn.initializers.normal(stddev=0.02)
        )
        self.d_V = nn.Dense(
            features=self.head_size,
            use_bias=False,
            dtype=get_dtype(),
            kernel_init=nn.initializers.normal(stddev=0.02)
        )
        self.d_O = nn.Dense(
            features=self.d_model,
            use_bias=False,
            dtype=get_dtype(),
            kernel_init=nn.initializers.normal(stddev=0.02)
        )
        self.dropout = nn.Dropout(self.dropout_rate)

    def __call__(self, x, training=True):
        B, T, C = x.shape
        query = self.d_Q(x)
        key = self.d_K(x)
        value = self.d_V(x)

        # Proper attention scaling using head_size
        weights = jnp.matmul(query, key.transpose(0, 2, 1)) * (self.head_size ** -0.5)

        # Better causal mask using -inf
        mask = jnp.tril(jnp.ones((T, T)))
        weights = jnp.where(mask == 0, -jnp.inf, weights)

        weights = nn.softmax(weights, axis=-1)
        # weights = self.dropout(weights, deterministic=not training)  # Apply dropout to attention weights

        out = jnp.matmul(weights, value)
        out = self.d_O(out)
        out = self.dropout(out, deterministic=not training)
        return out

In [70]:
class MaskedMHA(nn.Module):
    d_model: int = config.d_model
    num_heads: int = config.num_heads
    dropout_rate: float = config.dropout_rate

    def setup(self):
        self.heads = [MaskedSelfAttention(self.d_model, self.num_heads, self.dropout_rate) for _ in range(self.num_heads)]

        # Proper initialization for output projection
        self.linear = nn.Dense(
            features=self.d_model,
            dtype=get_dtype(),
            kernel_init=nn.initializers.normal(stddev=0.02),
            bias_init=nn.initializers.zeros
        )
        self.dropout = nn.Dropout(self.dropout_rate)

    def __call__(self, x, training=True):
        out = jnp.concatenate([head(x, training) for head in self.heads], axis=-1)
        out = self.linear(out)
        out = self.dropout(out, deterministic=not training)
        return out

In [71]:
class CrossAttention(nn.Module):
    d_model: int = config.d_model
    num_heads: int = config.num_heads
    dropout_rate: float = config.dropout_rate

    def setup(self):
        self.head_size = self.d_model // self.num_heads

        # Proper initialization for attention layers
        self.d_Q = nn.Dense(
            features=self.head_size,
            use_bias=False,
            dtype=get_dtype(),
            kernel_init=nn.initializers.normal(stddev=0.02)
        )
        self.d_K = nn.Dense(
            features=self.head_size,
            use_bias=False,
            dtype=get_dtype(),
            kernel_init=nn.initializers.normal(stddev=0.02)
        )
        self.d_V = nn.Dense(
            features=self.head_size,
            use_bias=False,
            dtype=get_dtype(),
            kernel_init=nn.initializers.normal(stddev=0.02)
        )
        self.d_O = nn.Dense(
            features=self.d_model,
            use_bias=False,
            dtype=get_dtype(),
            kernel_init=nn.initializers.normal(stddev=0.02)
        )
        self.dropout = nn.Dropout(self.dropout_rate)

    def __call__(self, q, k, v, training=True):
        B, T, C = q.shape
        query = self.d_Q(q)
        key = self.d_K(k)
        value = self.d_V(v)

        # Proper attention scaling using head_size
        weights = jnp.matmul(query, key.transpose(0, 2, 1)) * (self.head_size ** -0.5)

       

        weights = nn.softmax(weights, axis=-1)
        # weights = self.dropout(weights, deterministic=not training)  # Apply dropout to attention weights

        out = jnp.matmul(weights, value)
        out = self.d_O(out)
        out = self.dropout(out, deterministic=not training)
        return out

In [72]:
class CrossMHA(nn.Module):
    d_model: int = config.d_model
    num_heads: int = config.num_heads
    dropout_rate: float = config.dropout_rate

    def setup(self):
        self.heads = [CrossAttention(self.d_model, self.num_heads, self.dropout_rate) for _ in range(self.num_heads)]

        # Proper initialization for output projection
        self.linear = nn.Dense(
            features=self.d_model,
            dtype=get_dtype(),
            kernel_init=nn.initializers.normal(stddev=0.02),
            bias_init=nn.initializers.zeros
        )
        self.dropout = nn.Dropout(self.dropout_rate)

    def __call__(self, q, k, v, training=True):
        out = jnp.concatenate([head(q, k, v, training) for head in self.heads], axis=-1)
        out = self.linear(out)
        out = self.dropout(out, deterministic=not training)
        return out

In [73]:
class MLP(nn.Module):
    d_model: int = config.d_model
    d_ff: int = config.d_ff
    dropout_rate: float = config.dropout_rate

    def setup(self):
        # Proper initialization for MLP layers
        self.fc1 = nn.Dense(
            features=self.d_ff,
            dtype=get_dtype(),
            kernel_init=nn.initializers.normal(stddev=0.02),
            bias_init=nn.initializers.zeros
        )
        self.fc2 = nn.Dense(
            features=self.d_model,
            dtype=get_dtype(),
            kernel_init=nn.initializers.normal(stddev=0.02),
            bias_init=nn.initializers.zeros
        )
        self.dropout = nn.Dropout(self.dropout_rate)

    def __call__(self, x, training=True):
        x = self.fc1(x)
        x = nn.relu(x)
        x = self.fc2(x)
        x = self.dropout(x, deterministic=not training)  # Remove duplicate GELU
        return x

In [74]:
class DecoderBlock(nn.Module):
    d_model: int = config.d_model
    num_heads: int = config.num_heads
    d_ff: int = config.d_ff
    dropout_rate: float = config.dropout_rate

    def setup(self):
        self.masked_attention = MaskedMHA(self.d_model, self.num_heads, self.dropout_rate)
        self.cross_attention = CrossMHA(self.d_model, self.num_heads, self.dropout_rate)
        self.mlp = MLP(self.d_model, self.d_ff, self.dropout_rate)
        self.ln1 = nn.LayerNorm(dtype=get_dtype())
        self.ln2 = nn.LayerNorm(dtype=get_dtype())
        self.ln3 = nn.LayerNorm(dtype=get_dtype())
        
    def __call__(self, x, k, v, training=True):

        out = self.ln1(x + self.masked_attention(x, training))
        cross_attn = self.ln2(out + self.cross_attention(out, k, v, training))
        mlp_out = self.ln3(cross_attn + self.mlp(cross_attn, training))
        return mlp_out

In [75]:
class EncoderBlock(nn.Module):
    d_model: int = config.d_model
    num_heads: int = config.num_heads
    d_ff: int = config.d_ff
    dropout_rate: float = config.dropout_rate

    def setup(self):
        self.attention = FullMHA(self.d_model, self.num_heads, self.dropout_rate)
        self.mlp = MLP(self.d_model, self.d_ff, self.dropout_rate)
        self.ln1 = nn.LayerNorm(dtype=get_dtype())
        self.ln2 = nn.LayerNorm(dtype=get_dtype())

    def __call__(self, x, training=True):
        attn =  self.ln1(self.attention(x, training))
        # x = x + attn * ((2 * config.num_layers ** -0.5))
        mlp_out = self.ln2(self.mlp(attn, training))
        # x = x + mlp_out * (2 * (config.num_layers ** -0.5))
        return mlp_out

In [76]:
class Transformer(nn.Module):
    vocab_size__english: int = config.vocab_size_english
    vocab_size__hindi: int = config.vocab_size_hindi
    max_seq_len: int = config.max_seq_len
    d_model: int = config.d_model
    num_layers: int = config.num_layers
    num_heads: int = config.num_heads
    d_ff: int = config.d_ff
    dropout_rate: float = config.dropout_rate

    def setup(self):
        self.token_embedding_eng = nn.Embed(
            num_embeddings=self.vocab_size__english,
            features=self.d_model,
            embedding_init=nn.initializers.normal(stddev=0.02),
            dtype=get_dtype()
        )
        self.token_embedding_hindi = nn.Embed(
            num_embeddings=self.vocab_size__hindi,
            features=self.d_model,
            embedding_init=nn.initializers.normal(stddev=0.02),
            dtype=get_dtype()
        )
        self.positional_embedding_english = self.param(
            "positional_embeddings_english",  # name
            lambda key: jax.random.normal(key, (1, self.max_seq_len, self.d_model), dtype=get_dtype()) * 0.01
        )
        self.positional_embedding_hindi = self.param(
            "positional_embeddings_hindi",  # name
            lambda key: jax.random.normal(key, (1, self.max_seq_len, self.d_model), dtype=get_dtype()) * 0.01
        )
        self.encoder_layers = [EncoderBlock(self.d_model, self.num_heads, self.d_ff, self.dropout_rate) for _ in range(self.num_layers)]
        self.decoder_layers = [DecoderBlock(self.d_model, self.num_heads, self.d_ff, self.dropout_rate) for _ in range(self.num_layers)]
        # self.ln_f = nn.LayerNorm(dtype=get_dtype())
        self.head = nn.Dense(
            features=self.vocab_size__hindi,
            dtype=get_dtype(),
            kernel_init=nn.initializers.normal(stddev=0.02),
            bias_init=nn.initializers.zeros
        )
        self.dropout = nn.Dropout(self.dropout_rate)

    def __call__(self, enc_input, dec_input, training=True):
        B, T_enc = enc_input.shape
        B, T_dec = dec_input.shape

        # Encoder
        enc_x = self.token_embedding_eng(enc_input) + self.positional_embedding_english[:, :T_enc, :]
        enc_x = self.dropout(enc_x, deterministic=not training)
        for layer in self.encoder_layers:
            enc_x = layer(enc_x, training)

        # Decoder
        dec_x = self.token_embedding_hindi(dec_input) + self.positional_embedding_hindi[:, :T_dec, :]
        dec_x = self.dropout(dec_x, deterministic=not training)
        for layer in self.decoder_layers:
            dec_x = layer(dec_x, enc_x, enc_x, training)

        # dec_x = self.ln_f(dec_x)
        logits = self.head(dec_x)
        return logits

In [77]:
# Add this cell to inspect the model summary like torchsummary

from flax.linen import tabulate
import jax

# Initialize model
model = Transformer()
key = jax.random.PRNGKey(0)
x1 = jnp.ones((1, config.max_seq_len), dtype=jnp.int32)  # Dummy input for tabulation
x2 = jnp.ones((1, config.max_seq_len), dtype=jnp.int32)  # Dummy input for tabulation

# Tabulate the model structure
tabulate_fn = tabulate(model, key, console_kwargs={'width': 120})

# Count total parameters
params = model.init(key, x1, x2)['params']
total_params = sum(jax.tree_util.tree_leaves(jax.tree.map(lambda arr: arr.size, params)))

# Get raw summary and clean ANSI codes
raw_summary = tabulate_fn(x1, x2, training=True)
# Remove ANSI color codes for clean logging
clean_summary = re.sub(r'\x1b\[[0-9;]*m', '', raw_summary)

# Save to log file with clean formatting
with open('model_summary.txt', 'w') as f:
    f.write("=" * 60 + "\n")
    f.write("TRANSFORMER MODEL ARCHITECTURE SUMMARY\n")
    f.write("=" * 60 + "\n\n")
    f.write(f"Total Parameters: {total_params:,}\n")
    f.write(f"Model Configuration:\n")
    f.write(f"  - Vocabulary Size English: {config.vocab_size_english:,}\n")
    f.write(f"  - Vocabulary Size Hindi: {config.vocab_size_hindi:,}\n")
    f.write(f"  - Max Sequence Length: {config.max_seq_len}\n")
    f.write(f"  - Model Dimension: {config.d_model}\n")
    f.write(f"  - Number of Layers: {config.num_layers}\n")
    f.write(f"  - Number of Heads: {config.num_heads}\n")
    f.write(f"  - Feed Forward Dimension: {config.d_ff}\n")
    f.write(f"  - Dropout Rate: {config.dropout_rate}\n\n")
    f.write("Detailed Layer Information:\n")
    f.write("-" * 40 + "\n")
    f.write(clean_summary)

print(f"Model summary saved to model_summary.txt")
print(f"Total Parameters: {total_params:,}")
print(f"Model size: ~{total_params * 2 / (1024**2):.1f} MB (bfloat16)")

Model summary saved to model_summary.txt
Total Parameters: 120,350,120
Model size: ~229.5 MB (bfloat16)


In [78]:
def create_learning_rate_schedule():
    """Create a learning rate schedule with warmup and cosine decay."""
    # Use values from config
    max_lr = config.lr  # 6e-4
    min_lr = config.min_lr
    warmup_steps = config.warmup_steps  # 700 (or can override to 715)
    max_steps = config.total_steps  # 20000 (or can override to 19073)
    
    def get_lr(it):

        return config.num_layers ** -0.5 * min(it ** -0.5, it * warmup_steps ** -1.5)

    return get_lr

In [None]:
def compute_ce_loss(logits, labels):
    """Compute cross-entropy loss."""
    labels = labels[:, 1:]
    logits = logits[:, :-1, :]  # Shift logits to align with labels

    loss = optax.softmax_cross_entropy_with_integer_labels(logits, labels)
    pad_mask = (labels != tokenizer.pad_token_id)
    loss = jnp.where(pad_mask, loss, 0.0)
    return loss.sum () / pad_mask.sum()

In [None]:
def create_train_state(rng, config):
    """Create initial training state."""
    model = Transformer()

    # Initialize parameters
    dummy_input = jnp.ones((1, config.max_seq_len), dtype=jnp.int32)
    dummy_input2 =  jnp.ones((1, config.max_seq_len), dtype=jnp.int32)  # For encoder and decoder inputs
    params = model.init(rng, (dummy_input, dummy_input2))['params']

    # Create learning rate schedule
    lr_schedule = create_learning_rate_schedule()

    # Create optimizer with stronger gradient clipping and better settings
    tx = optax.chain(
        optax.clip_by_global_norm(1.0),  # Much stronger clipping (was 1.0)
        optax.adam(
            learning_rate=lr_schedule,
            b1=0.9,
            b2=0.98,
            # weight_decay=0.01,  # Reduced weight decay (was 0.1)
            eps=1e-9  # Added epsilon for numerical stability
        )
    )

    return TrainState.create(
        apply_fn=model.apply,
        params=params,
        tx=tx
    )

In [None]:
@jax.jit
def train_step(state, batch, step):
    def loss_fn(params):
        logits = state.apply_fn({"params": params}, batch, training=True, rngs={'dropout': step})
        loss = compute_ce_loss(logits, batch)
        return loss, logits

    (loss, logits), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)

    # Compute gradient norm for logging
    grad_norm = jnp.sqrt(sum([jnp.sum(jnp.square(g)) for g in jax.tree_util.tree_leaves(grads)]))

    # Update the parameters
    state = state.apply_gradients(grads=grads)
    return state, loss, grad_norm

In [None]:
# JIT-compiled evaluation step
@jax.jit
def eval_step(state, batch, step):
    """Single evaluation step."""
    logits = state.apply_fn({'params': state.params}, batch, training=False, rngs={'dropout': step})
    loss = compute_ce_loss(logits, batch)

    return loss, None