# Interactive Video Sampling with Echo-Dream

This notebook allows you to generate videos using the Echo-Dream model with:
1. User-defined conditions (class ID, LVEF, view)
2. Custom conditioning frames
3. Different sampling modes (diffusion, flow matching)
4. Various output formats

In [1]:
# Import required libraries
import os
import sys
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
from pathlib import Path
from IPython.display import display, HTML, clear_output
import ipywidgets as widgets
from omegaconf import OmegaConf

# Add Echo-Dream to path if needed
sys.path.append(os.path.abspath(".."))

# Import Echo-Dream modules
from echo.common import load_model, save_as_mp4, save_as_gif, save_as_avi, save_as_img
from echo.lvdm.sample import get_conditioning_vector

## Setup Model Paths and Configuration

First, we need to set up paths to our models and configuration files.

In [None]:
# Default paths - adjust these to your setup
DEFAULT_CONFIG_PATH = "../echo/lvdm/configs/default.yaml"
DEFAULT_UNET_PATH = "../models/unet"
DEFAULT_VAE_PATH = "../models/vae"
DEFAULT_CONDITIONING_PATH = "../samples/data/reference_frames"
DEFAULT_OUTPUT_PATH = "../samples/output/interactive"

# Create output directory if it doesn't exist
os.makedirs(DEFAULT_OUTPUT_PATH, exist_ok=True)

In [4]:
# Default paths - adjust these to your setup
DEFAULT_CONFIG_PATH = "../echo/lvdm/configs/default.yaml"
DEFAULT_UNET_PATH = "../experiments/lvdm_cardiacnet_df/checkpoint-100000/unet_ema"
DEFAULT_VAE_PATH = "/nfs/usrhome/khmuhammad/Echonet/models/vae"
DEFAULT_CONDITIONING_PATH = (
    "/nfs/usrhome/khmuhammad/Echonet/data/latents/cardiacnet/Latents"
)
DEFAULT_OUTPUT_PATH = "../samples/output/interactive"

# Create output directory if it doesn't exist
os.makedirs(DEFAULT_OUTPUT_PATH, exist_ok=True)

## Load Models and Configuration

Now let's load the models and configuration.

In [5]:
def load_models(
    config_path=DEFAULT_CONFIG_PATH,
    unet_path=DEFAULT_UNET_PATH,
    vae_path=DEFAULT_VAE_PATH,
):
    # Load config
    config = OmegaConf.load(config_path)
    print(f"✅ Loaded config from {config_path}")

    # Load models
    unet = load_model(unet_path)
    vae = load_model(vae_path)

    # Move to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    unet = unet.to(device)
    vae = vae.to(device)

    # Set models to evaluation mode
    unet.eval()
    vae.eval()

    print(f"✅ Loaded UNet from {unet_path}")
    print(f"✅ Loaded VAE from {vae_path}")
    print(f"🖥️  Using device: {device}")

    return config, unet, vae, device


# Initialize models
config, unet, vae, device = load_models()

The config attributes {'decay': 0.9999, 'inv_gamma': 1.0, 'min_decay': 0.0, 'optimization_step': 100000, 'power': 0.6666666666666666, 'update_after_step': 0, 'use_ema_warmup': False} were passed to UNetSpatioTemporalConditionModel, but are not expected and will be ignored. Please verify your config.json configuration file.


✅ Loaded config from ../echo/lvdm/configs/default.yaml
✅ Loaded UNet from ../experiments/lvdm_cardiacnet_df/checkpoint-100000/unet_ema
✅ Loaded VAE from /nfs/usrhome/khmuhammad/Echonet/models/vae
🖥️  Using device: cuda
✅ Loaded UNet from ../experiments/lvdm_cardiacnet_df/checkpoint-100000/unet_ema
✅ Loaded VAE from /nfs/usrhome/khmuhammad/Echonet/models/vae
🖥️  Using device: cuda


## Load Conditioning Images/Latents

We need to load the conditioning frames that will be used to guide the video generation.

In [6]:
from echo.common.datasets import TensorSet, ImageSet, TensorSetv2
from torch.utils.data import DataLoader


