# Video Diffusion with FlaxUNet3DConditionModel

This notebook demonstrates how to use the FlaxUNet3DConditionModel for video diffusion tasks.

In [None]:
from flaxdiff.schedulers import EDMNoiseScheduler, KarrasVENoiseScheduler
from flaxdiff.predictors import KarrasPredictionTransform
from flaxdiff.models.simple_unet import Unet
from flaxdiff.trainer.general_diffusion_trainer import GeneralDiffusionTrainer, ConditionalInputConfig
from flaxdiff.data.datasets import get_dataset_grain, get_media_dataset_grain
from flaxdiff.utils import defaultTextEncodeModel
from flaxdiff.models.autoencoder.diffusers import StableDiffusionVAE
from flaxdiff.samplers.euler import EulerAncestralSampler
import jax
import jax.numpy as jnp
import optax
from datetime import datetime
import argparse
import os

BATCH_SIZE = 16
IMAGE_SIZE = 256

  from .autonotebook import tqdm as notebook_tqdm


In [None]:
# Load dataset
data = get_media_dataset_grain("ucf101", batch_size=BATCH_SIZE, media_scale=IMAGE_SIZE)
datalen = data['train_len']
batches = datalen // BATCH_SIZE

In [3]:
dataiter = iter(data['train']())
batch = next(dataiter)

In [2]:
# Load dataset
data = get_dataset_grain("oxford_flowers102", batch_size=BATCH_SIZE, image_scale=IMAGE_SIZE)
datalen = data['train_len']
batches = datalen // BATCH_SIZE

text_encoder = defaultTextEncodeModel()
autoencoder = StableDiffusionVAE(**{"modelname": "pcuenq/sd-vae-ft-mse-flax"})

# Construct a validation set by the prompts
val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a marigold', ' a water lily', ' a photo of a sunflower', ' a photo of a lotus', ' columbine', ' columbine', ' an orchid', ' an orchid', ' an orchid', ' a water lily', ' a water lily', ' a water lily', ' columbine', ' columbine', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a lotus', ' a photo of a lotus', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a rose', ' a photo of a rose', ' a photo of a rose', ' orange dahlia', ' orange dahlia', ' a lenten rose', ' a lenten rose', ' a water lily', ' a water lily', ' a water lily', ' a water lily', ' an orchid', ' an orchid', ' an orchid', ' hard-leaved pocket orchid', ' bird of paradise', ' bird of paradise', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a globe-flower', ' a photo of a globe-flower', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a ruby-lipped cattleya', ' a photo of a ruby-lipped cattleya', ' a photo of a lovely rose', ' a water lily', ' a osteospermum', ' a osteospermum', ' a water lily', ' a water lily', ' a water lily', ' a red rose', ' a red rose']

def get_val_dataset(batch_size=8):
    for i in range(0, len(val_prompts), batch_size):
        prompts = val_prompts[i:i + batch_size]
        tokens = text_encoder.tokenize(prompts)
        yield {"text": tokens}

data['test'] = get_val_dataset
data['test_len'] = len(val_prompts)


2025-04-13 14:55:42.973940: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1744556142.998162  192486 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1744556143.005274  192486 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1744556143.022849  192486 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744556143.022872  192486 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1744556143.022874  192486 computation_placer.cc:177] computation placer alr

Scaling factor: 0.18215
Calculating downscale factor...
Downscale factor: 8
Latent channels: 4


In [3]:
dataiter = iter(data['train']())
batch = next(dataiter)

In [5]:
batch['image'].shape

(16, 256, 256, 3)

