In [None]:
import os
import sys
import cv2
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
import torchvision.transforms as transforms
import numpy as np
import warnings

from config import RunConfig
from matplotlib import pyplot as plt
from scipy import stats, ndimage
from skimage.feature import peak_local_max
from pycocotools import mask
from PIL import Image
from tqdm import tqdm
from collections import defaultdict
from torchvision.transforms import ToTensor

In [None]:
from typing import List, Dict, Union
from utils import vis_utils
from diffusers import DDIMScheduler, DDIMInverseScheduler
from pipeline_scribble_guide import ScribbleGuidePipeline, AttentionStore
from transformers import BlipForConditionalGeneration, BlipProcessor

In [None]:
warnings.filterwarnings('ignore')

In [None]:
def show_cam_on_image(img, mask):
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + np.float32(img)
    cam = cam / np.max(cam)
    return cam

In [None]:
def get_segmentation_image(segmentations, image_to_file, segm_idx, image_idx):
  segm = segmentations["annotations"][segm_idx]["segmentation"]
  image = segmentations["annotations"][image_idx]["image_id"]
  image = image_to_file[image]
  image = np.array(Image.open(image).convert("RGB"))
  h, w, c = image.shape
  rles = mask.frPyObjects(segm, h, w)
  
  if type(rles) is dict:
    rles = [rles]
  
  rle = mask.merge(rles)
  segm = mask.decode(rle)
  segm = segm * 255

  segm, image = Image.fromarray(segm), Image.fromarray(image)
  return segm, image

In [None]:
def blend_images(init_image, mask_image):
    # Ensure both images are the same size
    assert init_image.size == mask_image.size, "Images must be the same size!"
    
    # Convert both images to RGBA to work with transparency
    init_image = init_image.convert("RGBA")
    mask_image = mask_image.convert("RGBA")
    
    # Create a new image with the same size and RGBA mode for the result
    blended_image = Image.new("RGBA", init_image.size)
    
    # Iterate over each pixel
    for y in range(init_image.height):
        for x in range(init_image.width):
            mask_pixel = mask_image.getpixel((x, y))
            init_pixel = init_image.getpixel((x, y))
            
            # If the mask_pixel is white, keep the init_pixel
            if mask_pixel[:3] == (255, 255, 255):
                blended_image.putpixel((x, y), init_pixel)
            else:
                # Otherwise, set the pixel to the mask_pixel
                blended_image.putpixel((x, y), mask_pixel)
    
    return blended_image.convert("RGB")

## Test for Example

In [None]:
dataset_dir = "../datasets"
shape_prompts = json.load(open(os.path.join(dataset_dir, "shape_prompts", "val.json")))

annotations = shape_prompts['annotations']

In [None]:
len(annotations)

In [None]:
annotations[0].keys()

In [None]:
annotation_id_list = [ annotation['category_id']  for annotation in annotations ]
annotation_id_set_list = list(set(annotation_id_list))

In [None]:
len(annotation_id_set_list)

In [None]:
category_list = [ category for category in shape_prompts['categories'] if category['id'] in annotation_id_set_list]

In [None]:
category_list

In [None]:
category_id_to_name = {}
category_name_to_id = {}

for category in category_list:
    category_id_to_name[category['id']] = category['name']
    category_name_to_id[category['name']] = category['id']

In [None]:
print(category_id_to_name)
print(category_name_to_id)

In [None]:
category_count = defaultdict(int)
for annotation in annotations:
    category_id = annotation['category_id']
    category_count[category_id_to_name[category_id]] += 1

In [None]:
category_count

In [None]:
image_to_file = {image["id"]: image["coco_url"].replace("http://images.cocodataset.org", dataset_dir) for image in shape_prompts["images"]}

In [None]:
# for index in range(len(annotations)):
#     annotation = annotations[index]
#     category_name = category_id_to_name[annotation['category_id']]
    
#     mask_image, init_image = get_segmentation_image(shape_prompts, image_to_file, index, index)
#     mask_image, init_image = mask_image.resize((512, 512)), init_image.resize((512, 512))
    
#     blended_image = blend_images(init_image, mask_image)
    
#     save_category_image_path = f'./dataset/category/{category_name}'
#     save_masked_image_path = f'./dataset/masked'
    
#     if not os.path.exists(save_category_image_path):
#         os.makedirs(save_category_image_path)
    
#     if not os.path.exists(save_masked_image_path):
#         os.makedirs(save_masked_image_path)
    
#     blended_image.save(f'{save_category_image_path}/{index}.jpg')
#     blended_image.save(f'{save_masked_image_path}/{category_name}_{index}.jpg')

In [None]:
NUM_DIFFUSION_STEPS = 50
GUIDANCE_SCALE = 7.5
MAX_NUM_WORDS = 77

## Scribble Guidance

