In [None]:
import os
import torch
import numpy as np
from PIL import Image as PILImage
from tqdm import tqdm
import torch.nn as nn
import torch.nn.functional as F
import abc
from diffusers import StableVideoDiffusionPipeline, DDIMScheduler

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
class AttentionControl(abc.ABC):
    def __init__(self):
        self.cur_step = 0
        self.num_att_layers = -1
        self.cur_att_layer = 0

    def step_callback(self, x_t):
        return x_t

    def between_steps(self):
        return

    @property
    def num_uncond_att_layers(self):
        return 0

    @abc.abstractmethod
    def forward(self, attn_dict, is_cross: bool, place_in_unet: str):
        raise NotImplementedError

    def __call__(self, attn_dict, is_cross: bool, place_in_unet: str):
        attn_dict = self.forward(attn_dict, is_cross, place_in_unet)
        return attn_dict['attn']

    def reset(self):
        self.cur_step = 0
        self.cur_att_layer = 0

In [None]:
class AttentionStore(AttentionControl):
    @staticmethod
    def get_empty_store():
        return {"attn": []}

    def __init__(self):
        super(AttentionStore, self).__init__()
        self.step_store = self.get_empty_store()

    def forward(self, attn_dict, is_cross: bool, place_in_unet: str):
        # Simply store the attention map
        self.step_store["attn"].append(attn_dict['attn'])
        return attn_dict

    def reset(self):
        super(AttentionStore, self).reset()
        self.step_store = self.get_empty_store()

In [None]:
def register_attention_control(model, controller: AttentionControl, feature_upsample_res=256):
    def ca_forward(self, place_in_unet):
        # Save reference to output projection
        to_out = self.to_out
        if isinstance(to_out, nn.ModuleList):
            to_out = self.to_out[0]  # handle case where to_out is in a ModuleList

        def forward(x, context=None, mask=None):
            batch_size, sequence_length, dim = x.shape
            h = self.heads
            is_cross = context is not None
            context = context if is_cross else x
            # Standard QKV computation
            q = self.to_q(x)
            k = self.to_k(context)
            v = self.to_v(context)
            # Reshape to [batch*heads, ...] for attention computation
            q = self.reshape_heads_to_batch_dim(q)
            k = self.reshape_heads_to_batch_dim(k)
            v = self.reshape_heads_to_batch_dim(v)
            # Compute attention scores
            sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
            if mask is not None:
                # Apply mask if provided
                mask = mask.reshape(batch_size, -1)
                max_neg_value = -torch.finfo(sim.dtype).max
                mask = mask[:, None, :].repeat(h, 1, 1)
                sim = sim.masked_fill(~mask, max_neg_value)
            # Softmax to get attention probabilities
            attn = torch.softmax(sim, dim=-1)
            attn = attn.clone()  # clone to avoid modifying original
            # Compute attention output
            out = torch.matmul(attn, v)
            # If this is a cross-attention and spatial size is small (e.g., 32x32 or less), upsample for keypoint detection
            if is_cross and sequence_length <= feature_upsample_res**2 and len(controller.step_store["attn"]) < 4:
                # Determine spatial dimensions (H, W) of attention map
                spatial = int(sequence_length**0.5)
                if spatial * spatial == sequence_length:
                    H = W = spatial
                else:
                    # Use stored latent dimensions if available (handles non-square or multi-frame)
                    H = getattr(controller, "latent_h", spatial)
                    W = getattr(controller, "latent_w", spatial)
                # Reshape and upsample the attention query `x`
                x_reshaped = x.reshape(batch_size, H, W, dim).permute(0, 3, 1, 2)
                x_reshaped = F.interpolate(x_reshaped, size=(feature_upsample_res, feature_upsample_res),
                                           mode="bicubic", align_corners=False)
                x_reshaped = x_reshaped.permute(0, 2, 3, 1).reshape(batch_size, -1, dim)
                # Recompute Q and attention with upsampled spatial resolution
                q_up = self.to_q(x_reshaped)
                q_up = self.reshape_heads_to_batch_dim(q_up)
                sim_up = torch.einsum("b i d, b j d -> b i j", q_up, k) * self.scale
                attn_up = torch.softmax(sim_up, dim=-1)
                attn_up = attn_up.clone()
                # Store the upsampled attention map in the controller
                controller({"attn": attn_up}, is_cross, place_in_unet)
            else:
                # If not capturing or not cross, just store the current attention map
                controller({"attn": attn}, is_cross, place_in_unet)
            # Reshape and project output
            out = self.reshape_batch_dim_to_heads(out)
            return to_out(out)

        return forward

    # If no custom controller provided, use a dummy that passes through
    if controller is None:
        controller = AttentionControl()

    # Recursively register attention control on all CrossAttention modules in the model
    def register_recurse(net, count, place_in_unet):
        if net.__class__.__name__ == "CrossAttention":
            net.forward = ca_forward(net, place_in_unet)
            return count + 1
        elif hasattr(net, "children"):
            for child in net.children():
                count = register_recurse(child, count, place_in_unet)
        return count

    cross_att_count = 0
    for name, module in model.named_children():
        if "up" in name:  # focus on upsampling blocks' cross-attention
            cross_att_count += register_recurse(module, 0, "up")
    controller.num_att_layers = cross_att_count
    assert cross_att_count != 0, "No cross-attention layers found. Please check model or diffusers version."
    return controller

