In [2]:
import jax
import jax.numpy as jnp
import jax.random as jrn
import flax.linen as nn
from flax import nnx
from jax import grad, jit
import numpy as np

In [None]:
gaussian_sampler = jrn.normal


TypeError: iteration over a 0-d key array

In [2]:
# Define a simple MLP

class SimpleMLP(nnx.Module):
    def __init__(self, hidden_dim: int = 64, output_dim: int = 1, rngs: nnx.Rngs = None):
        self.hidden = nnx.Linear(10, hidden_dim, rngs=rngs)  # input_dim=10
        self.output = nnx.Linear(hidden_dim, output_dim, rngs=rngs)
    
    def __call__(self, x):
        x = nnx.relu(self.hidden(x))
        return self.output(x)


        


In [3]:
# Initialize models
rngs = nnx.Rngs(42)
model_alpha = SimpleMLP(rngs=rngs)
model_beta = SimpleMLP(rngs=rngs.fork())

# Test input
x = jax.random.normal(jax.random.PRNGKey(0), (5, 10))

print("Alpha model structure:")
print(f"Hidden weight shape: {model_alpha.hidden.kernel.shape}")
print(f"Hidden bias shape: {model_alpha.hidden.bias.shape}")
print(f"Output weight shape: {model_alpha.output.kernel.shape}")

# Test forward pass
out_alpha = model_alpha(x)
out_beta = model_beta(x)
print(f"Output shapes: {out_alpha.shape}, {out_beta.shape}")

Alpha model structure:
Hidden weight shape: (10, 64)
Hidden bias shape: (64,)
Output weight shape: (64, 1)
Output shapes: (5, 1), (5, 1)


In [7]:
isinstance(model_alpha,nnx.Module)

True

In [4]:
def interpolate_models(model_alpha, model_beta, weight=0.5):
    """Create interpolated model - this is your spline evaluation!"""
    # Get model states (parameters)
    state_alpha = nnx.state(model_alpha)
    state_beta = nnx.state(model_beta)
    
    # Interpolate parameters
    interpolated_state = jax.tree.map(
        lambda a, b: weight * a + (1 - weight) * b, 
        state_alpha, state_beta
    )   
    interpolate_model = nnx.clone(model_alpha)
    nnx.update(interpolate_model,interpolated_state)
    return interpolate_model

In [5]:
# Test interpolation
model_gamma = interpolate_models(model_alpha, model_beta, weight=0.3)
# model_gamma = SimpleMLP(rngs = nnx.Rngs(0))
# nnx.update(model_gamma,parameters_gamma)
out_gamma = model_gamma(x)
print(f"Interpolated output shape: {out_gamma.shape}")
print(f"Interpolated vs alpha: {jnp.mean(jnp.abs(out_gamma - out_alpha)):.6f}")

Interpolated output shape: (5, 1)
Interpolated vs alpha: 0.402743


In [6]:
# Test gradients through parameter interpolation
def loss_fn(model_alpha, model_beta, x):
    # Interpolate models
    interpolated_model = interpolate_models(model_alpha,model_beta,)
    
    # Forward pass
    output = interpolated_model(x)
    
    # Loss: absolute value as requested
    return jnp.mean(jnp.abs(output))

# NNX makes gradient computation clean!
loss_grad_fn = nnx.grad(loss_fn, argnums=(0, 1))

print("Computing gradients...")
grads_alpha, grads_beta = loss_grad_fn(model_alpha, model_beta, x)

print("Gradient for alpha model:")
print(f"Hidden kernel grad norm: {jnp.linalg.norm(grads_alpha.hidden.kernel):.6f}")
print(f"Hidden bias grad norm: {jnp.linalg.norm(grads_alpha.hidden.bias):.6f}")
print(f"Output kernel grad norm: {jnp.linalg.norm(grads_alpha.output.kernel):.6f}")

print("\nGradient for beta model:")
print(f"Hidden kernel grad norm: {jnp.linalg.norm(grads_beta.hidden.kernel):.6f}")
print(f"Output kernel grad norm: {jnp.linalg.norm(grads_beta.output.kernel):.6f}")

Computing gradients...
Gradient for alpha model:
Hidden kernel grad norm: 0.495711
Hidden bias grad norm: 0.117731
Output kernel grad norm: 0.926063

Gradient for beta model:
Hidden kernel grad norm: 0.495711
Output kernel grad norm: 0.926063