In [None]:
captioner_id = "Salesforce/blip-image-captioning-base"
processor = BlipProcessor.from_pretrained(captioner_id)
model = BlipForConditionalGeneration.from_pretrained(captioner_id, 
                                                    #  torch_dtype=torch.float16, 
                                                     low_cpu_mem_usage=True)

In [None]:
sd_model_ckpt = "CompVis/stable-diffusion-v1-4"

device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
stable = ScribbleGuidePipeline.from_pretrained(
    "CompVis/stable-diffusion-v1-4",
    caption_generator=model,
    caption_processor=processor,
    safety_checker=None,
    # torch_dtype=torch.float16
).to(device)

tokenizer = stable.tokenizer
stable.scheduler = DDIMScheduler.from_config(stable.scheduler.config)
stable.inverse_scheduler = DDIMInverseScheduler.from_config(stable.scheduler.config)

In [None]:
torch.cuda.is_available()

In [None]:
def run_on_prompt(prompt: List[str],
                  token_masks: Union[
                    List[torch.Tensor],
                    List[Image.Image],
                    List[np.ndarray]  
                  ],
                  model: ScribbleGuidePipeline,
                  controller: AttentionStore,
                  token_indices: List[int],
                  seed: int,
                  output_path: str,
                  generator: torch.Generator,
                  config: RunConfig,
                  latents: torch.FloatTensor = None,
                  ) -> Image.Image:
    outputs = model(prompt=prompt,
                    token_masks=token_masks,
                    attention_store=controller,
                    indices_list=token_indices,
                    attention_resolution=config.attention_res,
                    guidance_scale=config.guidance_scale,
                    latents=latents,
                    generator=generator,
                    seed=seed,
                    output_path=output_path,
                    num_inference_steps=config.num_inference_steps,
                    run_standard=config.run_standard,
                    scale_factor=config.scale_factor,
                    scale_range=config.scale_range,
                  )
    image = outputs.images[0]
    return image

In [None]:
def run_and_display(prompts: List[str],
                    token_masks: Union[
                        List[torch.Tensor],
                        List[Image.Image],
                        List[np.ndarray]  
                    ],
                    controller: AttentionStore,
                    indices_to_alter: List[int],
                    seed: int,
                    output_path: str,
                    generator: torch.Generator,
                    latents: torch.FloatTensor = None,
                    run_standard: bool = False,
                    scale_factor: int = 10,
                    display_output: bool = False):
    config = RunConfig(prompt=prompts[0],
                       run_standard=run_standard,
                       scale_factor=scale_factor)
    image = run_on_prompt(model=stable,
                          token_masks=token_masks,
                          prompt=prompts,
                          latents=latents,
                          controller=controller,
                          output_path=output_path,
                          token_indices=indices_to_alter,
                          seed=seed,
                          generator=generator,
                          config=config)
    if display_output:
        display(image)
    return image

In [None]:
for index in range(len(annotations)):
    annotation = annotations[index]
    category_name = category_id_to_name[annotation['category_id']]
    
    mask_image, _ = get_segmentation_image(shape_prompts, image_to_file, index, index)
    mask_image = mask_image.resize((512, 512))
    
    prompt = f"a photography of {'an' if category_name[0] in ['a', 'e', 'i', 'o', 'u'] else 'a'} {category_name}"
    
    prompts = [prompt]
    token_indices = [5]
    token_masks = [mask_image]
    
    seed = 21
    latents = None
    
    controller = AttentionStore()
    
    generator = torch.Generator('cuda').manual_seed(seed)
    
    display(mask_image)
    image = run_and_display(prompts=prompts,
                            # image=image,
                            token_masks=token_masks,
                            controller=controller,
                            latents=latents,
                            indices_to_alter=token_indices,
                            generator=generator,
                            seed=seed,
                            output_path=f"runs/{index}",
                            run_standard=False,
                            display_output=True)
    vis_utils.show_cross_attention(attention_store=controller,
                                   prompt=prompt,
                                   tokenizer=tokenizer,
                                   res=16,
                                   from_where=("up", "down", "mid"),
                                   indices_to_alter=token_indices,
                                   global_attention=False,
                                   orig_image=image)
    vis_utils.show_self_attention(attention_store=controller,
                                    res=16,
                                    from_where=("up", "down", "mid"),
                                )
    
    # image = run_and_display(prompts=prompts,
    #                         # image=image,
    #                         token_masks=token_masks,
    #                         controller=controller,
    #                         latents=latents,
    #                         indices_to_alter=token_indices,
    #                         generator=generator,
    #                         seed=seed,
    #                         run_standard=True,
    #                         display_output=True)
    # vis_utils.show_cross_attention(attention_store=controller,
    #                                prompt=prompt,
    #                                tokenizer=tokenizer,
    #                                res=16,
    #                                from_where=("up", "down", "mid"),
    #                                indices_to_alter=token_indices,
    #                                global_attention=False,
    #                                orig_image=image)
    # vis_utils.show_self_attention(attention_store=controller,
    #                                 res=16,
    #                                 from_where=("up", "down", "mid"),
    #                                 )