In [1]:
# Hook function to capture input shapes before each forward pass of the U-Net
def _unet_pre_forward_hook(module, inputs, controllers_dict, current_device, feature_upsample_res):
    # This hook runs at the start of each U-Net forward call
    latent = inputs[0]  # the latent tensor input to the U-Net
    controller = controllers_dict[current_device]
    # Store frame count and latent spatial size for use in attention hook
    if latent.dim() == 5:
        # Input shape (batch, frames, channels, height, width)
        controller.frames = latent.shape[1]
        controller.latent_h = latent.shape[3]
        controller.latent_w = latent.shape[4]
    elif latent.dim() == 4:
        # Input shape (batch, channels, height, width) – treat as 1 frame
        controller.frames = 1
        controller.latent_h = latent.shape[2]
        controller.latent_w = latent.shape[3]
    # Register the attention control hooks on this module (unet)
    register_attention_control(module, controller, feature_upsample_res=feature_upsample_res)
    return  # no modification to inputs

In [None]:
# Prepare the Stable Video Diffusion pipeline and attention controllers for each GPU
def load_ldm(model_name="stabilityai/stable-video-diffusion-img2vid", feature_upsample_res=256):
    # Use DDIM scheduler for deterministic output
    scheduler = DDIMScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear",
                              clip_sample=False, set_alpha_to_one=False)
    scheduler.set_timesteps(50)  # default number of DDIM steps
    # Load the Stable Video Diffusion pipeline
    ldm = StableVideoDiffusionPipeline.from_pretrained(
        model_name, torch_dtype=torch.float16, variant="fp16", scheduler=scheduler
    )
    ldm = ldm.to(device)
    # Enable multi-GPU support if available
    if device != "cpu" and torch.cuda.device_count() > 1:
        ldm.unet = nn.DataParallel(ldm.unet)
        ldm.vae = nn.DataParallel(ldm.vae)
    # Freeze model parameters (we will optimize only the query embeddings for keypoints)
    for param in ldm.vae.parameters():
        param.requires_grad = False
    for param in ldm.unet.parameters():
        param.requires_grad = False
    for param in getattr(ldm, "text_encoder", []).parameters():
        param.requires_grad = False  # stable-video-diffusion has no text encoder
    # Set up attention controllers per device (for DataParallel splits)
    controllers = {}
    if device != "cpu" and torch.cuda.device_count() > 1:
        for dev_id in ldm.unet.device_ids:
            dev = torch.device("cuda", dev_id)
            controller = AttentionStore()
            controllers[dev] = controller
            # Hook attention on the unet module (DataParallel splits model per device)
            ldm.unet.module.register_forward_pre_hook(
                lambda module, inp, dev=dev, controller=controller: _unet_pre_forward_hook(module, inp, controllers, dev, feature_upsample_res)
            )
    else:
        dev = torch.device(device)
        controller = AttentionStore()
        controllers[dev] = controller
        ldm.unet.register_forward_pre_hook(
            lambda module, inp: _unet_pre_forward_hook(module, inp, controllers, dev, feature_upsample_res)
        )
    return ldm, controllers

