In [None]:
from cli_visualization import HuatuoChatbot
import torch
from PIL import Image
import json
from tqdm import tqdm
import requests
from PIL import Image
from io import BytesIO
from transformers import TextStreamer
import os
import torch
import json
from tqdm import tqdm
import re

import seaborn as sns
import matplotlib.pyplot as plt


NUM_IMG_TOKENS = 576
PATCHES = 24
SIZE = (336,336)

bot = HuatuoChatbot("FreedomIntelligence/HuatuoGPT-Vision-7B")

In [None]:
import numpy as np
import cv2

def generate_attention_maps(question, image_path, layer=16, query=-1):

    general_question = 'Write a general description of the image.'

    prompt = f"{question} Answer the question using a single word or phrase."
    general_prompt = f"{general_question} Answer the question using a single word or phrase."

    model_output, input_ids = bot.inference_with_attention_output(prompt,image_path)
    # print(f"Answer: {response_qs}")
    input_ids = input_ids[0].cpu()
    index = torch.where(input_ids==-200)[0]
    att_map = model_output['attentions'][layer][0, :, query, index:index+NUM_IMG_TOKENS].mean(dim=0).to(torch.float32).detach().cpu().numpy().reshape(PATCHES, PATCHES)
    model_output, input_ids = bot.inference_with_attention_output(general_prompt,image_path)
    # print(f"Description: {response_general}")
    input_ids = input_ids[0].cpu()
    index = torch.where(input_ids==-200)[0]
    general_att_map = model_output['attentions'][layer][0, :, query, index:index+NUM_IMG_TOKENS].mean(dim=0).to(torch.float32).detach().cpu().numpy().reshape(PATCHES, PATCHES)

    return att_map, general_att_map


def show_mask_on_image(img, mask):
    img = np.float32(img) / 255
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_HSV)
    hm = np.float32(heatmap) / 255
    cam = hm + np.float32(img)
    cam = cam / np.max(cam)
    return np.uint8(255 * cam), heatmap


def normalize(img):
    return (img - img.min()) / (img.max() - img.min())


def get_attention_map_from_vit(question, image_path, layer=4):
    vit_attention = bot.inference_with_vit_attention(question,image_path)
    vit_attention = vit_attention[layer][0, :, 0, 1:].mean(dim=0).to(torch.float32).detach().cpu().numpy().reshape(PATCHES, PATCHES)
    return vit_attention

In [None]:
import matplotlib.patches as patches


def convert_bbox_after_padding_and_resize(bbox, original_size, resized_size):

    W_orig, H_orig = original_size
    W_new, H_new = resized_size

    # Compute padding offsets
    if W_orig > H_orig:
        pad_top = (W_orig - H_orig) // 2
        pad_left = 0
    elif H_orig > W_orig:
        pad_top = 0
        pad_left = (H_orig - W_orig) // 2
    else:
        pad_top = pad_left = 0

    W_pad, H_pad = max(W_orig, H_orig), max(W_orig, H_orig)

    # Shift bbox due to padding
    x_min, y_min, width, height = bbox
    x_min_pad = x_min + pad_left
    y_min_pad = y_min + pad_top

    # Scale from padded to resized
    scale_x = W_new / W_pad
    scale_y = H_new / H_pad

    x_min_new = x_min_pad * scale_x
    y_min_new = y_min_pad * scale_y
    width_new = width * scale_x
    height_new = height * scale_y

    return (x_min_new, y_min_new, width_new, height_new)


def load_image(image_file):
    if image_file.startswith('http://') or image_file.startswith('https://'):
        response = requests.get(image_file)
        image = Image.open(BytesIO(response.content)).convert('RGB')
    else:
        image = Image.open(image_file).convert('RGB')
    return image