In [56]:
class ResNet(nnx.Module):
    def __init__(self, input_dim: int = 1, width: int = 10, num_layers: int = 1, 
                 output_dim: int = 1, activation=nnx.relu, rngs: nnx.Rngs = None):
        self.input_dim = input_dim
        self.width = width  # Fixed typo: widht -> width
        self.num_layers = num_layers
        self.output_dim = output_dim
        self.activation = activation
        
        # Initialize layers
        self.init_layer = nnx.Linear(input_dim, width, rngs=rngs)
        
        # Use a list instead of dict - NNX tracks lists of modules properly
        self.layers = [nnx.Linear(width, width, rngs=rngs) for i in range(num_layers)]
        
        self.output_layer = nnx.Linear(width, output_dim, rngs=rngs)
    
    def __call__(self, x):  # Fixed: **call** -> __call__
        x = self.init_layer(x)
        
        # Residual connections
        for layer in self.layers:  # Now iterates over actual layer objects
            residual = x
            x = self.activation(layer(x)) + residual  # ResNet skip connection
            
        return self.output_layer(x)

In [57]:
# Test parameter combination
alpha = ResNet(input_dim=10, width=10, num_layers=1, output_dim=1, activation=nnx.relu, rngs=nnx.Rngs(0))
beta = ResNet(input_dim=10, width=10, num_layers=1, output_dim=1, activation=nnx.relu, rngs=nnx.Rngs(1))


In [60]:
# Test gradients of linear combination 



# Test gradients through parameter interpolation
def loss_fn(model_alpha, model_beta, x):
    # Interpolate models
    model_combined = interpolate_models(model_alpha, model_beta, weight=0.4)
    
    # Forward pass
    output = model_combined(x)
    
    # Loss: absolute value as requested
    return jnp.mean(jnp.abs(output))

# NNX makes gradient computation clean!
loss_grad_fn = nnx.grad(loss_fn, argnums=(0, 1))
grads_alpha, grads_beta = loss_grad_fn(alpha, beta, x)



In [68]:
from typing import Callable

# Ref: https://docs.jaxstack.ai/en/latest/digits_diffusion_model.html

