In [1]:
# import huggingface_hub
# if not hasattr(huggingface_hub, "cached_download"):
#     from huggingface_hub import hf_hub_download
#     huggingface_hub.cached_download = hf_hub_download

import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import torchvision.transforms.functional as TF
from safetensors.torch import load_file
from diffusers import (
    AutoencoderKL,
    UNet2DConditionModel,
    UniPCMultistepScheduler,
)
from transformers import CLIPTextModel, CLIPTokenizer
from tqdm import tqdm
import matplotlib.pyplot as plt
# your classes
from flownet import DualFlowControlNet
from pipeline import StableDiffusionDualFlowControlNetPipeline

  from .autonotebook import tqdm as notebook_tqdm
  @torch.library.impl_abstract("xformers_flash::flash_fwd")
  @torch.library.impl_abstract("xformers_flash::flash_bwd")
  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
  def backward(self, tenOutgrad):


In [2]:
# ---------------------------
# Helpers: image & .flo loaders
# ---------------------------
def read_flo(path: str) -> np.ndarray:
    """Middlebury .flo → [H,W,2] float32 (pixel units)."""
    with open(path, "rb") as f:
        magic = np.fromfile(f, np.float32, 1)[0]
        if magic != 202021.25:
            raise ValueError(f"Invalid .flo file: {path} (magic={magic})")
        w = int(np.fromfile(f, np.int32, 1)[0])
        h = int(np.fromfile(f, np.int32, 1)[0])
        data = np.fromfile(f, np.float32, 2 * w * h).reshape(h, w, 2)
    return data

def resize_flow_to(flow_hw2: np.ndarray, target_h: int, target_w: int) -> torch.Tensor:
    """Resize flow with bilinear and scale vectors to remain in pixel units."""
    ft = torch.from_numpy(flow_hw2).permute(2, 0, 1).unsqueeze(0)  # [1,2,H,W]
    _, _, H, W = ft.shape
    ft = F.interpolate(ft, size=(target_h, target_w), mode="bilinear", align_corners=True)
    ft[:, 0] *= (target_w / max(W, 1))
    ft[:, 1] *= (target_h / max(H, 1))
    return ft  # [1,2,target_h,target_w]

def load_pair_to_sixch(path0, path1, size=(512, 512)) -> torch.Tensor:
    """Two RGB images → [1,6,H,W] in [0,1]."""
    def load_rgb(p):
        img = Image.open(p).convert("RGB")
        if size is not None:
            img = img.resize(size, Image.BICUBIC)
        return TF.to_tensor(img)  # [3,H,W], float32
    a = load_rgb(path0)
    b = load_rgb(path1)
    return torch.cat([a, b], dim=0).unsqueeze(0)  # [1,6,H,W]

def load_controls_and_flows(
    img0_path, img1_path, fwd_flo_path, bwd_flo_path, size=(512, 512), device="cuda", dtype=torch.float32
):
    H, W = size
    sixch = load_pair_to_sixch(img0_path, img1_path, size=size).to(device=device, dtype=dtype)  # [1,6,H,W]

    fwd = read_flo(fwd_flo_path)
    bwd = read_flo(bwd_flo_path)
    fwd_t = resize_flow_to(fwd, H, W)
    bwd_t = resize_flow_to(bwd, H, W)
    flow4 = torch.cat([fwd_t, bwd_t], dim=1).to(device=device, dtype=dtype)  # [1,4,H,W]
    return sixch, flow4


In [3]:
device = "cuda"
dtype  = torch.float32 
# ---------------------------
# Load controls
# ---------------------------
sixch, flow4 = load_controls_and_flows(
    "data/Beauty/images/frame_0000.png",
    "data/Beauty/images/frame_0002.png",
    "data/Beauty/optical_flow/optical_flow_gop_2_raft/flow_0000_0001.flo",
    "data/Beauty/optical_flow_bwd/optical_flow_gop_2_raft/flow_0002_0001.flo",
    size=(512, 512),
    device=device,
    dtype=dtype,
)

In [4]:
vae = AutoencoderKL.from_pretrained(
    "stabilityai/sdxl-vae",     # official HF repo
    torch_dtype=dtype
)

In [5]:
# ---------------------------
# Load models (aligned SD-1.5)
# ---------------------------
base = "stable-diffusion-v1-5/stable-diffusion-v1-5"

# vae = AutoencoderKL.from_pretrained(base, subfolder="vae", torch_dtype=dtype)
unet = UNet2DConditionModel.from_pretrained(base, subfolder="unet", torch_dtype=dtype)
text_encoder = CLIPTextModel.from_pretrained(base, subfolder="text_encoder", torch_dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained(base, subfolder="tokenizer")
scheduler = UniPCMultistepScheduler.from_pretrained(base, subfolder="scheduler")

# --- ControlNet: load your subclass weights ---
controlnet = DualFlowControlNet(
    block_out_channels=tuple(unet.config.block_out_channels),     # (320, 640, 1280, 1280)
    layers_per_block=2,
    cross_attention_dim=768,   
 )
# controlnet.load_state_dict(torch.load("path/to/controlnet.safetensors" or ".pth", map_location="cpu"))

# sanity: cross-attn dims must match (768 for SD1.x)
assert unet.config.cross_attention_dim == text_encoder.config.hidden_size == 768
if hasattr(controlnet, "config") and hasattr(controlnet.config, "cross_attention_dim"):
    assert controlnet.config.cross_attention_dim == 768, f"ControlNet CAD={controlnet.config.cross_attention_dim}"

ckpt = load_file('experiments/controlnet/checkpoint-11500/controlnet/diffusion_pytorch_model.safetensors')
controlnet.load_state_dict(ckpt,strict=False)

<All keys matched successfully>

In [6]:
safety_checker = None
feature_extractor = None

# ---------------------------
# Build pipeline
# ---------------------------
pipe = StableDiffusionDualFlowControlNetPipeline(
    vae=vae,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    unet=unet,
    controlnet=controlnet,
    scheduler=scheduler,
    safety_checker=safety_checker,
    feature_extractor=feature_extractor,
)
# pipe = pipe.to("cuda")

In [7]:
prompt = "a pretty girl smiling , has pink lipstick and is infront of black background"
g = torch.Generator(device=device).manual_seed(42)  

In [None]:
del vae, tokenizer,unet, controlnet

torch.cuda.empty_cache()

In [None]:
pipe = pipe.to("cuda")

In [None]:
import torch
from PIL import Image
import torchvision.transforms as T
from diffusers import AutoencoderKL

device = "cuda"
dtype  = torch.float32

# --- Load the pretrained VAE ---
base = "stable-diffusion-v1-5/stable-diffusion-v1-5"
vae = AutoencoderKL.from_pretrained(base, subfolder="vae", torch_dtype=dtype).to(device)

# --- Load and preprocess image ---
path = "data/Beauty/images/frame_0000.png"
image = Image.open(path).convert("RGB")

transform = T.Compose([
    T.Resize((512, 512)),   # match training resolution
    T.ToTensor(),           # [0,1]
    T.Normalize([0.5], [0.5]),  # [-1,1]
])
img_tensor = transform(image).unsqueeze(0).to(device, dtype=dtype)  # [1,3,512,512]

# --- Encode with VAE ---
with torch.no_grad():
    posterior = vae.encode(img_tensor).latent_dist
    latents = posterior.sample() * vae.config.scaling_factor  # [1,4,64,64]

print("Latent shape:", latents.shape)



In [8]:
out = pipe(prompt, sixch, flow4)

  0%|                                                                                         | 0/50 [00:00<?, ?it/s]


AssertionError: 

In [None]:
out[0][0]