***

# SDXL Turbo Storyboard to Video
## Version 1.0

***

#### https://github.com/asigalov61/tegridy-vibe-code

***

### Project Los Angeles
### Tegridy Code 2026

***

# Setup Environment

!pip install U diffusers

In [None]:
!pip install -U accelerate

In [None]:
!pip install -U transformers

In [None]:
!pip install -U moviepy

In [None]:
!pip install numpy==1.26.4

# Run the script

In [None]:
#!/usr/bin/env python3
"""
Complete patched SDXL image2image morphing script
- Style-preserving pixel-space interpolation of decoded VAE images
- Per-frame strength interpolation to retain source style near endpoints
- Temporal smoothing in latent space
- Robust VAE encode/decode in float32 with NaN guards
- Upcast VAE to float32 and disable AMP autocast during pipe(...) calls to avoid fp16 instability
- Retry wrapper for pipeline calls with aggressive normalization and prompt-text fallback
- Deterministic seeding per-segment and per-frame
"""

import os
import math
import time
import csv
from typing import List, Tuple, Optional

import numpy as np
from PIL import Image
import torch
import torchvision.transforms as T
from diffusers import AutoPipelineForImage2Image
from torch.amp import autocast

# -------------------------
# User configuration
# -------------------------
device = "cuda" if torch.cuda.is_available() else "cpu"
pipeline_torch_dtype = torch.float32
vae_encode_dtype = torch.float32
autocast_dtype = torch.float32
autocast_enabled = False

model_id = "stabilityai/sdxl-turbo"
variant = "fp16"

prompts: List[str] = [
    "On a dark desert highway,",
    "cool wind in my hair",
    "Warm smell of colitas rising up through the air",
    "Up ahead in the distance, I saw a shimmering light",
]

# One image per prompt or None to use noise for that endpoint
prompt_images: Optional[List[Optional[str]]] = ['image_0.jpg', 'image_1.jpg', 'image_2.jpg', 'image_3.jpg']
seeds = -1 # [42, 1337, 7, 14]

num_steps_between = 120
output_dir = "frames_sdxl_img2img"
os.makedirs(output_dir, exist_ok=True)

# Strength controls: lower values preserve source style more
base_strength = 0.45
min_strength = 0.12

guidance_scale = 0.0  # SDXL Turbo requirement
num_inference_steps = 12

framerate = 16
output_video = "sdxl_img2img_morph.mp4"
cleanup_frames_after_video = False

DEBUG_FRAMES = 3
DEBUG_CSV = os.path.join(output_dir, "latent_stats.csv")
IMG_DEBUG_CSV = os.path.join(output_dir, "image_encode_debug.csv")

noise_mix_scale = 0.5
temporal_smooth_alpha = 0.24

# -------------------------
# Utility functions
# -------------------------
def lerp(a: torch.Tensor, b: torch.Tensor, t: float) -> torch.Tensor:
    return a * (1.0 - t) + b * t

def cosine_ease(t: float) -> float:
    return 0.5 - 0.5 * math.cos(math.pi * t)

def slerp(a: torch.Tensor, b: torch.Tensor, t: float, eps: float = 1e-6) -> torch.Tensor:
    orig_shape = a.shape
    a_flat = a.reshape(a.shape[0], -1)
    b_flat = b.reshape(b.shape[0], -1)

    a_norm = torch.linalg.norm(a_flat, dim=1, keepdim=True).clamp(min=eps)
    b_norm = torch.linalg.norm(b_flat, dim=1, keepdim=True).clamp(min=eps)

    a_unit = a_flat / a_norm
    b_unit = b_flat / b_norm

    dot = (a_unit * b_unit).sum(dim=1, keepdim=True).clamp(-1.0 + 1e-7, 1.0 - 1e-7)
    omega = torch.acos(dot)
    sin_omega = torch.sin(omega)

    near_parallel = (sin_omega.abs() < 1e-3).squeeze(1)

    out_flat = torch.empty_like(a_flat)
    for i in range(a_flat.shape[0]):
        if near_parallel[i]:
            out_flat[i] = lerp(a_flat[i], b_flat[i], t)
        else:
            o = omega[i, 0]
            s = sin_omega[i, 0]
            factor_a = torch.sin((1.0 - t) * o) / s
            factor_b = torch.sin(t * o) / s
            out_flat[i] = factor_a * a_flat[i] + factor_b * b_flat[i]

    mag = lerp(a_norm, b_norm, t)
    out_flat = out_flat / torch.linalg.norm(out_flat, dim=1, keepdim=True).clamp(min=eps) * mag

    return out_flat.reshape(orig_shape)