def load_conditioning_data(conditioning_path=DEFAULT_CONDITIONING_PATH):
    # Detect file extension
    files = os.listdir(conditioning_path)
    if not files:
        raise ValueError(f"No files found in {conditioning_path}")

    file_ext = files[0].split(".")[-1].lower()
    if file_ext not in ["pt", "jpg", "png"]:
        raise ValueError(
            f"Unsupported file extension: {file_ext}. Must be pt, jpg, or png."
        )

    # Load appropriate dataset
    if file_ext == "pt":
        dataset = TensorSetv2(conditioning_path)
        print(f"📁 Loaded {len(dataset)} tensor files from {conditioning_path}")
    else:
        dataset = ImageSet(conditioning_path, ext=file_ext)
        print(f"📁 Loaded {len(dataset)} {file_ext} images from {conditioning_path}")

    return dataset, file_ext


# Load conditioning data
conditioning_dataset, file_ext = load_conditioning_data()

📁 Loaded 655 tensor files from /nfs/usrhome/khmuhammad/Echonet/data/latents/cardiacnet/Latents


## Create User Interface

Let's create a user interface for controlling the video generation process.

In [7]:
def create_ui():
    # Sampling parameters
    sampling_mode = widgets.Dropdown(
        options=[("Diffusion", "diffusion"), ("Flow Matching", "flow_matching")],
        value="diffusion",
        description="Sampling:",
        style={"description_width": "initial"},
    )

    num_steps = widgets.IntSlider(
        min=10,
        max=500,
        step=10,
        value=64,
        description="Steps:",
        style={"description_width": "initial"},
    )

    guidance_scale = widgets.FloatSlider(
        min=1.0,
        max=10.0,
        step=0.5,
        value=1.0,
        description="Guidance:",
        style={"description_width": "initial"},
        disabled=False,
    )

    # Conditioning parameters
    conditioning_type = widgets.Dropdown(
        options=[
            ("Class ID", "class_id"),
            ("LVEF Value", "lvef"),
            ("View Type", "view"),
        ],
        value="class_id",
        description="Condition:",
        style={"description_width": "initial"},
    )

    # We'll change these based on conditioning type
    class_id_value = widgets.IntSlider(
        min=0,
        max=10,
        step=1,
        value=3,
        description="Class ID:",
        style={"description_width": "initial"},
    )

    lvef_value = widgets.IntSlider(
        min=10,
        max=90,
        step=5,
        value=50,
        description="LVEF:",
        style={"description_width": "initial"},
        disabled=True,
    )

    view_value = widgets.IntSlider(
        min=0,
        max=5,
        step=1,
        value=0,
        description="View ID:",
        style={"description_width": "initial"},
        disabled=True,
    )

    # Output parameters
    num_frames = widgets.IntSlider(
        min=32,
        max=320,
        step=32,
        value=192,
        description="Frames:",
        style={"description_width": "initial"},
    )

    output_format = widgets.SelectMultiple(
        options=["mp4", "gif", "jpg", "avi"],
        value=["mp4", "jpg"],
        description="Format:",
        style={"description_width": "initial"},
    )

    seed = widgets.IntText(
        value=None,
        description="Seed:",
        style={"description_width": "initial"},
        placeholder="Random",
    )

    # File selection for conditioning frame
    file_options = [
        f for f in os.listdir(DEFAULT_CONDITIONING_PATH) if f.endswith(f".{file_ext}")
    ]

    conditioning_file = widgets.Dropdown(
        options=file_options,
        value=file_options[0] if file_options else None,
        description="Frame:",
        style={"description_width": "initial"},
    )

    # Generate button
    generate_button = widgets.Button(
        description="Generate Video",
        button_style="success",
        tooltip="Click to generate video with the selected parameters",
    )

    # Output area for preview
    output = widgets.Output()

    # Show/hide controls based on conditioning type
    def update_conditioning_ui(change):
        if change["new"] == "class_id":
            class_id_value.disabled = False
            lvef_value.disabled = True
            view_value.disabled = True
        elif change["new"] == "lvef":
            class_id_value.disabled = True
            lvef_value.disabled = False
            view_value.disabled = True
        elif change["new"] == "view":
            class_id_value.disabled = True
            lvef_value.disabled = True
            view_value.disabled = False

    conditioning_type.observe(update_conditioning_ui, names="value")

    # Show or hide frame preview
    def show_frame_preview(change):
        with output:
            clear_output()
            if not change["new"]:
                return

            frame_path = os.path.join(DEFAULT_CONDITIONING_PATH, change["new"])

            try:
                if file_ext == "pt":
                    # Load tensor
                    frame = torch.load(frame_path)
                    # Normalize for display
                    if frame.dim() > 3:  # If it's a video tensor with time dimension
                        frame = frame[0]  # Take first frame
                    # Convert to numpy for display
                    if frame.dim() == 3 and frame.shape[0] in [1, 3]:
                        # Format: C x H x W
                        frame = frame.permute(1, 2, 0).numpy()
                        if frame.shape[2] == 1:  # Grayscale
                            frame = frame[:, :, 0]
                    frame = (frame * 255).astype(np.uint8)
                else:
                    # Load image
                    frame = np.array(Image.open(frame_path))

                plt.figure(figsize=(5, 5))
                plt.imshow(frame, cmap="gray" if frame.ndim == 2 else None)
                plt.title("Selected conditioning frame")
                plt.axis("off")
                plt.show()
            except Exception as e:
                print(f"Error loading frame: {e}")

    conditioning_file.observe(show_frame_preview, names="value")

    # Layout
    sampling_box = widgets.VBox([sampling_mode, num_steps, guidance_scale])
    conditioning_box = widgets.VBox(
        [conditioning_type, class_id_value, lvef_value, view_value]
    )
    output_box = widgets.VBox([num_frames, output_format, seed])
    frame_box = widgets.VBox([conditioning_file])

    # Main layout
    ui = widgets.Tab()
    ui.children = [sampling_box, conditioning_box, output_box, frame_box]
    ui.set_title(0, "Sampling")
    ui.set_title(1, "Conditioning")
    ui.set_title(2, "Output")
    ui.set_title(3, "Frame")

    # Show initial frame preview
    show_frame_preview({"new": conditioning_file.value})

    return (
        ui,
        generate_button,
        output,
        sampling_mode,
        num_steps,
        guidance_scale,
        conditioning_type,
        class_id_value,
        lvef_value,
        view_value,
        num_frames,
        output_format,
        seed,
        conditioning_file,
    )