class UNet(nnx.Module):
    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 features: int,
                 time_emb_dim: int = 128,
                 *,
                 rngs: nnx.Rngs):
        """
        Initialize the U-Net architecture with time embedding.
        """
        self.features = features

        # Time embedding layers for diffusion timestep conditioning.
        self.time_mlp_1 = nnx.Linear(in_features=time_emb_dim, out_features=time_emb_dim, rngs=rngs)
        self.time_mlp_2 = nnx.Linear(in_features=time_emb_dim, out_features=time_emb_dim, rngs=rngs)

        # Time projection layers for different scales.
        self.time_proj1 = nnx.Linear(in_features=time_emb_dim, out_features=features, rngs=rngs)
        self.time_proj2 = nnx.Linear(in_features=time_emb_dim, out_features=features * 2, rngs=rngs)
        self.time_proj3 = nnx.Linear(in_features=time_emb_dim, out_features=features * 4, rngs=rngs)
        self.time_proj4 = nnx.Linear(in_features=time_emb_dim, out_features=features * 8, rngs=rngs)

        # The encoder path.
        self.down_conv1 = self._create_residual_block(in_channels, features, rngs)
        self.down_conv2 = self._create_residual_block(features, features * 2, rngs)
        self.down_conv3 = self._create_residual_block(features * 2, features * 4, rngs)
        self.down_conv4 = self._create_residual_block(features * 4, features * 8, rngs)

        # Multi-head self-attention blocks.
        self.attention1 = self._create_attention_block(features * 4, rngs)
        self.attention2 = self._create_attention_block(features * 8, rngs)

        # The bridge connecting the encoder and the decoder.
        self.bridge_down = self._create_residual_block(features * 8, features * 16, rngs)
        self.bridge_attention = self._create_attention_block(features * 16, rngs)
        self.bridge_up = self._create_residual_block(features * 16, features * 16, rngs)

        # Decoder path with skip connections.
        self.up_conv4 = self._create_residual_block(features * 24, features * 8, rngs)
        self.up_conv3 = self._create_residual_block(features * 12, features * 4, rngs)
        self.up_conv2 = self._create_residual_block(features * 6, features * 2, rngs)
        self.up_conv1 = self._create_residual_block(features * 3, features, rngs)

        # Output layers.
        self.final_norm = nnx.LayerNorm(features, rngs=rngs)
        self.final_conv = nnx.Conv(in_features=features,
                                 out_features=out_channels,
                                 kernel_size=(3, 3),
                                 strides=(1, 1),
                                 padding=((1, 1), (1, 1)),
                                 rngs=rngs)

    def _create_attention_block(self, channels: int, rngs: nnx.Rngs) -> Callable:
        """Creates a self-attention block with learned query, key, value projections.

        Args:
            channels (int): The number of channels in the input feature maps.
            rngs (flax.nnx.Rngs): A set of named `flax.nnx.RngStream` objects that generate a stream of JAX pseudo-random number generator (PRNG) keys.

        Returns:
            Callable: A function representing a forward pass through the attention block.

        """
        query_proj = nnx.Linear(in_features=channels, out_features=channels, rngs=rngs)
        key_proj = nnx.Linear(in_features=channels, out_features=channels, rngs=rngs)
        value_proj = nnx.Linear(in_features=channels, out_features=channels, rngs=rngs)

        def forward(x: jax.Array) -> jax.Array:
            """Applies self-attention to the input.

            Args:
                x (jax.Array): The input tensor with the shape `[batch, height, width, channels]` (or `B, H, W, C`).

            Returns:
                jax.Array: The output tensor after applying self-attention.
            """

            # Shape: batch, height, width, channels.
            B, H, W, C = x.shape
            scale = jnp.sqrt(C).astype(x.dtype)

            # Project the input into query, key, value projections.
            q = query_proj(x)
            k = key_proj(x)
            v = value_proj(x)

            # Reshape for the attention computation.
            q = q.reshape(B, H * W, C)
            k = k.reshape(B, H * W, C)
            v = v.reshape(B, H * W, C)

            # Compute the scaled dot-product attention.
            attention = jnp.einsum('bic,bjc->bij', q, k) / scale  # Scaled dot-product.
            attention = jax.nn.softmax(attention, axis=-1)  # Softmax.

            # The output tensor.
            out = jnp.einsum('bij,bjc->bic', attention, v)
            out = out.reshape(B, H, W, C)

            return x + out  # A ResNet-style residual connection.

        return forward

    def _create_residual_block(self,
                              in_channels: int,
                              out_channels: int,
                              rngs: nnx.Rngs) -> Callable:
        """Creates a residual block with two convolutions and normalization.

        Args:
            in_channels (int): Number of input channels.
            out_channels (int): Number of output channels.
            rngs (flax.nnx.Rngs): A set of named `flax.nnx.RngStream` objects that generate a stream of JAX PRNG keys.

        Returns:
            Callable: A function that represents the forward pass through the residual block.
        """

        # Convolutional layers with layer normalization.
        conv1 = nnx.Conv(in_features=in_channels,
                        out_features=out_channels,
                        kernel_size=(3, 3),
                        strides=(1, 1),
                        padding=((1, 1), (1, 1)),
                        rngs=rngs)
        norm1 = nnx.LayerNorm(out_channels, rngs=rngs)
        conv2 = nnx.Conv(in_features=out_channels,
                        out_features=out_channels,
                        kernel_size=(3, 3),
                        strides=(1, 1),
                        padding=((1, 1), (1, 1)),
                        rngs=rngs)
        norm2 = nnx.LayerNorm(out_channels, rngs=rngs)

        # Projection shortcut if dimensions change.
        shortcut = nnx.Conv(in_features=in_channels,
                            out_features=out_channels,
                            kernel_size=(1, 1),
                            strides=(1, 1),
                            rngs=rngs)

        # The forward pass through the residual block.
        def forward(x: jax.Array) -> jax.Array:
            identity = shortcut(x)

            x = conv1(x)
            x = norm1(x)
            x = nnx.gelu(x)

            x = conv2(x)
            x = norm2(x)
            x = nnx.gelu(x)

            return x + identity

        return forward

    def _pos_encoding(self, t: jax.Array, dim: int) -> jax.Array:
        """Applies sinusoidal positional encoding for time embedding.

        Args:
            t (jax.Array): The time embedding, representing the timestep.
            dim (int): The dimension of the output positional encoding.

        Returns:
            jax.Array: The sinusoidal positional embedding per timestep.

        """
        # Calculate half the embedding dimension.
        half_dim = dim // 2
        # Compute the logarithmic scaling factor for sinusoidal frequencies.
        emb = jnp.log(10000.0) / (half_dim - 1)
        # Generate a range of sinusoidal frequencies.
        emb = jnp.exp(jnp.arange(half_dim) * -emb)
        # Create the positional encoding by multiplying time embeddings with.
        emb = t[:, None] * emb[None, :]
        # Concatenate sine and cosine components for richer representation.
        emb = jnp.concatenate([jnp.sin(emb), jnp.cos(emb)], axis=1)
        return emb

    def _downsample(self, x: jax.Array) -> jax.Array:
        """Downsamples the input feature map with max pooling."""
        return nnx.max_pool(x, window_shape=(2, 2), strides=(2, 2), padding='SAME')

    def _upsample(self, x: jax.Array, target_size: int) -> jax.Array:
        """Upsamples the input feature map using nearest neighbor interpolation."""
        return jax.image.resize(x,
                              (x.shape[0], target_size, target_size, x.shape[3]),
                              method='nearest')

    def __call__(self, x: jax.Array, t: jax.Array) -> jax.Array:
        """Perform the forward pass through the U-Net using time embeddings."""

        # Time embedding and projection.
        t_emb = self._pos_encoding(t, 128) # Sinusoidal positional encoding for time.
        t_emb = self.time_mlp_1(t_emb) # Project and activate the time embedding
        t_emb = nnx.gelu(t_emb) # Activation function: `flax.nnx.gelu` (GeLU).
        t_emb = self.time_mlp_2(t_emb)

        # Project time embeddings for each scale.
        # Project to the correct dimensions for each encoder block.
        t_emb1 = self.time_proj1(t_emb)[:, None, None, :]
        t_emb2 = self.time_proj2(t_emb)[:, None, None, :]
        t_emb3 = self.time_proj3(t_emb)[:, None, None, :]
        t_emb4 = self.time_proj4(t_emb)[:, None, None, :]

        # The encoder path with time injection.
        d1 = self.down_conv1(x)
        t_emb1 = jnp.broadcast_to(t_emb1, d1.shape) # Broadcast the time embedding to match feature map shape.
        d1 = d1 + t_emb1 # Add the time embedding to the feature map.

        d2 = self.down_conv2(self._downsample(d1))
        t_emb2 = jnp.broadcast_to(t_emb2, d2.shape)
        d2 = d2 + t_emb2

        d3 = self.down_conv3(self._downsample(d2))
        d3 = self.attention1(d3) # Apply self-attention.
        t_emb3 = jnp.broadcast_to(t_emb3, d3.shape)
        d3 = d3 + t_emb3

        d4 = self.down_conv4(self._downsample(d3))
        d4 = self.attention2(d4)
        t_emb4 = jnp.broadcast_to(t_emb4, d4.shape)
        d4 = d4 + t_emb4

        # The bridge.
        b = self._downsample(d4)
        b = self.bridge_down(b)
        b = self.bridge_attention(b)
        b = self.bridge_up(b)

        # The decoder path with skip connections.
        u4 = self.up_conv4(jnp.concatenate([self._upsample(b, d4.shape[1]), d4], axis=-1))
        u3 = self.up_conv3(jnp.concatenate([self._upsample(u4, d3.shape[1]), d3], axis=-1))
        u2 = self.up_conv2(jnp.concatenate([self._upsample(u3, d2.shape[1]), d2], axis=-1))
        u1 = self.up_conv1(jnp.concatenate([self._upsample(u2, d1.shape[1]), d1], axis=-1))

        # Final layers.
        x = self.final_norm(u1)
        x = nnx.gelu(x)
        return self.final_conv(x)

In [69]:
key = jax.random.PRNGKey(42) # PRNG seed for reproducibility.
in_channels = 1
out_channels = 1
features = 64   # Number of features in the U-Net.
num_steps = 1000
num_epochs = 5000
batch_size = 64
learning_rate = 1e-4
beta_start = 1e-4   # The starting value for beta (noise level schedule).
beta_end = 0.02   # The end value for beta (noise level schedule).

# Initialize model components.



In [72]:
key, subkey = jax.random.split(key) # Split the JAX PRNG key for initialization.
alpha = UNet(in_channels, out_channels, features, rngs=nnx.Rngs(default=subkey)) # Instantiate the U-Net.
key,subkey = jax.random.split(key)
beta = UNet(in_channels,out_channels,features,rngs=nnx.Rngs(default=subkey))

In [73]:
gamma = interpolate_models(alpha,beta)