In [2]:
# IMPORTS
import jax.numpy as jnp
import numpy as np
import flax
import flax.linen as nn
from typing import Any, Tuple
import functools
import jax


In [None]:
class alpha_step_projection(nn.Module):
    embed_dim : int
    scale : float = 30.
    
    @nn.compact
    def __call__(self, x):
        W = self.param('W', jax.nn.initializers.xavier_normal(stddev=self.scale),
                       (self.embed_dim//2,))
        W = jax.lax.stop_gradient(W)
        x_proj = x[:, None]* W[None, :] * 2* jnp.pi
        return jnp.concatenate([jnp.sin(x_proj), jnp.cos(x_proj)], axis = -1)
    
class Dense(nn.Module):
    """A fully connected layer that reshapes outputs for addition"""
    output_dim : int

    @nn.compact
    def __call__(self, x):
        return nn.Dense(self.output_dim)(x)
    
class diffusion_model(nn.Module):
    channels : Tuple[int] = (64,64,64,64,64,2)
    embed_dim : int = 256

    @nn.compact
    def __call__(self, x, alpha):
        act = nn.swish
        embed = act(nn.Dense(self.embed_dim)(alpha_step_projection(embed_dim=self.embed_dim)(alpha)))

        h1 = nn.Dense(features=self.channels[0])(x)
        h1+= Dense(output_dim=self.channels[0])(embed)
        h1 = nn.BatchNorm(use_running_average=True)(h1)

        




In [5]:
key = jax.random.PRNGKey(0)
rng, step_rng = jax.random.split(key)
x0 = jnp.ones((1,2))
x1 = jax.random.uniform(step_rng, (1,))

In [6]:
x0

Array([[1., 1.]], dtype=float32)

In [7]:
x1

Array([0.10536897], dtype=float32)

In [8]:
x_alpha = (1-x1)*x0

In [9]:
x_alpha

Array([[0.894631, 0.894631]], dtype=float32)