In [1]:
import jax
jax.config.update("jax_platform_name", "cpu")

print("Current backend:", jax.default_backend())

Current backend: cpu


In [2]:
import jax, sys

print("Python version:", sys.version)
print("JAX version:", jax.__version__)
print("Devices:", jax.devices())

Python version: 3.12.3 (main, Aug 14 2025, 17:47:21) [GCC 13.3.0]
JAX version: 0.5.3
Devices: [CpuDevice(id=0)]


In [3]:
import jax
import jax.numpy as jnp
from flax import nnx

RMSNorm = nnx.RMSNorm

In [4]:
key = nnx.Rngs(42)
hidden_size = (5, 6)
x = jax.random.normal(jax.random.key(0), hidden_size)
norm = RMSNorm(num_features=6,rngs=key)
output = norm(x)
print(output)

[[ 1.4136842   1.7644585  -0.37775773 -0.06849329  0.15341455 -0.8469071 ]
 [-0.37608212  0.37538347  0.5044429  -0.7214626   1.6549253  -1.484553  ]
 [ 0.37910467  0.16683145  1.3502184   1.5969633   1.0262417   0.6339453 ]
 [ 0.01842601 -1.4296368  -1.387021    1.2891457   0.03520264  0.607316  ]
 [ 0.17680073  0.38078466  1.6741968   0.92929405 -1.0779979  -0.9975626 ]]


In [3]:
import jax
import jax.numpy as jnp
from flax import nnx

class SwiGLU(nnx.Module):
    def __init__(
        self,
        in_features: int,
        hidden_features: int = None,
        out_features: int = None,
        *,
        use_bias: bool = True,
        rngs: nnx.rnglib.Rngs,
    ):
        self.in_features = in_features
        self.hidden_features = hidden_features or in_features
        self.out_features = out_features or in_features
        self.w12 = nnx.Linear(self.in_features, 2*self.hidden_features, use_bias=use_bias, rngs=rngs)
        self.w3 = nnx.Linear(self.hidden_features, self.out_features, use_bias=use_bias, rngs=rngs)
    
    def __call__(self, x):
        w12 = self.w12(x)
        x1, x2 = jnp.split(w12, 2, axis=-1)
        return self.w3(nnx.silu(x1) * x2)

if __name__ == '__main__':
    x = jax.random.normal(jax.random.PRNGKey(42), (2, 16, 128))
    model = SwiGLU(128, 256, 128, rngs=nnx.Rngs(params=42))
    output = model(x)
    print(output)