def ensure_list_of_seeds(seeds_input, n: int) -> List[int]:
    if isinstance(seeds_input, int):
        return [seeds_input] * n
    if isinstance(seeds_input, (list, tuple)):
        if len(seeds_input) == n:
            return list(seeds_input)
        if len(seeds_input) == 1:
            return [seeds_input[0]] * n
        raise ValueError(f"seeds must be an int or a list of length 1 or {n}")
    raise ValueError("seeds must be an int or a list/tuple of ints")

def ensure_list_of_images(images_input, n: int) -> List[Optional[str]]:
    if images_input is None:
        return [None] * n
    if isinstance(images_input, str):
        return [images_input] * n
    if isinstance(images_input, (list, tuple)):
        if len(images_input) == n:
            return list(images_input)
        if len(images_input) == 1:
            return [images_input[0]] * n
        raise ValueError(f"prompt_images must be None, a str, or a list of length 1 or {n}")
    raise ValueError("prompt_images must be None, a str, or a list/tuple of paths or None")

# -------------------------
# Load pipeline
# -------------------------
print("Loading image2image pipeline...")
load_kwargs = {"torch_dtype": pipeline_torch_dtype}
if variant:
    load_kwargs["variant"] = variant

pipe = AutoPipelineForImage2Image.from_pretrained(model_id, **load_kwargs).to(device)
pipe.set_progress_bar_config(disable=True)
torch.backends.cudnn.benchmark = True

if not hasattr(pipe, "encode_prompt"):
    raise RuntimeError("encode_prompt not available. Upgrade diffusers.")

init_sigma_raw = getattr(pipe.scheduler, "init_noise_sigma", None)
scheduler_init_sigma = float(init_sigma_raw) if init_sigma_raw is not None else 1.0
print("Scheduler init sigma:", scheduler_init_sigma)

# -------------------------
# Prompt encoding
# -------------------------
def encode_single(prompt: str) -> Tuple[torch.Tensor, torch.Tensor]:
    out = pipe.encode_prompt(
        prompt=prompt,
        device=device,
        num_images_per_prompt=1,
        do_classifier_free_guidance=False,
        negative_prompt=None,
    )
    if isinstance(out, tuple):
        prompt_embeds = out[0]
        pooled_prompt_embeds = out[1] if len(out) > 1 else None
    else:
        prompt_embeds = out
        pooled_prompt_embeds = None
    prompt_embeds = prompt_embeds.to(device=device, dtype=pipeline_torch_dtype)
    if pooled_prompt_embeds is None:
        proj_dim = getattr(pipe.text_encoder_2.config, "projection_dim", 1280)
        pooled_prompt_embeds = torch.zeros((1, proj_dim), device=device, dtype=pipeline_torch_dtype)
    else:
        pooled_prompt_embeds = pooled_prompt_embeds.to(device=device, dtype=pipeline_torch_dtype)
    return prompt_embeds, pooled_prompt_embeds

def encode_prompts(prompts_list: List[str]):
    encoded = []
    max_seq_len = 0
    for p in prompts_list:
        pe, pooled = encode_single(p)
        seq_len = pe.shape[1]
        if seq_len > max_seq_len:
            max_seq_len = seq_len
        encoded.append((pe, pooled))
    return encoded, max_seq_len

def pad_prompt_embeds(pe: torch.Tensor, target_seq_len: int) -> torch.Tensor:
    seq_len = pe.shape[1]
    if seq_len == target_seq_len:
        return pe
    dim = pe.shape[2]
    padded = torch.zeros((1, target_seq_len, dim), device=device, dtype=pipeline_torch_dtype)
    padded[:, :seq_len, :] = pe
    return padded

# -------------------------
# Image to latent and latent to PIL helpers
# -------------------------
def load_image(path_or_pil):
    if isinstance(path_or_pil, Image.Image):
        return path_or_pil.convert("RGB")
    return Image.open(path_or_pil).convert("RGB")

def preprocess_image_for_vae_fp32(img: Image.Image, target_pixel: int) -> torch.Tensor:
    transform = T.Compose([
        T.Resize((target_pixel, target_pixel), interpolation=T.InterpolationMode.LANCZOS),
        T.ToTensor(),
        T.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
    ])
    tensor = transform(img).unsqueeze(0).to(device=device, dtype=vae_encode_dtype)
    return tensor

def has_nan(t: torch.Tensor) -> bool:
    return torch.isnan(t).any().item()