In [None]:
# Load the stable video diffusion model and prepare attention controllers
ldm, controllers = load_ldm()
print("Model loaded on device:", device)
print("Number of GPUs:", torch.cuda.device_count())

In [None]:
# Define number of keypoints to find
num_keypoints = 10  # you can adjust this for different scenarios
num_optimization_steps = 500  # number of optimization iterations (increase for complex data)
batch_size = 1  # images per batch (use >1 for larger datasets if GPU memory allows)
augment_degrees = 30
augment_scale = (0.9, 1.1)
augment_translate = (0.1, 0.1)

In [None]:
# Prepare image dataset
class CustomDataset(torch.utils.data.Dataset):
    def __init__(self, data_root, image_size=512):
        self.data_root = data_root
        self.image_paths = sorted([os.path.join(data_root, fname) for fname in os.listdir(data_root) 
                                   if fname.lower().endswith(('.png', '.jpg', '.jpeg'))])
        self.image_size = image_size

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img = PILImage.open(self.image_paths[idx]).convert("RGB")
        # Resize to target size
        img = img.resize((self.image_size, self.image_size))
        img = np.array(img).astype(np.float32) / 255.0  # normalize to [0,1]
        # Return image and original index (for potential tracking)
        return {"img": torch.from_numpy(img).permute(2, 0, 1)}

In [None]:
# Function to initialize random context embedding (num_words tokens of dimension 1024)
def init_random_noise(device, num_words=1000, dim=1024):
    return torch.randn(1, num_words, dim, device=device)

In [None]:
# Random affine transform with inverse (for equivariance loss)
class RandomAffineWithInverse:
    def __init__(self, degrees=30, scale=(0.9, 1.1), translate=(0.1, 0.1)):
        self.degrees = degrees
        self.scale = scale
        self.translate = translate

    def __call__(self, img_tensor):
        # img_tensor shape: [batch, 3, H, W]
        # Apply random affine to a batch of images using torchvision or manual
        # Here we assume batch_size=1 for simplicity
        import torchvision.transforms.functional as TF
        angle = np.random.uniform(-self.degrees, self.degrees)
        scale_factor = np.random.uniform(self.scale[0], self.scale[1])
        max_dx = self.translate[0] * img_tensor.shape[2]
        max_dy = self.translate[1] * img_tensor.shape[3]
        translations = (np.random.uniform(-max_dx, max_dx), np.random.uniform(-max_dy, max_dy))
        # Apply affine
        img = TF.affine(img_tensor, angle=angle, translate=translations, scale=scale_factor, shear=0)
        # Store transform parameters for inverse if needed
        self.last_params = (angle, translations, scale_factor)
        return img

    def inverse(self, img_tensor):
        # Apply inverse of last used transform
        import torchvision.transforms.functional as TF
        angle, translations, scale_factor = self.last_params
        inv_angle = -angle
        inv_scale = 1.0 / scale_factor if scale_factor != 0 else 1.0
        inv_translations = (-translations[0] * inv_scale, -translations[1] * inv_scale)
        img = TF.affine(img_tensor, angle=inv_angle, translate=inv_translations, scale=inv_scale, shear=0)
        return img_tensor if img is None else img

In [None]:
# Loss functions
def sharpening_loss(attn_map, device="cuda", sigma=1.0):
    # Encourage attention map values to be either 0 or 1 (sharpen peaks)
    # We use mean squared error with a target peaked distribution (gaussian peak)
    # Target is a gaussian peak centered at the attention location (already selected by indices)
    # Here attn_map is a subset of attention values at candidate keypoint locations
    target = torch.ones_like(attn_map)
    return F.mse_loss(attn_map, target)

