In [None]:
# File: examples/comparison/quick_demo.py
"""Quick demonstration of the core difference between Datarax and Grain.

Datarax (stateful) and Grain (stateless) approaches.

Run this to immediately see why stateful is better!
"""

In [None]:
import flax.nnx as nnx
import numpy as np

In [None]:
print("=" * 60)
print("DATARAX vs GRAIN: CORE DIFFERENCE IN 100 LINES")
print("=" * 60)

============================================================
GRAIN APPROACH: External State Management (Complex)
============================================================

In [None]:
def grain_iterate_data(data, batch_size, state):
    """Grain-style: Must pass and return state."""
    position = state["position"]

    # Get batch
    end = min(position + batch_size, len(data))
    batch = data[position:end]

    # Update state (must remember to do this!)
    new_state = {"position": end, "samples_seen": state["samples_seen"] + (end - position)}

    return batch, new_state  # Must return both!

In [None]:
# Using Grain approach - complex!
print("\n1. GRAIN APPROACH (Stateless):")
print("-" * 40)

In [None]:
data = np.arange(10)
state = {"position": 0, "samples_seen": 0}  # Manual state

In [None]:
# Process batches - must handle state manually
for i in range(3):
    batch, state = grain_iterate_data(data, 3, state)  # Pass state
    print(f"  Batch {i}: {batch}, state={state}")

In [None]:
print("\nProblems:")
print("  ✗ Must pass state to every function")
print("  ✗ Must remember to update state")
print("  ✗ Easy to forget state = ... assignment")
print("  ✗ State scattered across code")

============================================================
DATARAX APPROACH: Internal State Management (Simple)
============================================================

In [None]:
class StatefulLoader(nnx.Module):
    """Datarax style: State managed internally."""

    def __init__(self, data, batch_size=3):
        self.data = data
        self.batch_size = batch_size

        # State as NNX Variables - automatic tracking!
        self.position = nnx.Variable(0)
        self.samples_seen = nnx.Variable(0)

    def get_batch(self):
        """Get batch - no state passing needed!"""
        end = min(self.position.value + self.batch_size, len(self.data))
        batch = self.data[self.position.value : end]

        # Update internal state automatically
        batch_size = end - self.position.value
        self.position.value = end
        self.samples_seen.value += batch_size

        return batch

    def reset(self):
        """Reset state - clean and simple."""
        self.position.value = 0

In [None]:
# Using Datarax approach - simple!
print("\n2. DATARAX APPROACH (Stateful):")
print("-" * 40)

In [None]:
loader = StatefulLoader(np.arange(10), batch_size=3)

In [None]:
# Process batches - no state management needed!
for i in range(3):
    batch = loader.get_batch()  # No state passing!
    print(f"  Batch {i}: {batch}, position={loader.position.value}")

In [None]:
print("\nAdvantages:")
print("  ✓ No state passing needed")
print("  ✓ State updates are automatic")
print("  ✓ Can't forget to update state")
print("  ✓ State encapsulated in module")
print("  ✓ Works with JAX transformations")

In [None]:
# Bonus: Automatic checkpointing!
print("\n3. BONUS - AUTOMATIC CHECKPOINTING:")
print("-" * 40)

In [None]:
# Get state (one line!)
checkpoint = {"position": loader.position.value, "samples": loader.samples_seen.value}
print(f"  Checkpoint: {checkpoint}")

In [None]:
# Reset and restore (clean!)
loader.reset()
print(f"  After reset: position={loader.position.value}")

In [None]:
loader.position.value = checkpoint["position"]
print(f"  After restore: position={loader.position.value}")

In [None]:
print()
print("=" * 60)
print("SUMMARY: Stateful is simpler, cleaner, and less error-prone!")
print("=" * 60)