def expand2square(pil_img, background_color):
    width, height = pil_img.size
    if width == height:
        return pil_img
    elif width > height:
        result = Image.new(pil_img.mode, (width, width), background_color)
        result.paste(pil_img, (0, (width - height) // 2))
        return result
    else:
        result = Image.new(pil_img.mode, (height, height), background_color)
        result.paste(pil_img, ((height - width) // 2, 0))
        return result

        
def draw_bboxes(image_input, bboxes, name, thickness, folder):

    # image = image_input.numpy()
    image = np.array(image_input)

    if image.max() <= 1.0:
        image = (image * 255).astype(np.uint8)

    fig, ax = plt.subplots(figsize=(6, 6))

    # Show image with viridis colormap
    ax.imshow(image, cmap='viridis')

    colors = ['red'] * len(bboxes)

    # Draw each bounding box
    for bbox, color in zip(bboxes, colors):
        x, y, w, h = bbox
        rect = patches.Rectangle((x, y), w, h, linewidth=thickness,
                                 edgecolor=color, facecolor='none')
        ax.add_patch(rect)

    ax.axis('off')
    plt.tight_layout()
    os.makedirs(os.path.join(folder, name), exist_ok=True)
    plt.savefig(os.path.join(folder, f"{name}/source.png"), bbox_inches='tight', pad_inches=0)



def draw_att_map(qs, image_path, layers, bboxes, name, thickness, folder):
    for layer in layers:
        att_map, general_att_map = generate_attention_maps(qs, image_path, layer=layer)
        att_map_scaled = normalize(torch.tensor(att_map / general_att_map))
        attn_over_image = torch.nn.functional.interpolate(
            att_map_scaled.unsqueeze(0).unsqueeze(0), 
            size=SIZE, 
            mode='nearest', 
        ).squeeze()
        draw_bboxes_attention_map(attn_over_image, bboxes, name, layer, thickness, folder)


def draw_bboxes_attention_map(image_input, bboxes, name, layer, thickness, folder):

    image = image_input.numpy()

    if image.max() <= 1.0:
        image = (image * 255).astype(np.uint8)

    fig, ax = plt.subplots(figsize=(6, 6))

    # Show image with viridis colormap
    im = ax.imshow(image, cmap='viridis')

    colors = ['red'] * len(bboxes)

    # Draw each bounding box
    for bbox, color in zip(bboxes, colors):
        x, y, w, h = bbox
        rect = patches.Rectangle((x, y), w, h, linewidth=thickness,
                                 edgecolor=color, facecolor='none')
        ax.add_patch(rect)

    ax.axis('off')
    plt.tight_layout()
    # plt.colorbar(im, ax=ax)
    plt.savefig(os.path.join(folder, f"{name}/layer{layer}.png"), bbox_inches='tight', pad_inches=0)

In [None]:
# image_num = 6
# layers = [0,15,25,35,45,55]
layers = [0,5,10,16,20,27]

questions = []
# dataset = "COCO"
dataset = "SLAKE"
# subset = "localization"
subset = "attribute"
input_path = f"./{dataset}_{subset}_questions.jsonl"
with open(input_path, "r") as infile:
    for line in infile:
        questions.append(json.loads(line))


qs_num = 1
sample = questions[qs_num]
image_path = os.path.join("/local_data/local_data/mllm_datasets/evaluation_datasets/slake/imgs", sample["image"], "source.jpg")
# image_path = os.path.join('../../local_data/mllm_datasets/evaluation_datasets/coco/val2014', f'COCO_val2014_{sample["image"]:012d}.jpg')
bboxes = sample["bbox"]
assert len(bboxes) == 1
original_image = load_image(image_path)
image = expand2square(original_image, 0)
image = image.resize((336, 336))
bbox = convert_bbox_after_padding_and_resize(bboxes[0], original_image.size, resized_size=(336, 336))
print(sample["question"])
name = f"{dataset}_{subset}_{qs_num}"
folder = "./huatuo34b_samples"

In [None]:
thickness = 3
draw_bboxes(image, [bbox], name, thickness, folder)
with open(os.path.join(folder, name, "question.txt"), "w") as file:
    file.write(sample["question"])

In [None]:
draw_att_map(sample["question"], image_path, layers, [bbox], name, thickness, folder)