In [None]:
def equivariance_loss(attn_map, attn_map_transformed, transform, index):
    # Ensure attention map on original and transformed images (after inverse transform) are similar
    attn_map_orig = attn_map
    # Inverse-transform the transformed attention map back to original frame
    attn_map_trans_inv = transform.inverse(attn_map_transformed)[index]
    return F.mse_loss(attn_map_orig, attn_map_trans_inv)

In [None]:
# Utility: find top k peak candidates in attention map by sampling points proportional to attn intensity
def find_top_k_gaussian(attn_map, num_samples, sigma=16, num_subjects=1):
    # attn_map shape: [H, W]
    # Flatten and sample indices weighted by attention values
    B, H, W = attn_map.shape[0], attn_map.shape[1], attn_map.shape[2]
    flat_attn = attn_map.reshape(B, -1)
    # Sample indices with probability proportional to attention
    flat_attn = flat_attn + 1e-8
    probs = (flat_attn / flat_attn.sum(dim=1, keepdim=True)).cpu().numpy()
    indices = []
    for b in range(B):
        idx = np.random.choice(flat_attn.shape[1], size=num_samples, p=probs[b])
        indices.append(torch.from_numpy(idx).long().to(attn_map.device))
    return torch.stack(indices)

In [None]:
# Utility: furthest point sampling to choose top_k distinct points from candidates
def furthest_point_sampling(attn_map, top_k, candidate_indices):
    # attn_map shape: [H, W]; candidate_indices: tensor of indices
    # Here we just take the top_k highest values from candidate indices
    # (In practice, one could implement FPS for diversity)
    flat_attn = attn_map.reshape(-1)
    vals = flat_attn[candidate_indices]
    topk = torch.topk(vals, top_k)
    return candidate_indices[topk.indices]

In [None]:
# Prepare dataset and dataloader
image_dir = "./images"  # directory containing input images/frames
dataset = CustomDataset(data_root=image_dir, image_size=512)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)
dataloader_iter = iter(dataloader)

In [None]:
# Initialize random context embedding with num_keypoints tokens
context = init_random_noise(device, num_words=num_keypoints, dim=1024)
context.requires_grad = True

In [None]:
# Set up optimizer for the embedding vectors
optimizer = torch.optim.Adam([context], lr=5e-3)
# Initialize random augmentation transform
invertible_transform = RandomAffineWithInverse(degrees=augment_degrees, scale=augment_scale, translate=augment_translate)

