In [1]:
import huggingface_hub
if not hasattr(huggingface_hub, "cached_download"):
    from huggingface_hub import hf_hub_download
    huggingface_hub.cached_download = hf_hub_download

import torch
import torch.nn.functional as F
import numpy as np
from PIL import Image
import torchvision.transforms.functional as TF
from safetensors.torch import load_file
from diffusers import (
    AutoencoderKL,
    UNet2DConditionModel,
    UniPCMultistepScheduler,
)
from transformers import CLIPTextModel, CLIPTokenizer
from tqdm import tqdm
import matplotlib.pyplot as plt
# your classes
from controlnet.flownet import DualFlowControlNet
from pipeline import StableDiffusionDualFlowControlNetPipeline
from controlnet.utils import load_controls_and_flows

  from .autonotebook import tqdm as notebook_tqdm
  @torch.library.impl_abstract("xformers_flash::flash_fwd")
  @torch.library.impl_abstract("xformers_flash::flash_bwd")
  @torch.cuda.amp.custom_fwd(cast_inputs=torch.float32)
  def backward(self, tenOutgrad):


In [3]:
# ---------------------------
# Load models (aligned SD-1.5)
# ---------------------------
dtype = torch.float32
base = "stable-diffusion-v1-5/stable-diffusion-v1-5"

vae = AutoencoderKL.from_pretrained(base, subfolder="vae", torch_dtype=dtype)
unet = UNet2DConditionModel.from_pretrained(base, subfolder="unet", torch_dtype=dtype)
text_encoder = CLIPTextModel.from_pretrained(base, subfolder="text_encoder", torch_dtype=dtype)
tokenizer = CLIPTokenizer.from_pretrained(base, subfolder="tokenizer")
scheduler = UniPCMultistepScheduler.from_pretrained(base, subfolder="scheduler")

# --- ControlNet: load your subclass weights ---
controlnet = DualFlowControlNet(
    block_out_channels=tuple(unet.config.block_out_channels),     # (320, 640, 1280, 1280)
    layers_per_block=2,
    cross_attention_dim=768,   
 )
# controlnet.load_state_dict(torch.load("path/to/controlnet.safetensors" or ".pth", map_location="cpu"))

# sanity: cross-attn dims must match (768 for SD1.x)
assert unet.config.cross_attention_dim == text_encoder.config.hidden_size == 768
if hasattr(controlnet, "config") and hasattr(controlnet.config, "cross_attention_dim"):
    assert controlnet.config.cross_attention_dim == 768, f"ControlNet CAD={controlnet.config.cross_attention_dim}"

ckpt = load_file('experiments/controlnet/checkpoint-153000/controlnet/diffusion_pytorch_model.safetensors')
model_state = controlnet.state_dict()

# Filter only matching keys with same shape
filtered_state_dict = {
    k: v for k, v in ckpt.items()
    if k in model_state and v.shape == model_state[k].shape
}

# Load compatible weights
# model.load_state_dict(filtered_state_dict, strict=False)

controlnet.load_state_dict(filtered_state_dict ,strict=False)