def encode_image_to_latent(path_or_pil) -> torch.Tensor:
    img = load_image(path_or_pil)
    sample_size = getattr(pipe.unet.config, "sample_size", None)
    if sample_size is None:
        raise RuntimeError("Cannot determine UNet sample_size from pipeline")
    target_pixel = int(sample_size * 8)
    img_tensor = preprocess_image_for_vae_fp32(img, target_pixel)

    # Log preprocess stats
    img_min = float(img_tensor.min().cpu())
    img_max = float(img_tensor.max().cpu())
    img_mean = float(img_tensor.mean().cpu())
    img_std = float(img_tensor.std().cpu())
    if not os.path.exists(IMG_DEBUG_CSV):
        with open(IMG_DEBUG_CSV, "w", newline="") as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow(["image_path", "tensor_dtype", "tensor_shape", "min", "max", "mean", "std"])
    with open(IMG_DEBUG_CSV, "a", newline="") as csvfile:
        writer = csv.writer(csvfile)
        writer.writerow([getattr(path_or_pil, "filename", str(path_or_pil)), str(img_tensor.dtype), tuple(img_tensor.shape), f"{img_min:.8f}", f"{img_max:.8f}", f"{img_mean:.8f}", f"{img_std:.8f}"])

    if has_nan(img_tensor):
        img_tensor = torch.nan_to_num(img_tensor, nan=0.0, posinf=0.0, neginf=0.0)

    if not (hasattr(pipe, "vae") and pipe.vae is not None):
        raise RuntimeError("No VAE found on pipeline")

    vae = pipe.vae
    vae_device = next(vae.parameters()).device
    orig_vae_dtype = next(vae.parameters()).dtype

    need_cast_back = False
    if orig_vae_dtype == torch.float16:
        vae.to(torch.float32)
        need_cast_back = True

    img_for_encode = img_tensor.to(device=vae_device, dtype=torch.float32)

    with torch.no_grad():
        if device.startswith("cuda"):
            # from torch.amp import autocast
            with autocast("cuda", dtype=autocast_dtype, enabled=autocast_enabled):
                enc = vae.encode(img_for_encode)
        else:
            enc = vae.encode(img_for_encode)

    if need_cast_back:
        vae.to(orig_vae_dtype)

    lat = None
    if enc is None:
        raise RuntimeError("pipe.vae.encode returned None")

    try:
        if hasattr(enc, "latent_dist") and hasattr(enc.latent_dist, "mean"):
            cand = enc.latent_dist.mean.to(device=vae_device, dtype=torch.float32)
            if not torch.isnan(cand).any():
                lat = cand
            else:
                try:
                    samp = enc.latent_dist.sample().to(device=vae_device, dtype=torch.float32)
                    if not torch.isnan(samp).any():
                        lat = samp
                except Exception:
                    pass
    except Exception:
        pass

    if lat is None:
        try:
            if isinstance(enc, (list, tuple)) and len(enc) > 0 and isinstance(enc[0], torch.Tensor):
                cand = enc[0].to(device=vae_device, dtype=torch.float32)
                if not torch.isnan(cand).any():
                    lat = cand
        except Exception:
            pass

    if lat is None and isinstance(enc, torch.Tensor):
        cand = enc.to(device=vae_device, dtype=torch.float32)
        if not torch.isnan(cand).any():
            lat = cand

    if lat is None:
        try:
            cand = enc.latent_dist.mean.to(device=vae_device, dtype=torch.float32)
            if torch.isnan(cand).any():
                cand = torch.nan_to_num(cand, nan=0.0, posinf=0.0, neginf=0.0)
            lat = cand
            print(f"[IMG DEBUG] used fallback latent (NaNs replaced) for {getattr(path_or_pil, 'filename', str(path_or_pil))}")
        except Exception as e:
            raise RuntimeError(f"Failed to extract any latent candidate: {e}")

    if has_nan(lat):
        lat = torch.nan_to_num(lat, nan=0.0, posinf=0.0, neginf=0.0)

    lat = lat.to(device=device, dtype=torch.float32)

    vae_scaling = getattr(pipe.vae.config, "scaling_factor", 0.13025)
    if vae_scaling is None:
        vae_scaling = 0.13025
    lat = lat * float(vae_scaling)

    # print(f"[IMG DEBUG] Final image latent for {getattr(path_or_pil, 'filename', str(path_or_pil))}: dtype {lat.dtype} shape {lat.shape} min {float(lat.min()):.6f} max {float(lat.max()):.6f} std {float(lat.std()):.6f}")

    return lat