In [None]:
# Optimization loop
for step in tqdm(range(num_optimization_steps)):
    try:
        batch = next(dataloader_iter)
    except StopIteration:
        dataloader_iter = iter(dataloader)
        batch = next(dataloader_iter)
    image = batch["img"].to(device)  # shape [batch, 3, H, W], values in [0,1]
    # Apply random augmentation
    transformed_image = invertible_transform(image.clone())
    # Run diffusion model for original and transformed images (one denoise step) to get attention maps
    # We use our controllers to collect attention
    attn_maps_list = []
    attn_maps_trans_list = []
    # Use the same noise level for both
    noise_level = -1  # -1 will use the last scheduler timestep (i.e., final denoising step)
    # Original image attention
    with torch.no_grad():
        # Encode image to latent
        latent = None
        if isinstance(ldm.vae, nn.DataParallel):
            latent = ldm.vae.module.encode(image * 2 - 1)["latent_dist"].mean.detach()
        else:
            latent = ldm.vae.encode(image * 2 - 1)["latent_dist"].mean.detach()
        latent = latent * (1/0.18215)  # scale factor used in stable diffusion models
        noise = torch.randn_like(latent)
        t = ldm.scheduler.timesteps[noise_level] if noise_level != -1 else ldm.scheduler.timesteps[-1]
        noisy_latent = ldm.scheduler.add_noise(latent, noise, t)
        # Forward through unet to predict noise (this triggers our hooks and fills controllers)
        if isinstance(ldm.unet, nn.DataParallel):
            pred = ldm.unet(noisy_latent, t.repeat(noisy_latent.shape[0]), encoder_hidden_states=context, added_time_ids=ldm._get_add_time_ids(7, 127, 0.02, latent.dtype, batch_size=noisy_latent.shape[0], num_videos_per_prompt=1, do_classifier_free_guidance=False))
        else:
            pred = ldm.unet(noisy_latent, t.repeat(noisy_latent.shape[0]), encoder_hidden_states=context, added_time_ids=ldm._get_add_time_ids(7, 127, 0.02, latent.dtype, batch_size=noisy_latent.shape[0], num_videos_per_prompt=1, do_classifier_free_guidance=False))
    # Collect attention maps for this batch from all controllers (for multi-gpu, each controller holds part)
    for ctrl in controllers.values():
        # Collect and average attention maps from this controller
        maps = ctrl.step_store["attn"]
        # Stack and average across layers and heads
        if len(maps) > 0:
            maps = torch.stack(maps, dim=0)  # shape [layers, b*h, seq, context]
            attn_map = maps.mean(dim=(0,1))  # average over layers and heads
            attn_map = attn_map[:, :1]  # focus on first context token if multiple (not used if we have multiple tokens directly)
            attn_map = attn_map.reshape(1, int(attn_map.shape[0]**0.5), -1)
            attn_maps_list.append(attn_map)
        ctrl.reset()
    # Transformed image attention
    with torch.no_grad():
        # Encode transformed image
        latent_t = None
        if isinstance(ldm.vae, nn.DataParallel):
            latent_t = ldm.vae.module.encode(transformed_image * 2 - 1)["latent_dist"].mean.detach()
        else:
            latent_t = ldm.vae.encode(transformed_image * 2 - 1)["latent_dist"].mean.detach()
        latent_t = latent_t * (1/0.18215)
        noise_t = torch.randn_like(latent_t)
        t = ldm.scheduler.timesteps[noise_level] if noise_level != -1 else ldm.scheduler.timesteps[-1]
        noisy_latent_t = ldm.scheduler.add_noise(latent_t, noise_t, t)
        # Forward unet
        if isinstance(ldm.unet, nn.DataParallel):
            pred_t = ldm.unet(noisy_latent_t, t.repeat(noisy_latent_t.shape[0]), encoder_hidden_states=context, added_time_ids=ldm._get_add_time_ids(7, 127, 0.02, latent_t.dtype, batch_size=noisy_latent_t.shape[0], num_videos_per_prompt=1, do_classifier_free_guidance=False))
        else:
            pred_t = ldm.unet(noisy_latent_t, t.repeat(noisy_latent_t.shape[0]), encoder_hidden_states=context, added_time_ids=ldm._get_add_time_ids(7, 127, 0.02, latent_t.dtype, batch_size=noisy_latent_t.shape[0], num_videos_per_prompt=1, do_classifier_free_guidance=False))
    for ctrl in controllers.values():
        if len(ctrl.step_store["attn"]) > 0:
            maps_t = torch.stack(ctrl.step_store["attn"], dim=0)
            attn_map_t = maps_t.mean(dim=(0,1))
            attn_map_t = attn_map_t[:, :1]
            attn_map_t = attn_map_t.reshape(1, int(attn_map_t.shape[0]**0.5), -1)
            attn_maps_trans_list.append(attn_map_t)
        ctrl.reset()
    if not attn_maps_list or not attn_maps_trans_list:
        continue  # skip if no attention maps collected (e.g., first few layers might not produce stored maps)
    attn_maps = attn_maps_list[0]  # since batch_size=1 and one controller for simplicity
    attn_maps_transformed = attn_maps_trans_list[0]
    # Compute losses
    _sharpening_loss = 0.0
    _equiv_loss = 0.0
    # Use the entire attention map to find candidate peaks
    attn_map = attn_maps[0]  # shape [H, W]
    attn_map_t = attn_maps_transformed[0]  # [H, W]
    # Find candidate keypoint locations on original and transformed attention maps
    candidates = find_top_k_gaussian(attn_map.unsqueeze(0), num_samples=50, sigma=16)
    # Pick top K distinct points (furthest sampling or top values)
    top_indices = furthest_point_sampling(attn_map, num_keypoints, candidates[0])
    # Compute losses for these keypoint locations
    _sharpening_loss = sharpening_loss(attn_map.view(-1)[top_indices])
    _equiv_loss = equivariance_loss(attn_map.view(1, *attn_map.shape), attn_map_t.view(1, *attn_map_t.shape), invertible_transform, 0)
    loss = _sharpening_loss + _equiv_loss
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()