[[[ 0.5698478   0.777241   -0.01974281 ... -0.6180885  -0.46066764
    0.6937558 ]
  [-0.26934257 -0.50588846  0.4076102  ... -0.9948675   0.10751119
   -0.94601625]
  [-0.31217194 -0.9516002   0.5457004  ...  0.03558588 -0.37972584
    0.41939488]
  ...
  [-0.83006567  0.40642098 -1.0011536  ...  0.6166688  -1.0053461
   -0.4291304 ]
  [-0.02426559  0.67361283 -0.258857   ...  0.26960015  0.12398798
    0.44029686]
  [-0.6357401  -0.05438824 -0.837344   ... -0.25044262 -0.0588683
    0.34996915]]

 [[-0.6730153   0.78862244  0.09795742 ... -0.8814044   0.9523628
   -0.2764015 ]
  [ 0.4803267   0.35329565 -0.43849137 ...  0.36068055 -0.98329335
   -0.09067731]
  [-0.79592603 -0.44201332 -0.21084595 ... -0.60548025  0.7878249
   -1.5968543 ]
  ...
  [ 0.27719223  0.16724351 -0.27431744 ...  0.6415455   0.33255875
    0.12926781]
  [ 0.05443227  0.44054908  0.20841222 ...  0.6051239  -0.3080075
   -1.076773  ]
  [ 0.1794864   0.4517041   0.14160526 ... -0.18415213  0.5466403
   -0.156816

In [4]:
import jax
import jax.numpy as jnp
from flax import nnx
from flax.nnx import rnglib
from flax.nnx import RMSNorm

class AdaLN(nnx.Module):
    def __init__(
        self,
        dim: int,
        y_dim: int,
        *,
        gate: bool = True,
        norm_layer: nnx.Module = RMSNorm,
        shared: bool = False,
        rngs: rnglib.Rngs,
    ):
        self.norm = norm_layer(num_features=dim, rngs=rngs)
        self.gate = gate
        if shared:
            self.adaln = None
        else:
            self.adaln = nnx.Linear(
                y_dim,
                dim * (2 + int(self.gate)),
                use_bias = True,
                bias_init = nnx.initializers.zeros_init(),
                kernel_init = nnx.initializers.zeros_init(),
                rngs = rngs,
            )

    def __call__(self, x, y, shared_adaln=None):
        if shared_adaln is None:
            out = self.adaln(y)
            scale, shift, *gate = jnp.split(out, 2 + int(self.gate), axis=-1)
        else:
            scale, shift, *gate = shared_adaln

        result = self.norm(x) * (scale + 1.0) + shift
        g = (gate[0] + 1.0) if self.gate else 1.0
        return result, g

if __name__ == '__main__':
    key = jax.random.PRNGKey(42)
    batch_size, dim, y_dim = 2, 16, 32
    x = jax.random.normal(key, (batch_size, dim))
    y = jax.random.normal(key, (batch_size, y_dim))

    model = AdaLN(dim, y_dim, gate=True, shared=False, rngs=nnx.Rngs(params=42))
    output, gate = model(x, y)
    print(f'{output.shape = }, {gate.shape = }')

    shared_adaln = nnx.Linear(
        y_dim,
        dim * 3,
        use_bias = True,
        bias_init = nnx.initializers.zeros_init(),
        kernel_init = nnx.initializers.zeros_init(),
        rngs = nnx.Rngs(params=43),
    )

    shared = shared_adaln(y)
    scale, shift, *gate = jnp.split(shared, 3, axis=-1)

    model = AdaLN(dim, y_dim, gate=True, shared=True, rngs=nnx.Rngs(params=42))
    output, gate = model(x, y, shared_adaln=(scale, shift, *gate))
    print(f'{output.shape = }, {gate.shape = }')


output.shape = (2, 16), gate.shape = (2, 16)
output.shape = (2, 16), gate.shape = (2, 16)


In [7]:
from typing import Callable

import jax
import jax.numpy as jnp
from jax import lax
from flax import nnx

from functools import partial

@jax.jit
def rotate_half(x):
    x1 = x[..., 0::2]
    x2 = x[..., 1::2]
    x = jnp.stack((-x2, x1), axis=-1)
    *shape, d, r = x.shape
    return x.reshape(*shape, d * r)

@partial(jax.jit, static_argnames=("start_index",))
def apply_rotary_emb(
    freqs: jnp.ndarray,
    t: jnp.ndarray,
    start_index: int = 0,
    scale: float = 1.0
) -> jnp.ndarray:
    rot_dim = freqs.shape[-1]
    end_index = start_index + rot_dim

    t_left  = lax.slice_in_dim(t, 0, start_index, axis=-1)
    t_mid   = lax.slice_in_dim(t, start_index, end_index, axis=-1)
    t_right = lax.slice_in_dim(t, end_index, t.shape[-1], axis=-1)

    freqs_cast = jnp.broadcast_to(freqs, t_mid.shape)
    t_rot = (t_mid * jnp.cos(freqs_cast) * scale) + (rotate_half(t_mid) * jnp.sin(freqs_cast) * scale)
    return jnp.concatenate([t_left, t_rot, t_right], axis=-1)

@partial(jax.jit, static_argnames=("num",))
def centers(start: float, stop: float, num: int, dtype=None):
    edges = jnp.linspace(start, stop, num + 1, dtype=dtype)
    return (edges[:-1] + edges[1:]) / 2

@jax.jit
def make_grid(h_pos, w_pos):
    grid = jnp.stack(jnp.meshgrid(h_pos, w_pos, indexing="ij"), axis=-1)
    return grid.reshape(-1, grid.shape[-1])

@jax.jit
def bounding_box(h: int, w: int, pixel_aspect_ratio: float = 1.0):
    w_adj = w
    h_adj = h * pixel_aspect_ratio
    ar_adj = w_adj / h_adj
    y_min, y_max, x_min, x_max = -1.0, 1.0, -1.0, 1.0
    return jnp.select(
        [ar_adj > 1, ar_adj < 1],
        [jnp.array([-1 / ar_adj, 1 / ar_adj, x_min, x_max]),
         jnp.array([y_min, y_max, -ar_adj, ar_adj])],
        jnp.array([y_min, y_max, x_min, x_max])
    ).astype(jnp.float32)
    # return jnp.array([y_min, y_max, x_min, x_max])

@partial(jax.jit, static_argnames=("h", "w", "align_corners",))
def make_axial_pos(
    h: int,
    w: int,
    pixel_aspect_ratio: float = 1.0,
    align_corners: bool = False,
    dtype = None,
):
    y_min, y_max, x_min, x_max = bounding_box(h, w, pixel_aspect_ratio)
    if align_corners:
        h_pos = jnp.linspace(y_min, y_max, h, dtype=dtype)
        w_pos = jnp.linspace(x_min, x_max, w, dtype=dtype)
    else:
        h_pos = centers(jnp.asarray(y_min), jnp.asarray(y_max), h, dtype=dtype)
        w_pos = centers(jnp.asarray(x_min), jnp.asarray(x_max), w, dtype=dtype)
    return make_grid(h_pos, w_pos)

def freqs_pixel(max_freq: float = 10.0) -> Callable[[tuple], jnp.ndarray]:
    def init(shape: tuple) -> jnp.ndarray:
        freqs = jnp.linspace(1.0, max_freq / 2.0, shape[-1]) * jnp.pi
        return jnp.log(freqs).reshape((1,) * (len(shape) - 1) + (shape[-1],)).repeat(shape[0], 0)
    return init

def freqs_pixel_log(max_freq: float = 10.0) -> Callable[[tuple], jnp.ndarray]:
    def init(shape: tuple) -> jnp.ndarray:
        log_min = jnp.log(jnp.pi)
        log_max = jnp.log(max_freq * jnp.pi / 2.0)
        return jnp.linspace(log_min, log_max, shape[-1]).reshape((1,) * (len(shape) - 1) + (shape[-1],)).repeat(shape[0], 0)
    return init

class AxialRoPE(nnx.Module):
    def __init__(
        self,
        dim: int,
        n_heads: int,
        pos_dim: int = 2,
        start_index: int = 0,
        freqs_init: Callable[[tuple], jnp.ndarray] = freqs_pixel_log(max_freq=10.0),
    ):
        self.dim = dim
        self.n_heads = n_heads
        self.pos_dim = pos_dim
        self.start_index = start_index
        
        per_axis = self.dim // (2 * self.pos_dim)
        log_freqs = freqs_init((self.n_heads, per_axis, 1))
        self.log_freqs = nnx.Param(log_freqs)

    def __call__(self, x: jnp.ndarray, pos: jnp.ndarray):
        if pos.shape[-1] != self.pos_dim:
            raise ValueError(f"Expected pos_dim={self.pos_dim}, got {pos.shape[-1]}")

        freqs_axes = jnp.exp(self.log_freqs)
        freqs_axes = freqs_axes.repeat(self.pos_dim, axis=-1)
        freqs = pos[..., None, None, :] * freqs_axes
        freqs = freqs.reshape(*freqs.shape[:-2], -1)
        freqs = jnp.repeat(freqs, 2, axis=-1)
        return apply_rotary_emb(freqs, x, self.start_index)

if __name__ == '__main__':
    key = jax.random.PRNGKey(42)
    batch_size, seq_len, dim = 2, 16, 64
    n_heads = 8
    pos_dim = 2

    x = jax.random.normal(key, (batch_size, seq_len, n_heads, dim // n_heads))

    h, w = 4, 4
    pos = make_axial_pos(h, w)

    model = AxialRoPE(dim=dim // n_heads, n_heads=n_heads, pos_dim=pos_dim, start_index=0)
    output = model(x, pos)
    print(f'{output.shape = }')

output.shape = (2, 16, 8, 8)


In [21]:
from typing import Optional, Type

import jax
import jax.numpy as jnp
from jax import lax
from flax import nnx

class Identity(nnx.Module):
    def __call__(self, x):
        return x

class PatchEmbed(nnx.Module):
    def __init__(
        self,
        patch_size: int = 4,
        in_channels: int = 3,
        embed_dim: int = 512,
        *,
        norm_layer: Optional[Type[nnx.Module]] = None,
        flatten: bool = True,
        use_bias: bool = True,
        rngs: nnx.rnglib.Rngs,
    ):
        self.patch_size = patch_size
        self.in_channels = in_channels
        self.embed_dim = embed_dim
        self.flatten = flatten
        self.proj = nnx.Conv(
            in_features=in_channels,
            out_features=embed_dim,
            kernel_size=(patch_size, patch_size),
            strides=(patch_size, patch_size),
            padding='VALID',
            use_bias=use_bias,
            rngs=rngs,
        )
        self.norm = Identity() if norm_layer is None else norm_layer(embed_dim, rngs=rngs)

    def __call__(self, x: jnp.ndarray, pos_map: jnp.ndarray | None = None):
        b, _, h, w = x.shape  # NCHW
        x = jnp.transpose(x, (0, 2, 3, 1))  # NHWC
        x = self.proj(x)                    # (B, H/P, W/P, D)
        new_h, new_w = x.shape[1:3]

        if pos_map is not None:
            pos_map = jax.image.resize(
                pos_map, (b, new_h, new_w, pos_map.shape[-1]), method="bilinear"
            )
            pos_map = pos_map.reshape(b, new_h * new_w, -1)

        if self.flatten:
            x = x.reshape(b, new_h * new_w, -1)
            x = self.norm(x)
        else:
            x = self.norm(x)
            x = jnp.transpose(x, (0, 3, 1, 2))  # NCHW

        return x, pos_map


class UnPatch(nnx.Module):
    def __init__(
        self,
        patch_size: int = 4,
        input_dim: int = 512,
        out_channels: int = 3,
        *,
        proj: bool = True,
        rngs: nnx.rnglib.Rngs,
    ):
        self.patch_size = patch_size
        self.input_dim = input_dim
        self.out_channels = out_channels
        self.linear = (
            Identity()
            if not proj
            else nnx.Linear(
                self.input_dim,
                self.patch_size**2 * self.out_channels,
                rngs=rngs
            )
        )

    def __call__(
        self,
        x: jnp.ndarray,
        axis1: int | None = None,
        axis2: int | None = None,
        loss_mask: jnp.ndarray | None = None,
    ):
        b, n, _ = x.shape
        p = q = self.patch_size

        if axis1 is None and axis2 is None:
            w = h = int(jnp.sqrt(n).item())
            assert h * w == n
        else:
            h = axis1 // p if axis1 else n // (axis2 // p)
            w = axis2 // p if axis2 else n // h
            assert h * w == n

        x = self.linear(x)  # (B, N, p^2 * C)

        if loss_mask is not None:
            loss_mask = loss_mask[..., None]
            x = jnp.where(loss_mask, x, lax.stop_gradient(x))

        x = x.reshape(b, h, w, p, q, self.out_channels)
        x = x.transpose(0, 5, 1, 3, 2, 4)
        x = x.reshape(b, self.out_channels, h * p, w * q)
        return x

In [26]:
import jax
import jax.numpy as jnp
from flax import nnx

class TimestepEmbedding(nnx.Module):
    def __init__(
        self,
        dim: int,
        max_period: float = 10000.0,
        time_factor: float = 1000.0,
        *,
        rngs: nnx.rnglib.Rngs,
    ):
        self.dim = dim
        self.max_period = max_period
        self.time_factor = time_factor
        self.freqs = jnp.exp(-jnp.log(self.max_period) * jnp.arange(0, self.dim // 2, dtype=jnp.float32) / (self.dim // 2))[None, :]
        self.proj = nnx.Linear(self.dim, self.dim, rngs=rngs)

    def __call__(self, t: jnp.ndarray) -> jnp.ndarray:
        if t.ndim > 1:
            t = jnp.squeeze(t, -1)
        t = self.time_factor * t
        args = t[:, None] * self.freqs  # (B, dim//2)
        embedding = jnp.concatenate([jnp.cos(args), jnp.sin(args)], axis=-1)

        # padding
        if self.dim % 2 == 1:
            embedding = jnp.concatenate(
                [embedding, jnp.zeros((embedding.shape[0], 1), dtype=embedding.dtype)], axis=-1
            )

        x = self.proj(embedding)
        x = jax.nn.mish(x)
        return x

In [59]:
from typing import Optional, Union, Callable
import jax
import jax.numpy as jnp
from flax import nnx

#from .axial_rope import AxialRoPE

MaskType = Union[jnp.ndarray, None]

def _split_heads(x, n_heads):
    B, N, D = x.shape
    H = n_heads
    return x.reshape(B, N, H, D // H)

def _merge_heads(x):
    B, N, H, D = x.shape
    return x.reshape(B, N, H * D)

def _prepare_mask_bias(mask: MaskType, B: int, H: int, q_len: int, kv_len: int):
    if mask is None:
        return {'mask': None, 'bias': None}
    arr = jnp.asarray(mask)

    def _broadcast(arr):
        # Supported input shapes:
        #  - (q_len, kv_len): e.g. causal mask shared across batch
        #  - (B, kv_len): padding mask
        #  - (B, q_len, kv_len): per-example mask
        #  - (B, H, q_len, kv_len): per-head mask/bias
        if arr.ndim == 2:
            if arr.shape == (B, kv_len):
                arr = arr[:, None, None, :]
            else:
                arr = arr[None, None, :, :]
        elif arr.ndim == 3:
            arr = arr[:, None, :, :]
        elif arr.ndim != 4:
            raise ValueError(f"Unsupported mask/bias shape {arr.shape}")
        return jnp.broadcast_to(arr, (B, H, q_len, kv_len))

    if arr.dtype == jnp.bool_:
        return {'mask': _broadcast(arr), 'bias': None}
    return {'mask': None, 'bias': _broadcast(arr)}

class SelfAttention(nnx.Module):
    def __init__(
        self,
        dim: int,
        n_heads: int = 8,
        head_dim: int = -1,
        pos_dim: int = 2,
        *,
        rope_factory: Optional[Callable[[int, int, int], nnx.Module]] = None, # AxialRoPE alternative
        rngs: nnx.rnglib.Rngs,
    ):
        self.head_dim = head_dim if head_dim > 0 else dim // n_heads
        self.pos_dim = pos_dim
        self.n_heads = dim // self.head_dim
        self.qkv = nnx.Linear(dim, dim * 3, use_bias=False, rngs=rngs)
        self.out = nnx.Linear(dim, dim, rngs=rngs)

        factory = rope_factory or (lambda d, h, p: AxialRoPE(d, h, p))
        self.rope = factory(self.head_dim, self.n_heads, self.pos_dim)

    def __call__(self, x, pos_map=None, mask=None, deterministic=True):
        b, n, _ = x.shape
        qkv = self.qkv(x)
        q, k, v = jnp.split(qkv, 3, axis=-1)

        q = _split_heads(q, self.n_heads)
        k = _split_heads(k, self.n_heads)
        v = _split_heads(v, self.n_heads)

        if pos_map is not None:
            # pos_map expected shape: (b, n, pos_dim)
            q = self.rope(q, pos_map)
            k = self.rope(k, pos_map)

        mask_bias = _prepare_mask_bias(mask, b, self.n_heads, n, n)
        attn = nnx.dot_product_attention(
            q, k, v,
            deterministic=deterministic,
            **mask_bias,
        )
        attn = _merge_heads(attn)
        return self.out(attn)

class CrossAttention(nnx.Module):
    def __init__(
        self,
        dim: int,
        ctx_dim: int,
        n_heads: int = 8,
        head_dim: int = -1,
        pos_dim: int = 2,
        *,
        rope_factory: Optional[Callable[[int, int, int], nnx.Module]] = None, # AxialRoPE alternative
        rngs: nnx.rnglib.Rngs,
    ):
        self.head_dim = head_dim if head_dim > 0 else dim // n_heads
        self.pos_dim = pos_dim
        self.n_heads = dim // self.head_dim
        self.q = nnx.Linear(dim, dim, use_bias=False, rngs=rngs)
        self.kv = nnx.Linear(ctx_dim, dim * 2, use_bias=False, rngs=rngs)
        self.out = nnx.Linear(dim, dim, rngs=rngs)

        factory = rope_factory or (lambda d, h, p: AxialRoPE(d, h, p))
        self.rope = factory(self.head_dim, self.n_heads, self.pos_dim)

    def __call__(self, x, ctx, pos_map=None, ctx_pos_map=None, mask=None, deterministic=True):
        b, n, _ = x.shape
        ctx_n = ctx.shape[1]

        q = _split_heads(self.q(x), self.n_heads)
        k, v = jnp.split(self.kv(ctx), 2, axis=-1)
        k = _split_heads(k, self.n_heads)
        v = _split_heads(v, self.n_heads)

        if pos_map is not None:
            q = self.rope(q, pos_map)
        if ctx_pos_map is not None:
            k = self.rope(k, ctx_pos_map)

        mask_bias = _prepare_mask_bias(mask, b, self.n_heads, n, ctx_n)

        attn = nnx.dot_product_attention(
            q, k, v,
            deterministic=deterministic,
            **mask_bias,
        )
        attn = _merge_heads(attn)
        return self.out(attn)

if __name__ == '__main__':
    key = jax.random.PRNGKey(42)
    key, pkey_sa, pkey_ca, xkey, ckey = jax.random.split(key, 5)

    B, N, D = 2, 16, 64
    x = jax.random.normal(xkey, (B, N, D))
    ctx = jax.random.normal(ckey, (B, N, D))

    sa = SelfAttention(dim=D, n_heads=8, rngs=nnx.Rngs(pkey_sa))
    sa_out = sa(x)
    print(f'{sa_out.shape = }')

    ca = CrossAttention(dim=D, ctx_dim=D, n_heads=8, rngs=nnx.Rngs(pkey_ca))
    ca_out = ca(x, ctx)
    print(f'{ca_out.shape = }')


sa_out.shape = (2, 16, 64)
ca_out.shape = (2, 16, 64)


In [62]:
from typing import Optional, Type

import jax
import jax.numpy as jnp
from flax import nnx

# from .layers import SwiGLU
# from .attention import SelfAttention, CrossAttention
# from .adaln import AdaLN

class TransformerBlock(nnx.Module):
    def __init__(
        self,
        dim: int,
        ctx_dim: Optional[int],
        heads: int,
        dim_head: int,
        mlp_dim: int,
        pos_dim: int,
        *,
        use_adaln: bool = False,
        use_shared_adaln: bool = False,
        ctx_from_self: bool = False,
        norm_layer: Type[nnx.Module] = nnx.RMSNorm,
        rngs: nnx.rnglib.Rngs,
    ):
        self.use_adaln = use_adaln
        self.ctx_from_self = ctx_from_self
        self.sa = SelfAttention(dim, heads, dim_head, pos_dim, rngs=rngs)
        self.xa = CrossAttention(dim, ctx_dim, heads, dim_head, pos_dim, rngs=rngs) if ctx_dim is not None else None
        self.mlp = SwiGLU(dim, mlp_dim, dim, rngs=rngs)

        if self.use_adaln:
            self.sa_pre = AdaLN(dim, dim, norm_layer=norm_layer, shared=use_shared_adaln, rngs=rngs)
            self.xa_pre = AdaLN(dim, dim, norm_layer=norm_layer, shared=use_shared_adaln, rngs=rngs) if self.xa else None
            self.mlp_pre = AdaLN(dim, dim, norm_layer=norm_layer, shared=use_shared_adaln, rngs=rngs)
        else:
            self.sa_pre = norm_layer(dim, rngs=rngs)
            self.xa_pre = norm_layer(dim, rngs=rngs) if self.xa else None
            self.mlp_pre = norm_layer(dim, rngs=rngs)

    def __call__(
        self,
        x,
        ctx=None,
        pos_map=None,
        ctx_pos_map=None,
        y=None,
        x_mask=None,
        ctx_mask=None,
        shared_adaln=None,
        deterministic: bool = True,
    ):
        # Self-Attention
        if self.use_adaln:
            sa_shared = shared_adaln[0] if shared_adaln is not None else None
            normed_x, gate = self.sa_pre(x, y, shared_adaln=sa_shared)
        else:
            normed_x = self.sa_pre(x)
            gate = jnp.ones_like(normed_x)
        if gate.ndim < x.ndim:
            gate = jnp.expand_dims(gate, axis=tuple(range(1, x.ndim - gate.ndim + 1)))
        x = x + gate * self.sa(
            normed_x,
            pos_map=pos_map,
            mask=x_mask,
            deterministic=deterministic
        )

        # Cross-Attention
        if self.xa is not None:
            if self.use_adaln:
                xa_shared = shared_adaln[1] if shared_adaln is not None else None
                normed_x, gate = self.xa_pre(x, y, shared_adaln=xa_shared)
            else:
                normed_x = self.xa_pre(x)
                gate = jnp.ones_like(normed_x)
            if gate.ndim < x.ndim:
                gate = jnp.expand_dims(gate, axis=tuple(range(1, x.ndim - gate.ndim + 1)))
            xa_ctx = x if self.ctx_from_self else ctx
            xa_mask = x_mask if self.ctx_from_self else ctx_mask
            x = x + gate * self.xa(
                normed_x, xa_ctx,
                pos_map=pos_map,
                ctx_pos_map=ctx_pos_map,
                mask=xa_mask,
                deterministic=deterministic
            )

        # MLP
        if self.use_adaln:
            mlp_shared = shared_adaln[2] if shared_adaln is not None else None
            normed_x, gate = self.mlp_pre(x, y, shared_adaln=mlp_shared)
        else:
            normed_x = self.mlp_pre(x)
            gate = jnp.ones_like(normed_x)
        if gate.ndim < x.ndim:
            gate = jnp.expand_dims(gate, axis=tuple(range(1, x.ndim - gate.ndim + 1)))
        x = x + gate * self.mlp(normed_x)
        return x

if __name__ == '__main__':
    key = jax.random.PRNGKey(0)
    batch_size = 2
    seq_len = 16
    dim = 64
    y_dim = 64
    mlp_dim = 128
    heads = 8
    head_dim = 8
    pos_dim = 2

    x = jax.random.normal(key, (batch_size, seq_len, dim))
    ctx = jax.random.normal(key, (batch_size, seq_len, y_dim))

    print("Self-Attention Block")
    sa_block = TransformerBlock(
        dim=dim,
        ctx_dim=None,  # w/o ctx
        heads=heads,
        dim_head=head_dim,
        mlp_dim=mlp_dim,
        pos_dim=pos_dim,
        rngs=nnx.Rngs(key),
    )
    sa_out = sa_block(x, deterministic=True)
    print(x.shape, "->", sa_out.shape)
    assert sa_out.shape == x.shape

    print("Cross-Attention Block")
    xa_block = TransformerBlock(
        dim=dim,
        ctx_dim=y_dim,  # w/ ctx
        heads=heads,
        dim_head=head_dim,
        mlp_dim=mlp_dim,
        pos_dim=pos_dim,
        rngs=nnx.Rngs(key),
    )
    xa_out = xa_block(x, ctx=ctx, deterministic=True)
    print(x.shape, "->", xa_out.shape)
    assert xa_out.shape == x.shape

    print("AdaLN Block")
    adaln_block = TransformerBlock(
        dim=dim,
        ctx_dim=y_dim,
        heads=heads,
        dim_head=head_dim,
        mlp_dim=mlp_dim,
        pos_dim=pos_dim,
        use_adaln=True,
        use_shared_adaln=False,
        rngs=nnx.Rngs(key),
    )
    adaln_out = adaln_block(x, ctx=ctx, y=x, deterministic=True)
    print(x.shape, "->", adaln_out.shape)
    assert adaln_out.shape == x.shape

    print("Shared AdaLN Block")
    shared_adaln_block = TransformerBlock(
        dim=dim,
        ctx_dim=y_dim,
        heads=heads,
        dim_head=head_dim,
        mlp_dim=mlp_dim,
        pos_dim=pos_dim,
        use_adaln=True,
        use_shared_adaln=True,
        rngs=nnx.Rngs(key),
    )
    shared_adaln = (
        (jnp.zeros_like(x), jnp.zeros_like(x), jnp.zeros_like(x)),  # sa
        (jnp.zeros_like(x), jnp.zeros_like(x), jnp.zeros_like(x)),  # xa
        (jnp.zeros_like(x), jnp.zeros_like(x), jnp.zeros_like(x)),  # mlp
    )
    out_shared = shared_adaln_block(
        x,
        ctx=ctx,
        y=x,
        shared_adaln=shared_adaln,
        deterministic=True,
    )
    print(x.shape, "->", out_shared.shape)
    assert out_shared.shape == x.shape

Self-Attention Block
(2, 16, 64) -> (2, 16, 64)
Cross-Attention Block
(2, 16, 64) -> (2, 16, 64)
AdaLN Block
(2, 16, 64) -> (2, 16, 64)
Shared AdaLN Block
(2, 16, 64) -> (2, 16, 64)


In [None]:
from typing import Optional, Sequence, Tuple, List

import jax
import jax.numpy as jnp
from flax import nnx
from flax.nnx import RMSNorm

# from .modules.transformer import TransformerBlock
# from .modules.patch import PatchEmbed, UnPatch
# from .modules.axial_rope import make_axial_pos
# from .modules.time_emb import TimestepEmbedding

def _isiterable(x) -> bool:
    try:
        iter(x)
    except TypeError:
        return False
    return True

class _SharedAdaLNHead(nnx.Module):
    def __init__(self, dim: int, *, rngs: nnx.rnglib.Rngs):
        self.dim = dim
        self.norm = nnx.LayerNorm(num_features=dim, rngs=rngs)
        self.fc1 = nnx.Linear(dim, dim * 4, rngs=rngs)
        self.fc2 = nnx.Linear(
            dim * 4,
            dim * 3,
            kernel_init=nnx.initializers.zeros_init(),
            bias_init=nnx.initializers.zeros_init(),
            rngs=rngs,
        )

    def __call__(self, x: jnp.ndarray) -> jnp.ndarray:
        y = self.norm(x)
        y = self.fc1(y)
        y = jax.nn.mish(y)
        y = self.fc2(y)
        return y  # (B, dim*3)

def _chunk3(v: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]:
    d3 = v.shape[-1]
    assert d3 % 3 == 0
    d = d3 // 3
    return v[..., :d], v[..., d:2*d], v[..., 2*d:]

class TBackBone(nnx.Module):
    def __init__(
        self,
        dim: int = 1024,
        ctx_dim: Optional[int] = 1024,
        heads: int = 16,
        dim_head: int = 64,
        mlp_dim: int = 3072,
        pos_dim: int = 2,
        depth: int = 8,
        use_adaln: bool = False,
        use_shared_adaln: bool = False,
        use_dyt: bool = False,
        grad_ckpt: bool = False,  # remat (not used)
        *,
        rngs: nnx.rnglib.Rngs,
    ):
        norm_layer = RMSNorm
        self.use_adaln = use_adaln
        self.use_shared_adaln = use_shared_adaln
        self.blocks = tuple([
            TransformerBlock(
                dim=dim,
                ctx_dim=ctx_dim,
                heads=heads,
                dim_head=dim_head,
                mlp_dim=mlp_dim,
                pos_dim=pos_dim,
                use_adaln=use_adaln,
                use_shared_adaln=use_shared_adaln,
                norm_layer=norm_layer,
                rngs=rngs,
            )
            for _ in range(depth)
        ])

    def __call__(
        self,
        x: jnp.ndarray,
        ctx: Optional[jnp.ndarray] = None,
        x_mask: Optional[jnp.ndarray] = None,
        ctx_mask: Optional[jnp.ndarray] = None,
        pos_map: Optional[jnp.ndarray] = None,
        y: Optional[jnp.ndarray] = None,
        shared_adaln: Optional[Sequence[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]] = None,
        deterministic: bool = True,
    ) -> jnp.ndarray:
        for block in self.blocks:
            x = block(
                x,
                ctx=ctx,
                pos_map=pos_map,
                ctx_pos_map=None,
                y=y,
                x_mask=x_mask,
                ctx_mask=ctx_mask,
                shared_adaln=shared_adaln,
                deterministic=deterministic
            )
        return x

class XUTBackBone(nnx.Module):
    def __init__(
        self,
        dim: int = 1024,
        ctx_dim: Optional[int] = None,
        heads: int = 16,
        dim_head: int = 64,
        mlp_dim: int = 3072,
        pos_dim: int = 2,
        depth: int = 8,
        enc_blocks: int | Sequence[int] = 1,
        dec_blocks: int | Sequence[int] = 2,
        dec_ctx: bool = False,
        use_adaln: bool = False,
        use_shared_adaln: bool = False,
        use_dyt: bool = False,
        grad_ckpt: bool = False,
        *,
        rngs: nnx.rnglib.Rngs,
    ):
        norm_layer = RMSNorm
        self.dim = dim
        self.ctx_dim = ctx_dim
        self.heads = heads
        self.dim_head = dim_head
        self.mlp_dim = mlp_dim
        self.pos_dim = pos_dim
        self.depth = depth
        self.dec_ctx = dec_ctx
        self.use_adaln = use_adaln
        self.use_shared_adaln = use_shared_adaln

        if _isiterable(enc_blocks):
            enc_list = list(enc_blocks)
            assert len(enc_list) == depth
        else:
            enc_list = [int(enc_blocks)] * depth

        if _isiterable(dec_blocks):
            dec_list = list(dec_blocks)
            assert len(dec_list) == depth
        else:
            dec_list = [int(dec_blocks)] * depth

        def mk_block(ctx_dim_inner, ctx_from_self: bool = False):
            return TransformerBlock(
                dim=dim,
                ctx_dim=ctx_dim_inner,
                heads=heads,
                dim_head=dim_head,
                mlp_dim=mlp_dim,
                pos_dim=pos_dim,
                use_adaln=use_adaln,
                use_shared_adaln=use_shared_adaln,
                ctx_from_self=ctx_from_self,
                norm_layer=norm_layer,
                rngs=rngs,
            )

        self.enc_stacks = tuple([
            tuple(mk_block(ctx_dim) for _ in range(enc_list[i]))
            for i in range(depth)
        ])

        self.dec_stacks = tuple([
            tuple(
                mk_block(dim, ctx_from_self=True) if bid == 0
                else mk_block(ctx_dim if dec_ctx else None)
                for bid in range(dec_list[i])
            )
            for i in range(depth)
        ])

    def __call__(
        self,
        x: jnp.ndarray,
        ctx: Optional[jnp.ndarray] = None,
        x_mask: Optional[jnp.ndarray] = None,
        ctx_mask: Optional[jnp.ndarray] = None,
        pos_map: Optional[jnp.ndarray] = None,
        y: Optional[jnp.ndarray] = None,
        shared_adaln: Optional[Sequence[Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]]] = None,
        return_enc_out: bool = False,
        deterministic: bool = True,
    ) -> jnp.ndarray | Tuple[jnp.ndarray, jnp.ndarray]:
        self_ctx: List[jnp.ndarray] = []
        for blocks in self.enc_stacks:
            for block in blocks:
                x = block(x,
                        ctx=ctx,
                        pos_map=pos_map,
                        ctx_pos_map=None,
                        y=y,
                        x_mask=x_mask,
                        ctx_mask=ctx_mask,
                        shared_adaln=shared_adaln,
                        deterministic=deterministic)
            self_ctx.append(x)
        enc_out = x

        for blocks in self.dec_stacks:
            # first block cross attends to last encoder output
            first = blocks[0]
            x = first(x,
                      ctx=self_ctx[-1],
                      pos_map=pos_map,
                      ctx_pos_map=pos_map,
                      y=y,
                      x_mask=x_mask,
                      ctx_mask=ctx_mask,
                      shared_adaln=shared_adaln,
                      deterministic=deterministic)
            for block in blocks[1:]:
                x = block(x,
                        ctx=ctx if self.dec_ctx else None,
                        pos_map=pos_map,
                        ctx_pos_map=None,
                        y=y,
                        x_mask=x_mask,
                        ctx_mask=ctx_mask,
                        shared_adaln=shared_adaln,
                        deterministic=deterministic)

        if return_enc_out:
            return x, enc_out
        return x

class XUDiT(nnx.Module):
    def __init__(
        self,
        patch_size: int = 2,
        input_dim: int = 4,
        dim: int = 1024,
        ctx_dim: Optional[int] = 1024,
        ctx_size: int = 256,
        heads: int = 16,
        dim_head: int = 64,
        mlp_dim: int = 3072,
        depth: int = 8,
        enc_blocks: int | Sequence[int] = 1,
        dec_blocks: int | Sequence[int] = 2,
        dec_ctx: bool = False,
        class_cond: int = 0,
        shared_adaln: bool = True,
        concat_ctx: bool = True,
        use_dyt: bool = False,
        double_t: bool = False,
        addon_info_embs_dim: Optional[int] = None,
        tread_config: Optional[dict] = None,
        grad_ckpt: bool = False,
        *,
        rngs: nnx.rnglib.Rngs,
    ):
        self.rngs = rngs
        self.patch_size = patch_size
        self.input_dim = input_dim
        self.dim = dim
        self.ctx_dim = ctx_dim
        self.ctx_size = ctx_size
        self.heads = heads
        self.dim_head = dim_head
        self.mlp_dim = mlp_dim
        self.depth = depth
        self.enc_blocks = enc_blocks
        self.dec_blocks = dec_blocks
        self.dec_ctx = dec_ctx
        self.class_cond = class_cond
        self.shared_adaln = shared_adaln
        self.concat_ctx = concat_ctx
        self.use_dyt = use_dyt
        self.double_t = double_t
        self.addon_info_embs_dim = addon_info_embs_dim
        self.tread_config = tread_config
        self.grad_ckpt = grad_ckpt

        self.backbone = XUTBackBone(
            dim=self.dim,
            ctx_dim=(None if self.concat_ctx else self.ctx_dim),
            heads=self.heads,
            dim_head=self.dim_head,
            mlp_dim=self.mlp_dim,
            pos_dim=2,
            depth=self.depth,
            enc_blocks=self.enc_blocks,
            dec_blocks=self.dec_blocks,
            dec_ctx=self.dec_ctx,
            use_adaln=True,
            use_shared_adaln=self.shared_adaln,
            use_dyt=self.use_dyt,
            grad_ckpt=self.grad_ckpt,
            rngs=rngs,
        )

        self.use_tread = self.tread_config is not None
        if self.use_tread:
            dr = float(self.tread_config["dropout_ratio"])
            prev_d = int(self.tread_config["prev_trns_depth"])
            post_d = int(self.tread_config["post_trns_depth"])
            self.dropout_ratio = dr
            self.prev_tread_trns = TBackBone(
                dim=self.dim,
                ctx_dim=(None if self.concat_ctx else self.ctx_dim),
                heads=self.heads,
                dim_head=self.dim_head,
                mlp_dim=self.mlp_dim,
                pos_dim=2,
                depth=prev_d,
                use_adaln=True,
                use_shared_adaln=self.shared_adaln,
                use_dyt=self.use_dyt,
                grad_ckpt=self.grad_ckpt,
                rngs=rngs,
            )
            self.post_tread_trns = TBackBone(
                dim=self.dim,
                ctx_dim=(None if self.concat_ctx else self.ctx_dim),
                heads=self.heads,
                dim_head=self.dim_head,
                mlp_dim=self.mlp_dim,
                pos_dim=2,
                depth=post_d,
                use_adaln=True,
                use_shared_adaln=self.shared_adaln,
                use_dyt=self.use_dyt,
                grad_ckpt=self.grad_ckpt,
                rngs=rngs,
            )

        self.patch_size_ = self.patch_size
        self.in_patch = PatchEmbed(
            patch_size=self.patch_size,
            in_channels=self.input_dim,
            embed_dim=self.dim,
            rngs=rngs,
        )
        self.out_patch = UnPatch(
            patch_size=self.patch_size,
            input_dim=self.dim,
            out_channels=self.input_dim,
            rngs=rngs,
        )

        self.time_emb = TimestepEmbedding(self.dim, rngs=rngs)
        if self.double_t:
            self.r_emb = TimestepEmbedding(self.dim, rngs=rngs)

        if self.shared_adaln:
            self.shared_adaln_attn = _SharedAdaLNHead(self.dim, rngs=rngs)
            self.shared_adaln_xattn = _SharedAdaLNHead(self.dim, rngs=rngs)
            self.shared_adaln_ffw = _SharedAdaLNHead(self.dim, rngs=rngs)

        if self.class_cond > 0:
            self.class_token = nnx.Embed(num_embeddings=self.class_cond, features=self.dim, rngs=rngs)
        else:
            self.class_token = None

        if self.concat_ctx and (self.ctx_dim is not None):
            self.ctx_proj = nnx.Linear(self.ctx_dim, self.dim, rngs=rngs)
        else:
            self.ctx_proj = None

        if self.addon_info_embs_dim is not None:
            self.addon_info_embs_proj_1 = nnx.Linear(self.addon_info_embs_dim, self.dim, rngs=rngs)
            self.addon_info_embs_proj_2 = nnx.Linear(
                self.dim,
                self.dim,
                kernel_init=nnx.initializers.zeros_init(),
                bias_init=nnx.initializers.zeros_init(),
                rngs=rngs,
            )

    def _make_shared_adaln_state(self, t_emb: jnp.ndarray):
        attn = self.shared_adaln_attn(t_emb)
        xattn = self.shared_adaln_xattn(t_emb)
        ffw = self.shared_adaln_ffw(t_emb)
        return [_chunk3(attn), _chunk3(xattn), _chunk3(ffw)]

    def __call__(self,
        x: jnp.ndarray,                           # NHWC (B, H, W, C)
        t: jnp.ndarray,                           # (B,) | (B, 1)
        ctx: Optional[jnp.ndarray] = None,        # (B, T_ctx, ctx_dim) | None | (B,) class ids if class_cond>0
        pos_map: Optional[jnp.ndarray] = None,    # (B, N, 2)
        r: Optional[jnp.ndarray] = None,
        addon_info: Optional[jnp.ndarray] = None, # (B, D_addon) or (B,)
        tread_rate: Optional[float] = None,
        return_enc_out: bool = False,
        deterministic: bool = True,
    ):
        """
        -> (B, C, H, W)
        """
        B, H, W, C = x.shape
        x_seq, pos_map_resized = self.in_patch(x, pos_map)  # x_seq: (B, N, D), pos_map_resized: (B, N, 2) or None
        N = x_seq.shape[1]
        if pos_map_resized is None:
            pos_map_resized = make_axial_pos(H // self.patch_size_,
                                             W // self.patch_size_)
            pos_map_resized = jnp.broadcast_to(pos_map_resized[None, ...],
                                               (B, pos_map_resized.shape[0], pos_map_resized.shape[1]))

        t_emb = self.time_emb(t)
        if r is not None:
            r = jnp.reshape(r, (B, -1))
            t_emb = t_emb + self.r_emb((t.reshape(B, -1) - r))

        if (self.class_token is not None) and (ctx is not None):
            if ctx.ndim == 1:
                cls_ids = ctx
            else:
                cls_ids = ctx.reshape((ctx.shape[0], -1))[:, 0]
            t_emb = t_emb + self.class_token(cls_ids)
            ctx = None

        if addon_info is not None:
            if addon_info.ndim == 1:
                addon_info = addon_info[:, None]
            addon_embs = jax.nn.mish(self.addon_info_embs_proj_1(addon_info))
            addon_embs = self.addon_info_embs_proj_2(addon_embs)
            addon_embs = addon_embs[:, None, :]  # (B,1,D)
            t_emb = t_emb + addon_embs

        need_ctx = (self.ctx_dim is not None)
        if (ctx is None) and need_ctx and (not self.concat_ctx):
            ctx = jnp.zeros((B, self.ctx_size, self.ctx_dim), dtype=x_seq.dtype)

        shared_adaln_state = None
        if self.shared_adaln:
            shared_adaln_state = self._make_shared_adaln_state(t_emb)

        length = x_seq.shape[1]

        if self.ctx_proj is not None and ctx is not None:
            ctx_proj = self.ctx_proj(ctx)
            x_seq = jnp.concatenate([x_seq, ctx_proj], axis=1)
            pad_pos = jnp.zeros((B, ctx_proj.shape[1], pos_map_resized.shape[-1]), dtype=pos_map_resized.dtype)
            pos_map_resized = jnp.concatenate([pos_map_resized, pad_pos], axis=1)
            ctx = None

        # TREAD (pre)
        if self.use_tread:
            x_seq = self.prev_tread_trns(
                x_seq, ctx=ctx, pos_map=pos_map_resized, y=t_emb,
                shared_adaln=shared_adaln_state, deterministic=deterministic
            )
            do_tread = (not deterministic) or (tread_rate is not None)
            if do_tread:
                rate = (tread_rate if tread_rate is not None else self.dropout_ratio)
                keep_len = length - int(length * rate)
                key = self.rngs.dropout()
                perm = jax.vmap(lambda k: jax.random.permutation(k, length))(jax.random.split(key, B))
                sel_mask = perm < keep_len
                if self.ctx_proj is not None:
                    ctx_len = x_seq.shape[1] - length
                    ctx_keep = jnp.ones((B, ctx_len), dtype=sel_mask.dtype)
                    sel_mask = jnp.concatenate([sel_mask, ctx_keep], axis=1)
                    keep_len = keep_len + ctx_len

                def _gather_sel(x_row, m_row):
                    idx = jnp.nonzero(m_row, size=m_row.shape[0], fill_value=0)[0]
                    idx = idx[:keep_len]
                    return x_row[idx]
                x_not_sel = jnp.where(sel_mask[..., None], jnp.zeros_like(x_seq), x_seq)
                x_sel = jax.vmap(_gather_sel)(x_seq, sel_mask)
                x_seq = x_sel
                raw_pos = pos_map_resized
                pos_map_resized = jax.vmap(_gather_sel)(pos_map_resized, sel_mask)

        # Backbone
        out = self.backbone(
            x_seq, ctx=ctx, pos_map=pos_map_resized, y=t_emb,
            shared_adaln=shared_adaln_state, return_enc_out=return_enc_out,
            deterministic=deterministic
        )
        if return_enc_out:
            out, enc_out = out

        # TREAD (post)
        if self.use_tread:
            if (not deterministic) or (tread_rate is not None):
                full_len = (raw_pos.shape[1])
                def _scatter_back(x_row, sel_mask_row):
                    idx = jnp.nonzero(sel_mask_row, size=sel_mask_row.shape[0], fill_value=0)[0]
                    idx = idx[:x_row.shape[0]]
                    base = jnp.zeros((full_len, x_row.shape[-1]), dtype=x_row.dtype)
                    return base.at[idx].set(x_row)
                out_full = jax.vmap(_scatter_back)(out, sel_mask)
                out = out_full
                pos_map_resized = raw_pos

            out = self.post_tread_trns(
                out, ctx=ctx, pos_map=pos_map_resized, y=t_emb,
                shared_adaln=shared_adaln_state, deterministic=deterministic
            )

        out = out[:, :length]
        img = self.out_patch(out, axis1=H, axis2=W)
        return img if not return_enc_out else (img, enc_out[:, :length])

In [22]:
import jax
import jax.numpy as jnp

key = jax.random.PRNGKey(42)

B = 2          # batch size
H = 32         # height
W = 32         # width
C = 4          # input channels (예: latent image or feature map)
D = 64         # transformer dim 줄여서 테스트
ctx_dim = 32   # context embedding dim
ctx_size = 8   # context length (for conditioning)

x = jax.random.normal(key, (B, H, W, C))
t = jnp.linspace(0, 1, B)  # timestep embedding input (B,)
ctx = jax.random.normal(key, (B, ctx_size, ctx_dim))

model = XUDiT(
    patch_size=4,
    input_dim=C,
    dim=D,
    ctx_dim=ctx_dim,
    ctx_size=ctx_size,
    heads=4,
    dim_head=16,
    mlp_dim=128,
    depth=2,
    enc_blocks=1,
    dec_blocks=1,
    class_cond=0,
    shared_adaln=True,
    concat_ctx=True,
    use_dyt=False,
    double_t=False,
    addon_info_embs_dim=None,
    tread_config=None,  # TREAD off
)

# forward pass
variables = model.init(key, x, t, ctx)
out = model.apply(variables, x, t, ctx)

print("✅ XUDiT output shape:", out.shape)
assert out.shape == (B, C, H, W)
print("✅ Test passed — output shape matches input (NCHW).")


✅ XUDiT output shape: (2, 4, 32, 32)
✅ Test passed — output shape matches input (NCHW).


In [23]:
import jax
import jax.numpy as jnp

key = jax.random.PRNGKey(42)

B, H, W, C = 2, 32, 32, 4
D = 64
ctx_dim, ctx_size = 32, 8

x = jax.random.normal(key, (B, H, W, C))
t = jnp.linspace(0, 1, B)
ctx = jax.random.normal(key, (B, ctx_size, ctx_dim))

model = XUDiT(
    patch_size=4,
    input_dim=C,
    dim=D,
    ctx_dim=ctx_dim,
    ctx_size=ctx_size,
    heads=4,
    dim_head=16,
    mlp_dim=128,
    depth=2,
    enc_blocks=1,
    dec_blocks=1,
    class_cond=0,
    shared_adaln=True,
    concat_ctx=True,
    use_dyt=False,
    double_t=False,
    addon_info_embs_dim=None,
    tread_config=None,
)

# Initialize parameters
variables = model.init(key, x, t, ctx)

# Define dummy loss (MSE)
def loss_fn(params, x, t, ctx):
    out = model.apply(params, x, t, ctx)
    loss = jnp.mean((out - x.transpose(0, 3, 1, 2)) ** 2)  # match NCHW order
    return loss

# Compute loss and gradient
loss_val, grads = jax.value_and_grad(loss_fn)(variables, x, t, ctx)

# ✅ Sanity check
print("Loss:", float(loss_val))
# Check gradient finiteness
def has_nan(x):
    return jnp.isnan(x).any()

nan_in_grad = jax.tree_util.tree_reduce(lambda a, b: a or has_nan(b), grads, False)
print("✅ Gradients finite:", not bool(nan_in_grad))


Loss: 9.021492004394531
✅ Gradients finite: True


In [24]:
import jax
import jax.numpy as jnp
import optax
from flax.training import train_state

# ============================================================
# 1. 모델 초기화
# ============================================================
B, H, W, C = 2, 16, 16, 4
D = 32
ctx_dim, ctx_size = 16, 4

key = jax.random.PRNGKey(0)
x = jax.random.normal(key, (B, H, W, C))
t = jnp.linspace(0, 1, B)
ctx = jax.random.normal(key, (B, ctx_size, ctx_dim))

model = XUDiT(
    patch_size=4,
    input_dim=C,
    dim=D,
    ctx_dim=ctx_dim,
    ctx_size=ctx_size,
    heads=2,
    dim_head=8,
    mlp_dim=64,
    depth=1,
    enc_blocks=1,
    dec_blocks=1,
    class_cond=0,
    shared_adaln=True,
    concat_ctx=True,
    use_dyt=False,
    double_t=False,
    addon_info_embs_dim=None,
    tread_config=None,
)

variables = model.init(key, x, t, ctx)

# ============================================================
# 2. Optimizer + Train state
# ============================================================
tx = optax.adamw(learning_rate=1e-3)

class TrainState(train_state.TrainState):
    pass

state = TrainState.create(apply_fn=model.apply, params=variables['params'], tx=tx)

# ============================================================
# 3. Diffusion-style loss 함수
# ============================================================
def loss_fn(params, x, t, ctx, key):
    # “ground truth noise”를 ε로 두고, 모델이 ε_pred를 맞추는 형태
    noise = jax.random.normal(key, x.shape)
    x_noisy = x + 0.1 * noise  # 약간 노이즈 추가
    pred = model.apply({'params': params}, x_noisy, t, ctx)
    loss = jnp.mean((pred - noise.transpose(0, 3, 1, 2)) ** 2)
    return loss

# ============================================================
# 4. 학습 step 정의
# ============================================================
@jax.jit
def train_step(state, x, t, ctx, key):
    loss, grads = jax.value_and_grad(loss_fn)(state.params, x, t, ctx, key)
    state = state.apply_gradients(grads=grads)
    return state, loss

# ============================================================
# 5. 실제 학습 루프
# ============================================================
for step in range(10):
    key, subkey = jax.random.split(key)
    state, loss = train_step(state, x, t, ctx, subkey)
    if step % 1 == 0:
        print(f"Step {step:02d} | Loss: {float(loss):.6f}")


Step 00 | Loss: 4.696184
Step 01 | Loss: 3.102890
Step 02 | Loss: 2.428725
Step 03 | Loss: 2.197905
Step 04 | Loss: 1.968389
Step 05 | Loss: 1.933393
Step 06 | Loss: 1.750799
Step 07 | Loss: 1.765865
Step 08 | Loss: 1.664528
Step 09 | Loss: 1.563093


In [None]:
def lin_beta_schedule(T=1000, beta_start=1e-4, beta_end=2e-2):
    # 미리 고정 스케줄 (CPU에서도 가벼움)
    betas = jnp.linspace(beta_start, beta_end, T, dtype=jnp.float32)
    alphas = 1.0 - betas
    alpha_bars = jnp.cumprod(alphas)
    return betas, alphas, alpha_bars

BETAS, ALPHAS, ALPHA_BARS = lin_beta_schedule()

def loss_fn(params, x0, t_cont, ctx, key):
    """
    x0: (B,H,W,C), t_cont: (B,)∈[0,1]
    모델 출력은 NCHW이므로 마지막에 transpose 정렬해둔 상태 유지
    """
    B = x0.shape[0]
    # 연속 t를 discrete step으로 맵핑
    T = ALPHA_BARS.shape[0]
    t_idx = jnp.clip((t_cont * (T - 1)).astype(jnp.int32), 0, T - 1)

    alpha_bar_t = ALPHA_BARS[t_idx]                 # (B,)
    sqrt_ab = jnp.sqrt(alpha_bar_t)[:, None, None, None]
    sqrt_one_minus_ab = jnp.sqrt(1.0 - alpha_bar_t)[:, None, None, None]

    noise = jax.random.normal(key, x0.shape)
    x_t = sqrt_ab * x0 + sqrt_one_minus_ab * noise

    # 모델은 (B,C,H,W) 출력 → 타깃 noise도 동일 포맷으로
    eps_pred = model.apply({'params': params}, x_t, t_cont, ctx)     # (B,C,H,W)
    eps_tgt = noise.transpose(0, 3, 1, 2)                            # (B,C,H,W)

    return jnp.mean((eps_pred - eps_tgt) ** 2)


In [None]:
import optax
from flax.training import train_state
from flax.core import FrozenDict

tx = optax.chain(
    optax.clip_by_global_norm(1.0),       # 폭주 방지
    optax.adamw(1e-3, weight_decay=0.0),
)

class TrainState(train_state.TrainState):
    ema_params: FrozenDict | None = None

def ema_update(ema_params, new_params, decay=0.999):
    if ema_params is None:    # 첫 스텝 bootstrap
        return new_params
    return jax.tree_map(lambda e, p: decay * e + (1 - decay) * p, ema_params, new_params)

state = TrainState.create(apply_fn=model.apply, params=variables['params'], tx=tx, ema_params=None)

@jax.jit
def train_step(state, x0, t_cont, ctx, key):
    loss, grads = jax.value_and_grad(loss_fn)(state.params, x0, t_cont, ctx, key)
    state = state.apply_gradients(grads=grads)
    new_ema = ema_update(state.ema_params, state.params, decay=0.999)
    state = state.replace(ema_params=new_ema)
    # grad norm 모니터링
    grad_norm = jnp.sqrt(sum([jnp.sum(g**2) for g in jax.tree_leaves(grads)]))
    return state, loss, grad_norm

# 루프
for step in range(10):
    key, sub = jax.random.split(key)
    state, loss, gnorm = train_step(state, x, t, ctx, sub)
    print(f"Step {step:02d} | Loss: {float(loss):.6f} | grad_norm: {float(gnorm):.3f}")



  grad_norm = jnp.sqrt(sum([jnp.sum(g**2) for g in jax.tree_leaves(grads)]))


Step 00 | Loss: 4.491674 | grad_norm: 14.472


  return jax.tree_map(lambda e, p: decay * e + (1 - decay) * p, ema_params, new_params)


Step 01 | Loss: 3.430302 | grad_norm: 9.012
Step 02 | Loss: 2.891752 | grad_norm: 5.750
Step 03 | Loss: 2.492566 | grad_norm: 4.882
Step 04 | Loss: 2.378393 | grad_norm: 3.864
Step 05 | Loss: 2.092716 | grad_norm: 2.930
Step 06 | Loss: 1.965346 | grad_norm: 2.324
Step 07 | Loss: 1.986635 | grad_norm: 1.969
Step 08 | Loss: 1.732124 | grad_norm: 1.480
Step 09 | Loss: 1.677979 | grad_norm: 1.258


In [None]:
def sample_batch(key, B=8, H=16, W=16, C=4, ctx_size=4, ctx_dim=16):
    k1,k2,k3 = jax.random.split(key, 3)
    x0  = jax.random.normal(k1, (B,H,W,C))
    t   = jax.random.uniform(k2, (B,), minval=0.0, maxval=1.0)
    ctx = jax.random.normal(k3, (B, ctx_size, ctx_dim))
    return x0, t, ctx

for step in range(50):
    key, sub = jax.random.split(key)
    xb, tb, ctxb = sample_batch(sub)
    key, sub2 = jax.random.split(key)
    state, loss, gnorm = train_step(state, xb, tb, ctxb, sub2)
    if step % 5 == 0:
        print(f"[{step:03d}] loss={float(loss):.5f} gnorm={float(gnorm):.3f}")


[000] loss=2.65521 gnorm=4.101
[005] loss=1.93928 gnorm=1.516
[010] loss=1.61671 gnorm=0.787
[015] loss=1.57689 gnorm=0.718
[020] loss=1.36980 gnorm=0.648
[025] loss=1.29856 gnorm=0.556
[030] loss=1.12467 gnorm=0.532
[035] loss=1.12139 gnorm=0.512
[040] loss=1.10809 gnorm=0.443
[045] loss=0.99940 gnorm=0.420


In [None]:
def sample_eps_pred(params, x_t, t_cont, ctx):
    return model.apply({'params': params}, x_t, t_cont, ctx)

def simple_denoise(key, params, ctx, H=16, W=16, C=4, steps=16):
    B = ctx.shape[0]
    x = jax.random.normal(key, (B,H,W,C))
    for i in range(steps, 0, -1):
        t_cont = jnp.full((B,), i/steps, dtype=jnp.float32)
        eps = sample_eps_pred(params, x, t_cont, ctx)             # (B,C,H,W)
        eps = eps.transpose(0,2,3,1)                              # → (B,H,W,C)
        # 간단히 한 스텝 Euler 업데이트(학술적으로는 scheduler 필요)
        x = x - 0.1 * eps
    return x  # 샘플된 “이미지/latent” 흉내

In [None]:
key, sub = jax.random.split(key)
# ctx도 학습 시의 shape과 맞춰줍니다
ctx_sample = jax.random.normal(sub, (2, 4, 16))  # (B, ctx_size, ctx_dim)

# EMA 파라미터가 있으면 그걸 쓰는 게 좋고, 없으면 state.params 사용
params = state.ema_params if state.ema_params is not None else state.params

# denoise 실행
key, sub2 = jax.random.split(key)
sampled = simple_denoise(sub2, params, ctx_sample, H=16, W=16, C=4, steps=16)

print("샘플 shape:", sampled.shape)
print("샘플 요약 (평균/표준편차):", jnp.mean(sampled), jnp.std(sampled))

샘플 shape: (2, 16, 16, 4)
샘플 요약 (평균/표준편차): 0.07229324 3.2548957