def latent_to_pil(latent: torch.Tensor) -> Image.Image:
    if latent is None:
        raise ValueError("latent is None")

    vae_scaling = getattr(pipe.vae.config, "scaling_factor", 0.13025)
    lat_for_decode = (latent / float(vae_scaling)).to(dtype=torch.float32)

    vae = pipe.vae
    vae_device = next(vae.parameters()).device
    orig_vae_dtype = next(vae.parameters()).dtype

    need_cast_back = False
    if orig_vae_dtype == torch.float16:
        vae.to(torch.float32)
        need_cast_back = True

    noisy_for_decode = lat_for_decode.to(device=vae_device, dtype=torch.float32)

    with torch.no_grad():
        if device.startswith("cuda"):
            # from torch.amp import autocast
            with autocast("cuda", dtype=autocast_dtype, enabled=autocast_enabled):
                decoded = vae.decode(noisy_for_decode)
        else:
            decoded = vae.decode(noisy_for_decode)

    if need_cast_back:
        vae.to(orig_vae_dtype)

    decoded_tensor = None
    if isinstance(decoded, torch.Tensor):
        decoded_tensor = decoded
    else:
        sample_attr = getattr(decoded, "sample", None)
        if callable(sample_attr):
            try:
                decoded_tensor = sample_attr()
            except Exception:
                decoded_tensor = None
        if decoded_tensor is None and sample_attr is not None and not callable(sample_attr):
            try:
                if isinstance(sample_attr, torch.Tensor):
                    decoded_tensor = sample_attr
                else:
                    decoded_tensor = torch.as_tensor(sample_attr)
            except Exception:
                decoded_tensor = None
        if decoded_tensor is None and isinstance(decoded, (list, tuple)) and len(decoded) > 0:
            cand = decoded[0]
            try:
                decoded_tensor = torch.as_tensor(cand) if not isinstance(cand, torch.Tensor) else cand
            except Exception:
                decoded_tensor = None
        if decoded_tensor is None and isinstance(decoded, dict):
            for key in ("sample", "reconstruction", "recon", "decoded"):
                if key in decoded:
                    cand = decoded[key]
                    try:
                        decoded_tensor = torch.as_tensor(cand) if not isinstance(cand, torch.Tensor) else cand
                        break
                    except Exception:
                        decoded_tensor = None

    if decoded_tensor is None:
        raise RuntimeError("Unexpected VAE.decode return type; cannot convert to a tensor for image conversion.")

    decoded_tensor = decoded_tensor.to(device="cpu", dtype=torch.float32)
    decoded_tensor = (decoded_tensor.clamp(-1.0, 1.0) + 1.0) / 2.0
    decoded_arr = (decoded_tensor * 255.0).permute(0, 2, 3, 1).cpu().numpy().astype("uint8")
    pil = Image.fromarray(decoded_arr[0])
    return pil