In [None]:
# After optimization, use the learned embeddings to get keypoint positions on each frame
indices_by_frame = []
image_paths = sorted([os.path.join(image_dir, fname) for fname in os.listdir(image_dir) 
                      if fname.lower().endswith(('.png', '.jpg', '.jpeg'))])
image_w = image_h = 512  # we resized images to 512

In [None]:
for img_path in image_paths:
    img = PILImage.open(img_path).convert("RGB").resize((image_w, image_h))
    img_tensor = torch.from_numpy(np.array(img).astype(np.float32) / 255.0).permute(2, 0, 1).unsqueeze(0).to(device)
    frame_indices = []
    # Compute attention map for each token
    for token_idx in range(num_keypoints):
        # Reset controllers
        for ctrl in controllers.values():
            ctrl.reset()
        with torch.no_grad():
            # Encode image to latent
            latent = None
            if isinstance(ldm.vae, nn.DataParallel):
                latent = ldm.vae.module.encode(img_tensor * 2 - 1)["latent_dist"].mean
            else:
                latent = ldm.vae.encode(img_tensor * 2 - 1)["latent_dist"].mean
            latent = latent * (1/0.18215)
            noise = torch.randn_like(latent)
            t = ldm.scheduler.timesteps[-1]  # final step
            noisy_latent = ldm.scheduler.add_noise(latent, noise, t)
            # Run unet for this token only (we provide only the token of interest via indices slicing)
            # To isolate a single token, we create a context where that token is present and others maybe zeroed.
            # Simplest: we pass the entire context, but collect only that token's attn in collect_maps by index.
            if isinstance(ldm.unet, nn.DataParallel):
                _ = ldm.unet(noisy_latent, t.repeat(noisy_latent.shape[0]), encoder_hidden_states=context,
                             added_time_ids=ldm._get_add_time_ids(7, 127, 0.02, latent.dtype, batch_size=1, num_videos_per_prompt=1, do_classifier_free_guidance=False))
            else:
                _ = ldm.unet(noisy_latent, t.repeat(noisy_latent.shape[0]), encoder_hidden_states=context,
                             added_time_ids=ldm._get_add_time_ids(7, 127, 0.02, latent.dtype, batch_size=1, num_videos_per_prompt=1, do_classifier_free_guidance=False))
        # Gather attention maps from controller
        attn_map_token = None
        for ctrl in controllers.values():
            if len(ctrl.step_store["attn"]) > 0:
                maps = torch.stack(ctrl.step_store["attn"], dim=0)
                attn = maps.mean(dim=(0,1))  # average over heads and layers
                # Reshape to (H, W, context_length)
                attn = attn.reshape(int(attn.shape[0]**0.5), -1, attn.shape[1])
                # Take the map for the current token (token_idx)
                if token_idx < attn.shape[2]:
                    attn_map_token = attn[..., token_idx]
                    break
        if attn_map_token is None:
            frame_indices.append(None)
        else:
            # Upsample attention map to image size
            attn_map_token = attn_map_token.unsqueeze(0).unsqueeze(0)  # shape [1,1,H,W]
            attn_map_token = F.interpolate(attn_map_token, size=(image_h, image_w), mode="bicubic", align_corners=False)
            attn_map_token = attn_map_token.squeeze()
            # Find the peak coordinate
            max_idx = torch.argmax(attn_map_token).item()
            frame_indices.append(max_idx)
    indices_by_frame.append(frame_indices)

In [None]:
# Print the resulting indices for each frame and keypoint
print("Keypoint indices [frame][point]:")
for i, frame_indices in enumerate(indices_by_frame):
    print(f"Frame {i}: {frame_indices}")