In [2]:
def create_model(rng):
    num_frames = 8
    model = FlaxUNet3DConditionModel(
        sample_size=(num_frames, 32, 32),
        in_channels=4,
        out_channels=4,
        down_block_types=(
            "CrossAttnDownBlock3D",
            "CrossAttnDownBlock3D",
            "CrossAttnDownBlock3D",
            "DownBlock3D",
        ),
        up_block_types=(
            "UpBlock3D",
            "CrossAttnUpBlock3D",
            "CrossAttnUpBlock3D",
            "CrossAttnUpBlock3D",
        ),
        block_out_channels=(32, 64, 128, 256),
        layers_per_block=1,
        cross_attention_dim=64,
        attention_head_dim=8,
        dropout=0.0,
        dtype=jnp.bfloat16
    )
    
    # Create dummy inputs for initialization
    batch_size = 1
    sample = jax.random.normal(
        rng, 
        shape=(batch_size, num_frames, 32, 32, 4),
        dtype=jnp.bfloat16
    )
    
    timestep = jnp.array([0], dtype=jnp.int32)
    
    # Create dummy text embeddings
    encoder_hidden_states = jax.random.normal(
        rng, 
        shape=(batch_size, 77, 64),  # 77 is standard for CLIP text tokens
        dtype=jnp.bfloat16
    )
    
    # Initialize the model
    params = model.init(rng, sample, timestep, encoder_hidden_states)
    
    # Print model summary
    param_count = sum(p.size for p in jax.tree_util.tree_leaves(params))
    print(f"Model initialized with {param_count:,} parameters")
    
    return model, params

rng = jax.random.PRNGKey(42)
rng, model_rng = jax.random.split(rng)
model, params = create_model(model_rng)

Model initialized with 20,027,236 parameters


In [4]:
sample_video = np.random.rand(2, 8, 32, 32, 4).astype(np.float32)
sample_video = jnp.array(sample_video)
timestep = jnp.ones((2,), dtype=jnp.int32) * 0
    
# Create dummy text embeddings
encoder_hidden_states = jax.random.normal(
    rng, 
    shape=(2, 77, 64),  # 77 is standard for CLIP text tokens
    dtype=jnp.bfloat16
)

out = model.apply(
    params,
    sample_video,
    timestep,
    encoder_hidden_states,
    return_dict=True
)
print(out.shape)  # Should be (2, 8, 32, 32, 4)

(2, 8, 32, 32, 4)


## 2. Set up the Diffusion Process

Now we'll set up the noise scheduler and sampler for our diffusion process.

In [5]:
# Create a noise scheduler
noise_scheduler = EDMNoiseScheduler(1, sigma_min=0.002, sigma_max=80.0, rho=7.0)

# Create a prediction transform
model_output_transform = EpsilonPredictionTransform()

# Create a sampler
sampler = EulerSampler(
    model=model,
    params=params,
    noise_schedule=noise_scheduler,
    model_output_transform=model_output_transform,
    guidance_scale=0
)

## 3. Generate a Simple Video

Let's generate a simple random video using our model. For a real application, you would use a text encoder like CLIP to encode prompts.

In [6]:
def generate_video(num_frames=8, height=32, width=32, steps=20):
    # Create mock text embeddings (in a real scenario, you'd use a text encoder like CLIP)
    batch_size = 1
    rng_gen = jax.random.PRNGKey(123)  # Using a different seed
    
    # Generate random text embeddings
    encoder_hidden_states = jax.random.normal(
        rng_gen, 
        shape=(batch_size, 77, 64),
        dtype=jnp.float32
    )
    
    # Generate video frames
    print(f"Generating {num_frames} frames with {steps} diffusion steps...")
    video = sampler.generate_images(
        params=params,
        batch_size=batch_size,
        sequence_length=num_frames,
        diffusion_steps=steps,
        start_step=1000,
        end_step=0,
        priors=None,
        model_conditioning_inputs=(encoder_hidden_states,),
    )
    
    return video

# Generate video
video = generate_video(num_frames=8, steps=20)

Generating 8 frames with 20 diffusion steps...


  0%|          | 0/20 [00:00<?, ?it/s]