# -------------------------
# Robust pipeline call helper with VAE upcast and autocast disabled
# -------------------------
def safe_call_pipe_with_image(pipe, *, prompt_embeds=None, pooled_prompt_embeds=None,
                              prompt_text: Optional[str]=None, image: Image.Image,
                              strength: float, num_inference_steps: int,
                              guidance_scale: float, generator: torch.Generator,
                              output_type: str = "pil", max_retries: int = 2,
                              min_scheduler_steps: int = 12, retry_with_steps: int = 20):
    """
    Robust wrapper for image2image pipeline with additional scheduler guards and
    automatic retry with more steps when very small num_inference_steps cause failures.

    - Ensures image is RGB and target size.
    - Ensures scheduler has at least `min_scheduler_steps` timesteps.
    - Temporarily upcasts VAE to float32 and disables autocast for the call.
    - Retries normalization attempts; if still failing and original num_inference_steps
      is small, retries once with `retry_with_steps` steps.
    - Falls back to prompt-text call if prompt_embeds path fails.
    """
    # Normalize image
    if image.mode != "RGB":
        image = image.convert("RGB")

    target_px = pipe.unet.config.sample_size * 8
    if image.size != (target_px, target_px):
        image = image.resize((target_px, target_px), resample=Image.LANCZOS)

    # Ensure scheduler timesteps are sane for this call
    try:
        # Use at least min_scheduler_steps to avoid degenerate scheduler arrays
        safe_steps = max(int(num_inference_steps), int(min_scheduler_steps))
        pipe.scheduler.set_timesteps(safe_steps)
        timesteps = getattr(pipe.scheduler, "timesteps", None)
        # If timesteps is None or empty, force a small valid set
        if timesteps is None or (hasattr(timesteps, "__len__") and len(timesteps) == 0):
            pipe.scheduler.set_timesteps(max(safe_steps, min_scheduler_steps))
            timesteps = pipe.scheduler.timesteps
    except Exception:
        # If scheduler manipulation fails, still proceed but mark timesteps invalid
        timesteps = getattr(pipe.scheduler, "timesteps", None)

    vae = pipe.vae
    orig_vae_dtype = next(vae.parameters()).dtype

    last_exc = None
    attempted_with_increased_steps = False

    for attempt in range(max_retries):
        try:
            # Upcast VAE to float32 for stable encode/decode if needed
            need_cast_back = False
            if orig_vae_dtype == torch.float16:
                vae.to(torch.float32)
                need_cast_back = True

            # Ensure prompt_embeds are on correct device/dtype
            if prompt_embeds is not None:
                prompt_embeds = prompt_embeds.to(device=device, dtype=pipeline_torch_dtype)
            if pooled_prompt_embeds is not None:
                pooled_prompt_embeds = pooled_prompt_embeds.to(device=device, dtype=pipeline_torch_dtype)

            # Disable autocast for the pipeline call to avoid mixed-precision VAE ops
            if device.startswith("cuda"):
                # from torch.amp import autocast
                with autocast("cuda", dtype=autocast_dtype, enabled=autocast_enabled):
                    out = pipe(
                        prompt_embeds=prompt_embeds,
                        pooled_prompt_embeds=pooled_prompt_embeds,
                        image=image,
                        strength=strength,
                        num_inference_steps=num_inference_steps,
                        guidance_scale=guidance_scale,
                        generator=generator,
                        output_type=output_type
                    )
            else:
                out = pipe(
                    prompt_embeds=prompt_embeds,
                    pooled_prompt_embeds=pooled_prompt_embeds,
                    image=image,
                    strength=strength,
                    num_inference_steps=num_inference_steps,
                    guidance_scale=guidance_scale,
                    generator=generator,
                    output_type=output_type
                )

            # Restore VAE dtype if we changed it
            if need_cast_back:
                vae.to(orig_vae_dtype)

            return out

        except Exception as e:
            last_exc = e
            # If the failure looks like the zero-reshape issue and we haven't yet retried
            # with more steps, do that first (this often fixes small-step failures).
            err_msg = str(e).lower()
            is_reshape_zero_err = "cannot reshape tensor of 0 elements" in err_msg or "shape [0" in err_msg

            if is_reshape_zero_err and (not attempted_with_increased_steps) and int(num_inference_steps) < int(retry_with_steps):
                # Try again with a larger number of steps (keeps other normalization)
                attempted_with_increased_steps = True
                print(f"[PIPE RETRY-STEPS] detected small-step reshape failure; retrying with {retry_with_steps} steps.")
                try:
                    # Up the scheduler steps and try once more
                    pipe.scheduler.set_timesteps(int(retry_with_steps))
                except Exception:
                    pass
                # loop will retry automatically
                continue

            # Otherwise perform the previous normalization retry behavior
            print(f"[PIPE RETRY] attempt {attempt+1} failed: {e}. Retrying after re-normalize.")
            try:
                arr = np.asarray(image).astype(np.uint8)
                image = Image.fromarray(arr).convert("RGB").resize((target_px, target_px), Image.LANCZOS)
            except Exception as e2:
                print(f"[PIPE RETRY] aggressive normalize failed: {e2}")

            # restore VAE dtype before next attempt to keep pipeline consistent
            try:
                vae.to(orig_vae_dtype)
            except Exception:
                pass

    # Final fallback: call pipeline with prompt text (let pipeline compute embeddings)
    if prompt_text is not None:
        try:
            print("[PIPE FALLBACK] calling pipeline with prompt text fallback.")
            need_cast_back = False
            if orig_vae_dtype == torch.float16:
                vae.to(torch.float32)
                need_cast_back = True

            if device.startswith("cuda"):
                # from torch.amp import autocast
                with autocast("cuda", dtype=autocast_dtype, enabled=autocast_enabled):
                    out = pipe(
                        prompt=prompt_text,
                        image=image,
                        strength=strength,
                        num_inference_steps=max(num_inference_steps, retry_with_steps),
                        guidance_scale=guidance_scale,
                        generator=generator,
                        output_type=output_type
                    )
            else:
                out = pipe(
                    prompt=prompt_text,
                    image=image,
                    strength=strength,
                    num_inference_steps=max(num_inference_steps, retry_with_steps),
                    guidance_scale=guidance_scale,
                    generator=generator,
                    output_type=output_type
                )

            if need_cast_back:
                vae.to(orig_vae_dtype)
            return out
        except Exception as e:
            last_exc = e
        finally:
            try:
                vae.to(orig_vae_dtype)
            except Exception:
                pass

    raise RuntimeError(f"image2image pipeline failed after retries: {last_exc}")

