# Hedlin et al.

In [1]:
import numpy as np
import torch
from PIL import Image
import torchvision.transforms.functional as TF

In [2]:
import sys
sys.path.append('hedlin/utils/')
from hedlin.utils.optimize_token import load_ldm, optimize_prompt, run_image_with_tokens_cropped, find_max_pixel_value

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
def find_corresponding_point(ldm, src_img, trg_img, src_kpts):
    upsample_res = 512
    num_steps = 100
    noise_level = 10
    layers = [0, 1, 2, 3, 4, 5]
    device = 'cpu'
    lr = 1e-3
    num_opt_iterations = 5
    sigma = 32
    flip_prob = 0.5
    crop_percent = 80
    num_iterations = 20

    # Initialize the estimated keypoints
    est_keypoints = -1 * torch.ones_like(src_kpts)
    ind_layers = -1 * torch.ones_like(src_kpts).repeat(len(layers), 1, 1)

    for j in range(src_kpts.shape[1]):
        # Find the text embeddings for the source point
        contexts = []
        for _ in range(num_opt_iterations):
            context = optimize_prompt(ldm, src_img, src_kpts[0, :, j]/512, num_steps=num_steps, device=device, layers=layers, lr = lr, upsample_res=upsample_res, noise_level=noise_level, sigma = sigma, flip_prob=flip_prob, crop_percent=crop_percent)
            contexts.append(context)

        # Find and combine the attention maps over the multiple found text embeddings and crops
        all_maps = []
        for context in contexts:
            maps = []
            attn_maps, _ = run_image_with_tokens_cropped(ldm, trg_img, context, index=0, upsample_res = upsample_res, noise_level=noise_level, layers=layers, device=device, crop_percent=crop_percent, num_iterations=num_iterations, image_mask = None)
            for k in range(attn_maps.shape[0]):
                avg = torch.mean(attn_maps[k], dim=0, keepdim=True)
                maps.append(avg)
                _max_val = find_max_pixel_value(avg[0], img_size = 512)
                ind_layers[k, :, j] = (_max_val+0.5)
            maps = torch.stack(maps, dim=0)
            all_maps.append(maps)
        all_maps = torch.stack(all_maps, dim=0)
        all_maps = torch.mean(all_maps, dim=0)
        all_maps = torch.nn.Softmax(dim=-1)(all_maps.reshape(len(layers), upsample_res*upsample_res))
        all_maps = all_maps.reshape(len(layers), upsample_res, upsample_res)

        # Take the argmax to find the corresponding location for the target image
        all_maps = torch.mean(all_maps, dim=0)
        max_val = find_max_pixel_value(all_maps, img_size = 512)
        est_keypoints[0, :, j] = (max_val+0.5)

    return est_keypoints

In [4]:
ldm = load_ldm('cpu', 'CompVis/stable-diffusion-v1-4')
ldm.enable_attention_slicing()

Fetching 30 files: 100%|██████████| 30/30 [00:00<00:00, 167548.76it/s]
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["id2label"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["bos_token_id"]` will be overriden.
`text_config_dict` is provided which will be used to initialize `CLIPTextConfig`. The value `text_config["eos_token_id"]` will be overriden.
The config attributes {'scaling_factor': 0.18215} were passed to AutoencoderKL, but are not expected and will be ignored. Please verify your config.json configuration file.


In [5]:
def load_image(img_name):
    image = Image.open(img_name).convert('RGB')
    image = image.resize((512, 512), Image.BILINEAR)
    image = np.array(image)
    image = np.transpose(image, (2, 0, 1))
    image = torch.tensor(image)/255.0
    return image
        
img1 = load_image('hedlin/assets/source_cat.png')
img2 = load_image('hedlin/assets/target_cat.jpeg')
point1 = torch.tensor([[0.4, 0.9]])
point1 = point1.permute(1, 0) * 512.0
point2 = find_corresponding_point(ldm, img1, img2, point1.unsqueeze(0))

  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


KeyboardInterrupt: 