_IncompatibleKeys(missing_keys=['feature_extractor.wrapper.0.metric_net.0.weight', 'feature_extractor.wrapper.0.metric_net.0.bias', 'feature_extractor.wrapper.0.metric_net.2.weight', 'feature_extractor.wrapper.0.metric_net.2.bias', 'feature_extractor.wrapper.1.metric_net.0.weight', 'feature_extractor.wrapper.1.metric_net.0.bias', 'feature_extractor.wrapper.1.metric_net.2.weight', 'feature_extractor.wrapper.1.metric_net.2.bias', 'feature_extractor.wrapper.2.metric_net.0.weight', 'feature_extractor.wrapper.2.metric_net.0.bias', 'feature_extractor.wrapper.2.metric_net.2.weight', 'feature_extractor.wrapper.2.metric_net.2.bias', 'feature_extractor.wrapper.3.metric_net.0.weight', 'feature_extractor.wrapper.3.metric_net.0.bias', 'feature_extractor.wrapper.3.metric_net.2.weight', 'feature_extractor.wrapper.3.metric_net.2.bias', 'feature_extractor.zero_convs.0.weight', 'feature_extractor.zero_convs.1.weight', 'feature_extractor.zero_convs.2.weight', 'feature_extractor.zero_convs.3.weight'], une

In [4]:
safety_checker = None
feature_extractor = None

# ---------------------------
# Build pipeline
# ---------------------------
pipe = StableDiffusionDualFlowControlNetPipeline(
    vae=vae,
    text_encoder=text_encoder,
    tokenizer=tokenizer,
    unet=unet,
    controlnet=controlnet,
    scheduler=scheduler,
    safety_checker=safety_checker,
    feature_extractor=feature_extractor,
)
pipe = pipe.to("cuda")

In [5]:
# Validation data:
device = pipe.device

local_conditions = []
flow_conditions = []
prompts = ["A beautiful blonde girl smiling with pink lipstick with black background",
           "A Yacht with a red flag ,sailing in front of the Bosphorus in Istanbul , and bridge with cars is in the background." , 
           "A German shepherd shakes off water in the middle of a forest trail",
           "Honeybees hover among blooming purple flowers"]

videos = ['Beauty', 'Bosphorus', 'ShakeNDry', 'HoneyBee']
for video in videos:
    local,flow = load_controls_and_flows(
    f'data/{video}/images/frame_0000.png',
    f'data/{video}/images/frame_0004.png',
    f'data/{video}/optical_flow/optical_flow_gop_4_raft/flow_0000_0003.flo',
    f'data/{video}/optical_flow_bwd/optical_flow_gop_4_raft/flow_0004_0003.flo',
    size=(512, 512),
    device=device,
    dtype=dtype,
)
    local_conditions.append(local) 
    flow_conditions.append(flow)

In [6]:
del vae, tokenizer,unet, controlnet
torch.cuda.empty_cache()

In [7]:
pipe.enable_freeu(s1=0.9, s2=0.2, b1=1.2, b2=1.4)
images = []
for i in range(2):
    # with inference_ctx:
    out = pipe(
        prompt=prompts[i],
        controlnet_cond=local_conditions[i],   # [1,6,512,512]
        flow_cond=flow_conditions[i],        # [1,4,512,512]
        height=512,
        width=512,
        num_inference_steps=40,
        guidance_scale=4.5,
        negative_prompt=None,
        num_images_per_prompt=2,
        controlnet_conditioning_scale=1.7,
        guess_mode=False,
        output_type="pil",
        return_dict=True,
    )

    images.append(out.images)

  with torch.cuda.amp.autocast(enabled=False):
100%|█████████████████████████████████████████████████████████████████████████| 40/40 [00:16<00:00,  2.37it/s]
100%|█████████████████████████████████████████████████████████████████████████| 40/40 [00:14<00:00,  2.73it/s]


In [44]:
import matplotlib.pyplot as plt

image_logs = []
spacing = 20
img_size = 512

for i, video in enumerate(videos):
    pil0 = Image.open(f"data/{video}/images/frame_0000.png").convert("RGB").resize((img_size, img_size))
    pil4 = Image.open(f"data/{video}/images/frame_0004.png").convert("RGB").resize((img_size, img_size))
    gt   = Image.open(f"data/{video}/images/frame_0003.png").convert("RGB").resize((img_size, img_size))
    
    # 3 predictions
    pred1 = images[i][0]                       # pipeline pred 1
    pred2 = images[i][1]  # pipeline pred 2
    pred3 = Image.open(f"benchmark_results/preds_gop4_q4/{video}/im00003_pred.png").convert("RGB").resize((img_size, img_size))
    
    preds = [pred1.resize((img_size,img_size)), 
             pred2.resize((img_size,img_size)), 
             pred3.resize((img_size,img_size))]
    labels = ["Pred 1 - Pipeline", "Pred 2 - Pipeline", "Pred 3 - UniControl"]
    
    # Compute metrics
    metrics = {}
    for j, p in enumerate(preds):
        pred_tensor = torch.from_numpy(np.array(p).transpose(2,0,1)).float() / 255.0
        gt_tensor   = torch.from_numpy(np.array(gt).transpose(2,0,1)).float() / 255.0
        pred_tensor = pred_tensor.unsqueeze(0)
        gt_tensor   = gt_tensor.unsqueeze(0)
        
        ms_ssim_val = ms_ssim(pred_tensor, gt_tensor, data_range=1.0, size_average=True).item()
        mse = F.mse_loss(pred_tensor, gt_tensor).item()
        psnr_val = 10 * np.log10(1.0 / mse) if mse != 0 else float('inf')
        
        metrics[labels[j]] = {"MS-SSIM": ms_ssim_val, "PSNR": psnr_val}
    
    # --- Plot with matplotlib ---
    all_imgs = preds + [gt]
    titles = [f"{labels[j]}\nPSNR: {metrics[labels[j]]['PSNR']:.3f}, MS-SSIM: {metrics[labels[j]]['MS-SSIM']:.3f}" 
                                       for j in range(len(preds))] + ["Ground Truth"]
    
    ncols = len(all_imgs)
    fig, axs = plt.subplots(1, ncols, figsize=(4*ncols, 6))
    
    for ax, img, title in zip(axs, all_imgs, titles):
        ax.imshow(img)
        ax.set_title(title, fontsize=10)
        ax.axis("off")
    
    plt.tight_layout()
    save_path = f"benchmark_results/{video}_comparison_free_u.svg"
    plt.savefig(save_path, format="svg", dpi=800, bbox_inches="tight")
    plt.close(fig)
    
    # Log
    image_logs.append({
        "video": video,
        "preds": preds,
        "ground_truth": gt,
        "metrics": metrics
    })


In [None]:
import torch
from PIL import Image
import torchvision.transforms as T
from diffusers import AutoencoderKL

device = "cuda"
dtype  = torch.float32

# --- Load the pretrained VAE ---
base = "stable-diffusion-v1-5/stable-diffusion-v1-5"
vae = AutoencoderKL.from_pretrained(base, subfolder="vae", torch_dtype=dtype).to(device)

# --- Load and preprocess image ---
path = "data/Beauty/images/frame_0000.png"
image = Image.open(path).convert("RGB")

transform = T.Compose([
    T.Resize((512, 512)),   # match training resolution
    T.ToTensor(),           # [0,1]
    T.Normalize([0.5], [0.5]),  # [-1,1]
])
img_tensor = transform(image).unsqueeze(0).to(device, dtype=dtype)  # [1,3,512,512]

# --- Encode with VAE ---
with torch.no_grad():
    posterior = vae.encode(img_tensor).latent_dist
    latents = posterior.sample() * vae.config.scaling_factor  # [1,4,64,64]

print("Latent shape:", latents.shape)



In [None]:
latents = latents * pipe.scheduler.init_noise_sigma
out = pipe(prompt, sixch, flow4,guidance_scale=4 ,  controlnet_conditioning_scale=1.85,latents =latents )
out[0][0]