# -------------------------
# Prepare prompts, seeds, images
# -------------------------

# Start total timer
total_start = time.time()

if not isinstance(prompts, (list, tuple)) or len(prompts) < 2:
    raise ValueError("prompts must be a list with at least two items")

num_prompts = len(prompts)
seeds_list = ensure_list_of_seeds(seeds, num_prompts)
prompt_images_list = ensure_list_of_images(prompt_images, num_prompts)

print("Encoding prompts...")
encoded_list, max_seq_len = encode_prompts(prompts)

padded_prompts = []
pooled_prompts = []
for pe, pooled in encoded_list:
    pe_padded = pad_prompt_embeds(pe, max_seq_len)
    padded_prompts.append(pe_padded)
    pooled_prompts.append(pooled)

image_latents = []
image_latent_flags = []
image_pils = []
for idx, img_path in enumerate(prompt_images_list):
    if img_path is None:
        image_latents.append(None)
        image_latent_flags.append(False)
        image_pils.append(None)
    else:
        print(f"Encoding image to latent for prompt index {idx}: {img_path}")
        lat = encode_image_to_latent(img_path)
        image_latents.append(lat)
        image_latent_flags.append(True)
        pil = latent_to_pil(lat)
        image_pils.append(pil)

# -------------------------
# Generation loop
# -------------------------
print("Generating frames with image2image...")
frame_index = 0
total_segments = num_prompts - 1

shape = (
    1,
    pipe.unet.config.in_channels,
    pipe.unet.config.sample_size,
    pipe.unet.config.sample_size,
)
print("UNet latent shape expected:", shape)

with open(DEBUG_CSV, "w", newline="") as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(["frame_index", "segment", "step", "raw_t", "eased_t", "is_img_a", "is_img_b",
                     "latent_min", "latent_max", "latent_mean", "latent_std", "time_s"])

prev_latents: Optional[torch.Tensor] = None

pipe.scheduler.set_timesteps(num_inference_steps)
timesteps = pipe.scheduler.timesteps
t_start = timesteps[0]