ScopeParamShapeError: Initializer expected to generate shape (3, 3, 3, 4, 32) but got shape (3, 3, 3, 3, 32) instead for parameter "kernel" in "/conv_in". (https://flax.readthedocs.io/en/latest/api_reference/flax.errors.html#flax.errors.ScopeParamShapeError)

## 4. Visualize the Generated Video

In [None]:
def visualize_video(video):
    # Normalize to [0, 1] range for visualization
    video_clip = np.array(video[0])
    video_clip = (video_clip + 1.0) / 2.0  # Assuming [-1, 1] range
    video_clip = np.clip(video_clip, 0.0, 1.0)
    
    # Only use RGB channels (first 3) for visualization
    video_clip = video_clip[:, :, :, :3]
    
    # Create a figure for animation
    fig, ax = plt.subplots(figsize=(5, 5))
    ax.axis('off')
    
    # Create initial frame
    img = ax.imshow(video_clip[0])
    
    # Animation function
    def animate(i):
        img.set_array(video_clip[i])
        return [img]
    
    # Create animation
    anim = animation.FuncAnimation(
        fig, animate, frames=len(video_clip), interval=200, blit=True
    )
    
    # Display the animation
    from IPython.display import HTML
    HTML(anim.to_jshtml())
    
    # Also display individual frames for reference
    fig, axes = plt.subplots(1, len(video_clip), figsize=(15, 3))
    for i, ax in enumerate(axes):
        ax.imshow(video_clip[i])
        ax.set_title(f"Frame {i}")
        ax.axis('off')
    plt.tight_layout()
    
    return anim

# Visualize the generated video
anim = visualize_video(video)

## 5. Save the Generated Video

In [None]:
def save_video(video, filename='generated_video.mp4'):
    video_clip = np.array(video[0])
    video_clip = (video_clip + 1.0) / 2.0  # Assuming [-1, 1] range
    video_clip = np.clip(video_clip, 0.0, 1.0)
    
    # Only use RGB channels (first 3) for saving
    video_clip = video_clip[:, :, :, :3]
    
    # Create a figure for animation
    fig, ax = plt.subplots(figsize=(5, 5))
    ax.axis('off')
    
    # Create initial frame
    img = ax.imshow(video_clip[0])
    
    # Animation function
    def animate(i):
        img.set_array(video_clip[i])
        return [img]
    
    # Create animation
    anim = animation.FuncAnimation(
        fig, animate, frames=len(video_clip), interval=200, blit=True
    )
    
    # Save the animation
    anim.save(filename, writer='ffmpeg', fps=5, dpi=100)
    print(f"Video saved to {filename}")
    
    # Also save individual frames
    for i, frame in enumerate(video_clip):
        plt.imsave(f"frame_{i}.png", frame)
    
# Comment out if you don't have ffmpeg installed
# save_video(video)

## 6. Experiment with Different Parameters

Let's experiment with different guidance scales to see how they affect the generated video.

In [None]:
def experiment_with_guidance_scale(guidance_scales=[1.0, 3.0, 5.0, 7.0], num_frames=8, steps=20):
    results = {}
    
    for gs in guidance_scales:
        print(f"Generating video with guidance scale {gs}...")
        
        # Create a sampler with the current guidance scale
        temp_sampler = EulerSampler(
            model=model,
            params=params,
            noise_schedule=noise_scheduler,
            model_output_transform=model_output_transform,
            guidance_scale=gs,
        )
        
        # Create mock text embeddings
        batch_size = 1
        rng_gen = jax.random.PRNGKey(123)  # Using a consistent seed for comparison
        
        encoder_hidden_states = jax.random.normal(
            rng_gen, 
            shape=(batch_size, 77, 64),
            dtype=jnp.float32
        )
        
        # Generate video
        video = temp_sampler.generate_images(
            params=params,
            num_images=batch_size,
            diffusion_steps=steps,
            start_step=1000,
            end_step=0,
            priors=None,
            image_shape=(num_frames, 32, 32, 4),
            model_conditioning_inputs=(encoder_hidden_states,),
        )
        
        results[gs] = video
    
    return results

# Uncomment to run the experiment
# guidance_results = experiment_with_guidance_scale()

## 7. Processing Existing Video

In a real-world scenario, you might want to process existing video frames. Here's how you could do that with the UNet3D model.

In [None]:
def process_existing_video(video_frames, noise_level=0.2):
    """
    Process existing video frames with the UNet3D model.
    This is a simple example that adds noise and then denoises.
    
    Args:
        video_frames: numpy array of shape (num_frames, height, width, channels)
        noise_level: Amount of noise to add (0-1)
    """
    # Convert to JAX array and ensure correct shape
    video_frames = jnp.array(video_frames)
    batch_size = 1
    num_frames, height, width, channels = video_frames.shape
    
    # Scale to [-1, 1] if needed
    if video_frames.max() > 1.0:
        video_frames = video_frames / 255.0
    if video_frames.max() <= 1.0 and video_frames.min() >= 0.0:
        video_frames = video_frames * 2.0 - 1.0
    
    # Add a batch dimension
    video_frames = video_frames.reshape(batch_size, num_frames, height, width, channels)
    
    # If channels < 4, pad with zeros
    if channels < 4:
        padding = jnp.zeros((batch_size, num_frames, height, width, 4 - channels))
        video_frames = jnp.concatenate([video_frames, padding], axis=-1)
    
    # Add noise
    rng_noise = jax.random.PRNGKey(456)
    noise = jax.random.normal(rng_noise, video_frames.shape)
    noisy_frames = video_frames + noise_level * noise
    
    # Create mock text embeddings 
    rng_text = jax.random.PRNGKey(789)
    encoder_hidden_states = jax.random.normal(
        rng_text, 
        shape=(batch_size, 77, 64),
        dtype=jnp.float32
    )
    
    # Process the video
    print("Processing video...")
    
    # For a simple demonstration, we'll just do a single denoising step
    timestep = jnp.array([500], dtype=jnp.int32)  # Middle of the diffusion process
    output = model.apply(params, noisy_frames, timestep, encoder_hidden_states)
    
    # Extract the first 3 channels for visualization
    processed_frames = output['sample'][0, :, :, :, :3]
    original_frames = video_frames[0, :, :, :, :3]
    noisy_frames = noisy_frames[0, :, :, :, :3]
    
    # Normalize to [0, 1] for visualization
    processed_frames = (processed_frames + 1.0) / 2.0
    original_frames = (original_frames + 1.0) / 2.0
    noisy_frames = (noisy_frames + 1.0) / 2.0
    
    processed_frames = jnp.clip(processed_frames, 0.0, 1.0)
    original_frames = jnp.clip(original_frames, 0.0, 1.0)
    noisy_frames = jnp.clip(noisy_frames, 0.0, 1.0)
    
    return {
        'original': original_frames,
        'noisy': noisy_frames,
        'processed': processed_frames
    }

# Create some synthetic video frames for demonstration
def create_synthetic_video(num_frames=8, height=32, width=32):
    """Create a simple synthetic video with moving shapes"""
    frames = np.zeros((num_frames, height, width, 3))
    
    # Add a moving circle
    for i in range(num_frames):
        # Create frame with a circle
        frame = np.zeros((height, width, 3))
        x_center = width // 2 + int(width * 0.3 * np.sin(i / num_frames * 2 * np.pi))
        y_center = height // 2 + int(height * 0.3 * np.cos(i / num_frames * 2 * np.pi))
        
        # Draw circle
        for y in range(height):
            for x in range(width):
                dist = np.sqrt((x - x_center)**2 + (y - y_center)**2)
                if dist < 5:
                    frame[y, x, 0] = 1.0  # Red circle
        
        # Add a static square
        frame[5:15, 5:15, 1] = 1.0  # Green square
        
        frames[i] = frame
    
    return frames

# Generate synthetic video and process it
synthetic_video = create_synthetic_video()
# Uncomment to process the video
# processed_results = process_existing_video(synthetic_video, noise_level=0.3)

## 8. Using Frame-Specific Conditioning

The UNet3D model now supports both video-wide conditioning and optional frame-specific conditioning. Let's see how to use this feature.

In [None]:
def generate_with_frame_conditioning(num_frames=8, height=32, width=32, steps=20):
    # Create batch
    batch_size = 1
    rng_gen = jax.random.PRNGKey(123)
    rng_gen, key1, key2 = jax.random.split(rng_gen, 3)
    
    # Generate random global text embeddings
    encoder_hidden_states = jax.random.normal(
        key1, 
        shape=(batch_size, 77, 64),
        dtype=jnp.float32
    )
    
    # Generate random frame-specific embeddings
    frame_encoder_hidden_states = jax.random.normal(
        key2, 
        shape=(batch_size, num_frames, 77, 64),
        dtype=jnp.float32
    )
    
    # Generate video frames - demonstrate with and without frame conditioning
    print(f"Generating {num_frames} frames with global conditioning only...")
    video_global = sampler.generate_images(
        params=params,
        num_images=batch_size,
        diffusion_steps=steps,
        start_step=1000,
        end_step=0,
        priors=None,
        image_shape=(num_frames, height, width, 4),
        model_conditioning_inputs=(encoder_hidden_states,),
    )
    
    print(f"Generating {num_frames} frames with global + frame-specific conditioning...")
    video_combined = sampler.generate_images(
        params=params,
        num_images=batch_size,
        diffusion_steps=steps,
        start_step=1000,
        end_step=0,
        priors=None,
        image_shape=(num_frames, height, width, 4),
        model_conditioning_inputs=(encoder_hidden_states, frame_encoder_hidden_states),
    )
    
    return video_global, video_combined

# Uncomment to run the experiment
# video_global, video_combined = generate_with_frame_conditioning()

In [None]:
def compare_videos(video1, video2, title1="Global Conditioning", title2="Global + Frame Conditioning"):
    # Normalize both videos
    def normalize_video(video):
        video_clip = np.array(video[0])
        video_clip = (video_clip + 1.0) / 2.0
        video_clip = np.clip(video_clip, 0.0, 1.0)
        video_clip = video_clip[:, :, :, :3]  # RGB only
        return video_clip
    
    video1_norm = normalize_video(video1)
    video2_norm = normalize_video(video2)
    
    # Display side by side frames
    num_frames = video1_norm.shape[0]
    fig, axes = plt.subplots(2, num_frames, figsize=(num_frames*2, 4))
    
    # Display first video on top row
    for i in range(num_frames):
        axes[0, i].imshow(video1_norm[i])
        axes[0, i].set_title(f"Frame {i}")
        axes[0, i].axis('off')
    axes[0, 0].set_ylabel(title1)
    
    # Display second video on bottom row
    for i in range(num_frames):
        axes[1, i].imshow(video2_norm[i])
        axes[1, i].set_title(f"Frame {i}")
        axes[1, i].axis('off')
    axes[1, 0].set_ylabel(title2)
    
    plt.tight_layout()
    plt.show()

# Uncomment to compare the videos
# if 'video_global' in locals() and 'video_combined' in locals():
#     compare_videos(video_global, video_combined)

## Conclusion

In this notebook, we've demonstrated:
1. How to initialize and use the FlaxUNet3DConditionModel
2. How to generate new videos from random noise
3. How to modify existing videos using the model
4. How to use frame-specific conditioning for more detailed control

The FlaxUNet3DConditionModel provides a powerful tool for video diffusion tasks, offering the performance benefits of JAX and Flax while maintaining compatibility with diffusers-style APIs.