# Create the UI
ui_components = create_ui()
ui, generate_button, output = ui_components[0], ui_components[1], ui_components[2]
sampling_mode, num_steps, guidance_scale = (
    ui_components[3],
    ui_components[4],
    ui_components[5],
)
conditioning_type, class_id_value, lvef_value, view_value = (
    ui_components[6],
    ui_components[7],
    ui_components[8],
    ui_components[9],
)
num_frames, output_format, seed, conditioning_file = (
    ui_components[10],
    ui_components[11],
    ui_components[12],
    ui_components[13],
)

# Display the UI
display(ui)
display(generate_button)
display(output)

Tab(children=(VBox(children=(Dropdown(description='Sampling:', options=(('Diffusion', 'diffusion'), ('Flow Mat…

Button(button_style='success', description='Generate Video', style=ButtonStyle(), tooltip='Click to generate v…

Output()

## Generate Video Function

Now let's implement the video generation function that will be triggered by the UI.

In [None]:
from echo.common import pad_reshape, unpad_reshape, padf, unpadf
from torch.utils.data import Subset
from einops import rearrange
import diffusers
from echo.common import FlowMatchingScheduler
import time
import uuid


def generate_video_from_ui(b):
    # Clear previous output
    with output:
        clear_output()
        print("🚀 Starting video generation...")

        # Get parameters from UI
        curr_sampling_mode = sampling_mode.value
        curr_num_steps = num_steps.value
        curr_guidance_scale = guidance_scale.value
        curr_conditioning_type = conditioning_type.value

        # Get the appropriate conditioning value
        if curr_conditioning_type == "class_id":
            conditioning_value = class_id_value.value
        elif curr_conditioning_type == "lvef":
            conditioning_value = lvef_value.value
        else:  # view
            conditioning_value = view_value.value

        curr_num_frames = num_frames.value
        curr_output_formats = output_format.value
        curr_seed = (
            seed.value if seed.value not in [None, ""] else int(time.time()) % 1000000
        )
        curr_conditioning_file = conditioning_file.value

        # Print generation parameters
        print(f"📋 Parameters:")
        print(f"   - Sampling mode: {curr_sampling_mode}")
        print(f"   - Steps: {curr_num_steps}")
        print(f"   - Guidance scale: {curr_guidance_scale}")
        print(f"   - Conditioning: {curr_conditioning_type}={conditioning_value}")
        print(f"   - Frames: {curr_num_frames}")
        print(f"   - Output formats: {', '.join(curr_output_formats)}")
        print(f"   - Seed: {curr_seed}")
        print(f"   - Conditioning file: {curr_conditioning_file}")

        try:
            # Set up the scheduler based on sampling mode
            if curr_sampling_mode == "diffusion":
                scheduler_kwargs = OmegaConf.to_container(config.noise_scheduler)
                scheduler_klass_name = scheduler_kwargs.pop("_class_name")
                scheduler_klass = getattr(diffusers, scheduler_klass_name, None)
                assert scheduler_klass is not None, (
                    f"Could not find scheduler class {scheduler_klass_name}"
                )
                scheduler = scheduler_klass(**scheduler_kwargs)
            else:  # flow_matching
                scheduler = FlowMatchingScheduler(
                    num_train_timesteps=config.get("num_train_timesteps", 1000)
                )

            scheduler.set_timesteps(curr_num_steps)
            timesteps = scheduler.timesteps

            # Set up generator with seed
            generator = torch.Generator(device=device).manual_seed(curr_seed)

            # Load the conditioning frame
            conditioning_path = os.path.join(
                DEFAULT_CONDITIONING_PATH, curr_conditioning_file
            )

            if file_ext == "pt":
                latent_cond_images = torch.load(conditioning_path).to(device)
                if latent_cond_images.dim() == 5:  # B, C, C, H, W (TensorSetv2 format)
                    latent_cond_images = latent_cond_images.squeeze(1)
            else:  # jpg, png
                # Load image and convert to tensor
                image = Image.open(conditioning_path).convert("RGB")
                transform = transforms.Compose(
                    [
                        transforms.Resize(
                            (config.unet.sample_size, config.unet.sample_size)
                        ),
                        transforms.ToTensor(),
                        transforms.Normalize([0.5], [0.5]),
                    ]
                )
                image_tensor = transform(image).unsqueeze(0).to(device)

                # Project to latent space
                with torch.no_grad():
                    latent_cond_images = (
                        vae.encode(image_tensor).latent_dist.sample()
                        * vae.config.scaling_factor
                    )

            # Ensure batch dimension
            if latent_cond_images.dim() == 3:  # C, H, W
                latent_cond_images = latent_cond_images.unsqueeze(0)  # 1, C, H, W

            # Set up dimensions
            B = 1  # Just generate one video
            C = config.unet.out_channels
            H, W = config.unet.sample_size, config.unet.sample_size
            T = config.unet.num_frames

            # Stitching parameters
            NT = curr_num_frames
            if NT > T:
                OT = T // 2  # overlap
                TR = (NT - T) / (T - OT) + 1
                TR = int(TR + 0.999)  # ceiling
            else:
                OT = 0
                TR = 1
                NT = T

            print(
                f"🎬 Generating video with dimensions: B={B}, C={C}, T={NT}, H={H}, W={W}"
            )

            # Prepare latent noise
            latents = torch.randn((B, C, NT, H, W), device=device, generator=generator)

            # Get conditioning vector
            dtype = torch.float
            conditioning = get_conditioning_vector(
                curr_conditioning_type, conditioning_value, B, device, dtype
            )

            # Repeat conditioning for temporal stitching if needed
            conditioning = (
                conditioning.repeat_interleave(TR, dim=0) if TR > 1 else conditioning
            )

            # Format input/output functions
            format_input = (
                pad_reshape
                if config.unet._class_name == "UNetSpatioTemporalConditionModel"
                else padf
            )
            format_output = (
                unpad_reshape
                if config.unet._class_name == "UNetSpatioTemporalConditionModel"
                else unpadf
            )

            # Expand conditioning frame to video latents
            latent_cond_images = latent_cond_images[:, :, None, :, :].repeat(
                1, 1, NT, 1, 1
            )

            # Forward kwargs setup
            forward_kwargs = {"timestep": -1}

            if config.unet._class_name == "UNetSpatioTemporalConditionModel":
                dummy_added_time_ids = torch.zeros(
                    (B * TR, config.unet.addition_time_embed_dim), device=device
                )
                forward_kwargs["added_time_ids"] = dummy_added_time_ids

            if curr_conditioning_type == "text":
                forward_kwargs["encoder_hidden_states"] = conditioning
            else:
                forward_kwargs["encoder_hidden_states"] = conditioning

            print("⏳ Starting generation loop...")

            # Denoising loop
            with torch.no_grad():
                for i, t in enumerate(timesteps):
                    if i % 10 == 0 or i == len(timesteps) - 1:
                        print(f"   Step {i + 1}/{len(timesteps)}")

                    forward_kwargs["timestep"] = t

                    # Prepare model input
                    latent_model_input = scheduler.scale_model_input(
                        latents, timestep=t
                    )
                    latent_model_input = torch.cat(
                        (latent_model_input, latent_cond_images), dim=1
                    )  # B x 2C x T x H x W

                    # Handle classifier-free guidance
                    use_guidance = curr_guidance_scale > 1.0

                    if use_guidance and curr_sampling_mode == "diffusion":
                        # Create unconditional input
                        uncond_kwargs = forward_kwargs.copy()
                        uncond_kwargs["encoder_hidden_states"] = torch.zeros_like(
                            conditioning
                        )

                        # Format inputs
                        latent_model_input, padding = format_input(
                            latent_model_input, mult=3
                        )

                        # Stitching for conditional prediction
                        inputs = torch.cat(
                            [
                                latent_model_input[:, r * (T - OT) : r * (T - OT) + T]
                                for r in range(TR)
                            ],
                            dim=0,
                        )

                        # Conditional and unconditional predictions
                        noise_pred_cond = unet(inputs, **forward_kwargs).sample
                        noise_pred_uncond = unet(inputs, **uncond_kwargs).sample

                        # Apply guidance
                        outputs_cond = torch.chunk(noise_pred_cond, TR, dim=0)
                        outputs_uncond = torch.chunk(noise_pred_uncond, TR, dim=0)

                        noise_predictions = []
                        for r in range(TR):
                            cond_chunk = (
                                outputs_cond[r] if r == 0 else outputs_cond[r][:, OT:]
                            )
                            uncond_chunk = (
                                outputs_uncond[r]
                                if r == 0
                                else outputs_uncond[r][:, OT:]
                            )
                            guided_chunk = uncond_chunk + curr_guidance_scale * (
                                cond_chunk - uncond_chunk
                            )
                            noise_predictions.append(guided_chunk)

                        noise_pred = torch.cat(noise_predictions, dim=1)
                    else:
                        # Standard prediction without guidance
                        latent_model_input, padding = format_input(
                            latent_model_input, mult=3
                        )

                        inputs = torch.cat(
                            [
                                latent_model_input[:, r * (T - OT) : r * (T - OT) + T]
                                for r in range(TR)
                            ],
                            dim=0,
                        )

                        noise_pred = unet(inputs, **forward_kwargs).sample
                        outputs = torch.chunk(noise_pred, TR, dim=0)

                        noise_predictions = []
                        for r in range(TR):
                            noise_predictions.append(
                                outputs[r] if r == 0 else outputs[r][:, OT:]
                            )

                        noise_pred = torch.cat(noise_predictions, dim=1)

                    noise_pred = format_output(noise_pred, pad=padding)

                    # Update latents
                    if curr_sampling_mode == "diffusion":
                        latents = scheduler.step(noise_pred, t, latents).prev_sample
                    else:  # flow_matching
                        dt = 1.0 / (len(timesteps) - 1)
                        latents = latents - noise_pred * dt

            print("🎉 Finished generation! Decoding with VAE...")

            # Decode with VAE
            latents = latents / vae.config.scaling_factor

            # Process in chunks to save memory
            latents = rearrange(latents, "b c t h w -> (b t) c h w")
            chunk_size = 16  # Process 16 frames at a time
            video_chunks = []

            for i in range(0, latents.shape[0], chunk_size):
                chunk = latents[i : i + chunk_size].to(device)
                with torch.no_grad():
                    video_chunk = vae.decode(chunk).sample
                video_chunks.append(video_chunk.cpu())

            video = torch.cat(video_chunks, dim=0)  # (B*T) x C x H x W

            # Format output
            video = rearrange(video, "(b t) c h w -> b t h w c", b=B)
            video = (video + 1) * 128
            video = video.clamp(0, 255).to(torch.uint8)[0]  # Remove batch dimension

            print(
                f"📹 Video generated! Shape: {video.shape}, range: [{video.min()}, {video.max()}]"
            )

            # Create unique ID for this video
            video_id = f"video_{uuid.uuid4().hex[:8]}"

            # Save video in requested formats
            saved_paths = []

            # Create subdirectories for formats if they don't exist
            for fmt in curr_output_formats:
                os.makedirs(os.path.join(DEFAULT_OUTPUT_PATH, fmt), exist_ok=True)

            if "mp4" in curr_output_formats:
                mp4_path = os.path.join(DEFAULT_OUTPUT_PATH, "mp4", f"{video_id}.mp4")
                save_as_mp4(video, mp4_path)
                saved_paths.append(mp4_path)

            if "avi" in curr_output_formats:
                avi_path = os.path.join(DEFAULT_OUTPUT_PATH, "avi", f"{video_id}.avi")
                save_as_avi(video, avi_path)
                saved_paths.append(avi_path)

            if "gif" in curr_output_formats:
                gif_path = os.path.join(DEFAULT_OUTPUT_PATH, "gif", f"{video_id}.gif")
                save_as_gif(video, gif_path)
                saved_paths.append(gif_path)

            if "jpg" in curr_output_formats:
                jpg_dir = os.path.join(DEFAULT_OUTPUT_PATH, "jpg", video_id)
                save_as_img(video, jpg_dir, ext="jpg")
                saved_paths.append(jpg_dir)

            # Display the video
            if "mp4" in curr_output_formats:
                display(
                    HTML(f"""
                <div style="display: flex; justify-content: center;">
                    <video width="320" height="320" controls autoplay loop>
                        <source src="{mp4_path}" type="video/mp4">
                        Your browser does not support the video tag.
                    </video>
                </div>
                """)
                )
            elif "gif" in curr_output_formats:
                display(Image.open(gif_path))

            print("\n✅ Generated video saved to:")
            for path in saved_paths:
                print(f"   - {path}")

        except Exception as e:
            print(f"❌ Error generating video: {e}")
            import traceback

            traceback.print_exc()


# Attach the generate function to the button
generate_button.on_click(generate_video_from_ui)

## Advanced Options

You can customize the generation process further by modifying the default settings below:

In [None]:
# Change model paths if needed
def reload_models(config_path=None, unet_path=None, vae_path=None):
    global config, unet, vae, device
    config_path = config_path or DEFAULT_CONFIG_PATH
    unet_path = unet_path or DEFAULT_UNET_PATH
    vae_path = vae_path or DEFAULT_VAE_PATH

    config, unet, vae, device = load_models(config_path, unet_path, vae_path)


# Reload conditioning data from a different location
def reload_conditioning_data(conditioning_path=None):
    global conditioning_dataset, file_ext
    conditioning_path = conditioning_path or DEFAULT_CONDITIONING_PATH
    conditioning_dataset, file_ext = load_conditioning_data(conditioning_path)


# Example usage:
# reload_models(unet_path="/path/to/your/custom/unet")
# reload_conditioning_data("/path/to/your/custom/conditioning/frames")