for seg_idx in range(total_segments):

    seg_start = time.time()
    
    pe_a = padded_prompts[seg_idx]
    pe_b = padded_prompts[seg_idx + 1]
    pooled_a = pooled_prompts[seg_idx]
    pooled_b = pooled_prompts[seg_idx + 1]

    img_lat_a = image_latents[seg_idx]
    img_lat_b = image_latents[seg_idx + 1]
    pil_a = image_pils[seg_idx]
    pil_b = image_pils[seg_idx + 1]
    is_img_a = image_latent_flags[seg_idx]
    is_img_b = image_latent_flags[seg_idx + 1]

    steps_in_segment = num_steps_between

    seed_a = seeds_list[seg_idx]
    seed_b = seeds_list[seg_idx + 1]

    combined_seed = (int(seed_a) * 31 + int(seed_b) * 17) & 0x7FFFFFFF

    pipe.scheduler.set_timesteps(num_inference_steps)
    timesteps = pipe.scheduler.timesteps
    t_start = timesteps[0]

    # If mixing image and noise, create scaled noise endpoints
    if is_img_a and not is_img_b and img_lat_b is None:
        tmp_gen_b = torch.Generator(device=device).manual_seed(int(seed_b))
        img_lat_b = torch.randn(shape, generator=tmp_gen_b, device=device, dtype=torch.float32) * float(scheduler_init_sigma) * noise_mix_scale
    if is_img_b and not is_img_a and img_lat_a is None:
        tmp_gen_a = torch.Generator(device=device).manual_seed(int(seed_a))
        img_lat_a = torch.randn(shape, generator=tmp_gen_a, device=device, dtype=torch.float32) * float(scheduler_init_sigma) * noise_mix_scale

    for step in range(steps_in_segment + 1):

        step_start = time.time()
        
        if step == steps_in_segment and seg_idx < total_segments - 1:
            continue

        raw_t = step / float(steps_in_segment)
        eased_t = cosine_ease(raw_t)

        cond_interp = lerp(pe_a, pe_b, eased_t)
        pooled_interp = lerp(pooled_a, pooled_b, eased_t)

        # Branch: both-noise segment
        if (not is_img_a) and (not is_img_b):
            if img_lat_a is None:
                tmp_gen_a = torch.Generator(device=device).manual_seed(int(seed_a))
                endpoint_a = torch.randn(shape, generator=tmp_gen_a, device=device, dtype=torch.float32) * float(scheduler_init_sigma)
            else:
                endpoint_a = img_lat_a.to(device=device, dtype=torch.float32)
            if img_lat_b is None:
                tmp_gen_b = torch.Generator(device=device).manual_seed(int(seed_b))
                endpoint_b = torch.randn(shape, generator=tmp_gen_b, device=device, dtype=torch.float32) * float(scheduler_init_sigma)
            else:
                endpoint_b = img_lat_b.to(device=device, dtype=torch.float32)

            latents = slerp(endpoint_a, endpoint_b, eased_t)

            if prev_latents is not None and temporal_smooth_alpha > 0.0:
                latents = prev_latents * temporal_smooth_alpha + latents * (1.0 - temporal_smooth_alpha)

            if torch.isnan(latents).any():
                latents = torch.nan_to_num(latents, nan=0.0, posinf=0.0, neginf=0.0)

            vae_scaling = getattr(pipe.vae.config, "scaling_factor", 0.13025)
            lat_for_noise = latents / float(vae_scaling)

            gen = torch.Generator(device=device).manual_seed(int((combined_seed + step) & 0x7FFFFFFF))
            noise = torch.randn(lat_for_noise.shape, generator=gen, device=device, dtype=torch.float32)

            try:
                noisy_latents = pipe.scheduler.add_noise(lat_for_noise, noise, t_start)
            except Exception:
                noisy_latents = lat_for_noise + noise * float(scheduler_init_sigma)

            pil_noisy = latent_to_pil(noisy_latents)

            strength_use = max(min_strength, base_strength + 0.05)
            frame_gen = torch.Generator(device=device).manual_seed(int((combined_seed + step) & 0x7FFFFFFF))
            fallback_prompt_text = prompts[seg_idx] if raw_t < 0.5 else prompts[seg_idx + 1]

            out = safe_call_pipe_with_image(
                pipe,
                prompt_embeds=cond_interp,
                pooled_prompt_embeds=pooled_interp,
                prompt_text=fallback_prompt_text,
                image=pil_noisy,
                strength=strength_use,
                num_inference_steps=num_inference_steps,
                guidance_scale=guidance_scale,
                generator=frame_gen,
                output_type="pil"
            )

            image = out.images[0]
            frame_path = os.path.join(output_dir, f"frame_{frame_index:06d}.png")
            image.save(frame_path)
            print(f"Saved {frame_path}")

            try:
                enc_lat = encode_image_to_latent(image)
                prev_latents = enc_lat.to(device=device, dtype=torch.float32)
            except Exception:
                prev_latents = None

            t0 = time.time()
            if prev_latents is not None:
                lmin = float(prev_latents.min().cpu())
                lmax = float(prev_latents.max().cpu())
                lmean = float(prev_latents.mean().cpu())
                lstd = float(prev_latents.std().cpu())
            else:
                lmin = lmax = lmean = lstd = 0.0
            with open(DEBUG_CSV, "a", newline="") as csvfile:
                writer = csv.writer(csvfile)
                writer.writerow([frame_index, seg_idx, step, f"{raw_t:.6f}", f"{eased_t:.6f}", int(is_img_a), int(is_img_b),
                                 f"{lmin:.8f}", f"{lmax:.8f}", f"{lmean:.8f}", f"{lstd:.8f}", f"{t0:.6f}"])

            if frame_index < DEBUG_FRAMES:
                print(f"DEBUG FRAME {frame_index} seg {seg_idx} step {step}/{steps_in_segment} raw_t={raw_t:.4f} eased_t={eased_t:.4f}")
                print("  pil_noisy size", pil_noisy.size)
                print(f"  latent stats min {lmin:.6f} max {lmax:.6f} mean {lmean:.6f} std {lstd:.6f}")

            frame_index += 1
            step_end = time.time()
            step_runtime = step_end - step_start
            print('-' * 70)
            print(f"Step {step} runtime: {step_runtime:.6f} sec")
            print('-' * 70)
            continue

        # Branch: at least one side is an image -> pixel-space interpolation
        if pil_a is None and pil_b is None:
            pil_a = pil_b = Image.new("RGB", (pipe.unet.config.sample_size * 8, pipe.unet.config.sample_size * 8), (127, 127, 127))
        if pil_a is None:
            pil_a = pil_b.copy()
        if pil_b is None:
            pil_b = pil_a.copy()

        target_px = pipe.unet.config.sample_size * 8
        if pil_a.size != (target_px, target_px):
            pil_a = pil_a.resize((target_px, target_px), resample=Image.LANCZOS)
        if pil_b.size != (target_px, target_px):
            pil_b = pil_b.resize((target_px, target_px), resample=Image.LANCZOS)

        a_arr = np.asarray(pil_a).astype(np.float32) / 255.0
        b_arr = np.asarray(pil_b).astype(np.float32) / 255.0
        interp_arr = (1.0 - eased_t) * a_arr + eased_t * b_arr
        interp_arr = (interp_arr * 255.0).clip(0, 255).astype("uint8")
        pil_noisy = Image.fromarray(interp_arr)

        strength = base_strength * (1.0 - 0.5 * (1.0 - math.cos(math.pi * raw_t)))
        strength = max(min_strength, float(strength))

        frame_gen = torch.Generator(device=device).manual_seed(int((combined_seed + step) & 0x7FFFFFFF))
        fallback_prompt_text = prompts[seg_idx] if raw_t < 0.5 else prompts[seg_idx + 1]

        out = safe_call_pipe_with_image(
            pipe,
            prompt_embeds=cond_interp,
            pooled_prompt_embeds=pooled_interp,
            prompt_text=fallback_prompt_text,
            image=pil_noisy,
            strength=strength,
            num_inference_steps=num_inference_steps,
            guidance_scale=guidance_scale,
            generator=frame_gen,
            output_type="pil"
        )

        image = out.images[0]
        frame_path = os.path.join(output_dir, f"frame_{frame_index:06d}.png")
        image.save(frame_path)
        print(f"Saved {frame_path}")

        try:
            enc_lat = encode_image_to_latent(image)
            prev_latents = enc_lat.to(device=device, dtype=torch.float32)
        except Exception:
            prev_latents = None

        t0 = time.time()
        if prev_latents is not None:
            lmin = float(prev_latents.min().cpu())
            lmax = float(prev_latents.max().cpu())
            lmean = float(prev_latents.mean().cpu())
            lstd = float(prev_latents.std().cpu())
        else:
            lmin = lmax = lmean = lstd = 0.0
        with open(DEBUG_CSV, "a", newline="") as csvfile:
            writer = csv.writer(csvfile)
            writer.writerow([frame_index, seg_idx, step, f"{raw_t:.6f}", f"{eased_t:.6f}", int(is_img_a), int(is_img_b),
                             f"{lmin:.8f}", f"{lmax:.8f}", f"{lmean:.8f}", f"{lstd:.8f}", f"{t0:.6f}"])

        if frame_index < DEBUG_FRAMES:
            print(f"DEBUG FRAME {frame_index} seg {seg_idx} step {step}/{steps_in_segment} raw_t={raw_t:.4f} eased_t={eased_t:.4f}")
            print("  pil_noisy size", pil_noisy.size)
            print(f"  latent stats min {lmin:.6f} max {lmax:.6f} mean {lmean:.6f} std {lstd:.6f}")

        frame_index += 1

        step_end = time.time()
        step_runtime = step_end - step_start
        print('-' * 70)
        print(f"Step {step} runtime: {step_runtime:.6f} sec")
        print('-' * 70)

    seg_end = time.time()
    seg_runtime = seg_end - seg_start
    print('=' * 70)
    print(f"Segment {seg_idx} runtime: {seg_runtime:.6f} sec")
    print('=' * 70)

# -------------------------
# Assemble video
# -------------------------
print("Assembling video...")
frames = sorted([os.path.join(output_dir, p) for p in os.listdir(output_dir) if p.startswith("frame_") and p.endswith(".png")])
if not frames:
    raise RuntimeError("No frames found")

try:
    from moviepy.video.io.ImageSequenceClip import ImageSequenceClip
    clip = ImageSequenceClip(frames, fps=framerate)
    clip.write_videofile(output_video, codec="libx264", audio=False, logger=None)
    print(f"Saved video to {output_video} using moviepy")
except Exception as e:
    print("moviepy failed, trying ffmpeg:", str(e))
    input_pattern = os.path.join(output_dir, "frame_%06d.png")
    import subprocess
    subprocess.run(
        [
            "ffmpeg", "-y", "-framerate", str(framerate),
            "-i", input_pattern, "-c:v", "libx264",
            "-pix_fmt", "yuv420p", output_video
        ],
        check=True
    )
    print(f"Saved video to {output_video} using ffmpeg")

if cleanup_frames_after_video:
    for p in frames:
        try:
            os.remove(p)
        except Exception:
            pass
    print("Cleaned up intermediate PNG frames")


# End total timer
total_end = time.time()
total_runtime = total_end - total_start

print('=' * 70)
print(f"\nTotal runtime: {total_runtime:.6f} sec")
print('=' * 70)
print("Done.")
print('=' * 70)

# Congrats! You did it :)