## Prompt-to-Prompt with Stable Diffusion

In [48]:
from typing import List
import torch
from diffusers import StableDiffusionPipeline,  DDIMScheduler
import numpy as np
import abc
import ptp_utils

In [49]:
# MY_TOKEN = '<replace with your token>'
MY_TOKEN = None
LOW_RESOURCE = False 
NUM_DIFFUSION_STEPS = 50
GUIDANCE_SCALE = 7.5
MAX_NUM_WORDS = 77

# diffusion device
device = torch.device('cuda:0') if torch.cuda.is_available() else torch.device('cpu')
# model_key = "runwayml/stable-diffusion-v1-5"
model_key = "stabilityai/stable-diffusion-2-1-base"

# load diffusion model
ldm_stable = StableDiffusionPipeline.from_pretrained(model_key, local_files_only=True, torch_dtype=torch.float16).to(device)
# ldm_stable.enable_xformers_memory_efficient_attention()
tokenizer = ldm_stable.tokenizer
ldm_stable.scheduler = DDIMScheduler.from_pretrained(model_key, subfolder="scheduler",
                                                     beta_start=0.00085,beta_end=0.012,
                                                     steps_offset=1)

from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration

processor = BlipProcessor.from_pretrained("../blip-image-captioning-large")
model = BlipForConditionalGeneration.from_pretrained("../blip-image-captioning-large").to("cuda:0")

Loading pipeline components...:   0%|          | 0/6 [00:00<?, ?it/s]

In [50]:
# code for store attention
class AttentionControl(abc.ABC):
    
    def step_callback(self, x_t):
        return x_t
    
    def between_steps(self):
        return
    
    @property
    def num_uncond_att_layers(self):
        return self.num_att_layers if LOW_RESOURCE else 0
    
    @abc.abstractmethod
    def forward (self, attn, is_cross: bool, place_in_unet: str):
        raise NotImplementedError

    def __call__(self, attn, is_cross: bool, place_in_unet: str):
        if self.cur_att_layer >= self.num_uncond_att_layers:
            if LOW_RESOURCE:
                attn = self.forward(attn, is_cross, place_in_unet)
            else:
                h = attn.shape[0]
                attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
        self.cur_att_layer += 1
        if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
            self.cur_att_layer = 0
            self.cur_step += 1
            self.between_steps()
        return attn
    
    def reset(self):
        self.cur_step = 0
        self.cur_att_layer = 0

    def __init__(self):
        self.cur_step = 0
        self.num_att_layers = -1
        self.cur_att_layer = 0

class EmptyControl(AttentionControl):
    
    def forward (self, attn, is_cross: bool, place_in_unet: str):
        return attn
    
    
class AttentionStore(AttentionControl):

    @staticmethod
    def get_empty_store():
        return {"down_cross": [], "mid_cross": [], "up_cross": [],
                "down_self": [],  "mid_self": [],  "up_self": []}

    def forward(self, attn, is_cross: bool, place_in_unet: str):
        key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
        if attn.shape[1] <= 64 ** 2:  # avoid memory overhead
            self.step_store[key].append(attn)
        return attn

    def between_steps(self):
        if len(self.attention_store) == 0:
            self.attention_store = self.step_store
        else:
            for key in self.attention_store:
                for i in range(len(self.attention_store[key])):
                    self.attention_store[key][i] += self.step_store[key][i]
        self.step_store = self.get_empty_store()

    def get_average_attention(self):
        average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] for key in self.attention_store}
        return average_attention


    def reset(self):
        super(AttentionStore, self).reset()
        self.step_store = self.get_empty_store()
        self.attention_store = {}

    def __init__(self):
        super(AttentionStore, self).__init__()
        self.step_store = self.get_empty_store()
        self.attention_store = {}

In [51]:
from PIL import Image

# code for aggregaring attention
def aggregate_all_attention(prompts, attention_store: AttentionStore, from_where: List[str], is_cross: bool, select: int):
    attention_maps = attention_store.get_average_attention()
    att_8 = []
    att_16 = []
    att_32 = []
    att_64 = []
    for location in from_where:
        for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
            if item.shape[1] == 8*8:
                cross_maps = item.reshape(len(prompts), -1, 8, 8, item.shape[-1])[select]
                att_8.append(cross_maps)
            if item.shape[1] == 16*16:
                cross_maps = item.reshape(len(prompts), -1, 16, 16, item.shape[-1])[select]
                att_16.append(cross_maps)
            if item.shape[1] == 32*32:
                cross_maps = item.reshape(len(prompts), -1, 32, 32, item.shape[-1])[select]
                att_32.append(cross_maps)
            if item.shape[1] == 64*64:
                cross_maps = item.reshape(len(prompts), -1, 64, 64, item.shape[-1])[select]
                att_64.append(cross_maps)
    atts = []
    for att in [att_8,att_16,att_32,att_64]:
        att = torch.cat(att, dim=0)
        att = att.sum(0) / att.shape[0]
        atts.append(att.cpu())
    return atts

def aggregate_attention(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int):
    out = []
    attention_maps = attention_store.get_average_attention()
    num_pixels = res ** 2
    for location in from_where:
        for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
            if item.shape[1] == num_pixels:
                cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
                out.append(cross_maps)
    out = torch.cat(out, dim=0)
    out = out.sum(0) / out.shape[0]
    return out.cpu()

# visualize cross att
def show_cross_attention(prompts,attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0):
    tokens = tokenizer.encode(prompts[select])
    decoder = tokenizer.decode
    attention_maps = aggregate_attention(prompts, attention_store, res, from_where, True, select)
    images = []
    j = 0
    for i in range(len(tokens)):
        image = attention_maps[:, :, i]
        image = 255 * image / image.max()
        image = image.unsqueeze(-1).expand(*image.shape, 3)
        image = image.numpy().astype(np.uint8)
        image = np.array(Image.fromarray(image).resize((256, 256)))
        if decoder(int(tokens[j])) == "++":
            j += 1  
        image = ptp_utils.text_under_image(image, decoder(int(tokens[j])))
        images.append(image)
        j+=1
        if j >= len(tokens):
            break
    ptp_utils.view_images(np.stack(images, axis=0))
    
# visualize self att
def show_self_attention_comp(prompts,attention_store: AttentionStore, res: int, from_where: List[str],
                        max_com=10, select: int = 0):
    attention_maps = aggregate_attention(prompts, attention_store, res, from_where, False, select).float().numpy().reshape((res ** 2, res ** 2))
    u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
    images = []
    for i in range(max_com):
        image = vh[i].reshape(res, res)
        image = image - image.min()
        image = 255 * image / image.max()
        image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
        image = Image.fromarray(image).resize((256, 256))
        image = np.array(image)
        images.append(image)
    ptp_utils.view_images(np.concatenate(images, axis=1))

In [52]:
def run_and_display(prompts, controller, latent=None, run_baseline=False, generator=None,t=NUM_DIFFUSION_STEPS):
    images, x_t = ptp_utils.text2image_ldm_stable(ldm_stable, prompts, controller, latent=latent, num_inference_steps=t, guidance_scale=GUIDANCE_SCALE, generator=generator, low_resource=LOW_RESOURCE)
    return images, x_t

In [53]:
# diffusion vae
vae = ldm_stable.vae.to(device)
controller = AttentionStore()

def encode_imgs(imgs):
    # imgs: [B, 3, H, W]
    imgs = 2 * imgs - 1
    posterior = vae.encode(imgs).latent_dist.mean
    latents = posterior * 0.18215
    return latents

In [54]:
# fix random seed
def same_seeds(seed):
    torch.manual_seed(seed)  # 固定随机种子（CPU）
    if torch.cuda.is_available():  # 固定随机种子（GPU)
        torch.cuda.manual_seed(seed)  # 为当前GPU设置
        torch.cuda.manual_seed_all(seed)  # 为所有GPU设置
    np.random.seed(seed)  # 保证后续使用random函数时，产生固定的随机数
    torch.backends.cudnn.benchmark = False  # GPU、网络结构固定，可设置为True
    torch.backends.cudnn.deterministic = True  # 固定网络结构


In [55]:
import cv2
import torch.nn.functional as F
# cam visual_code
def show_cam_on_image(img, mask):
    mask = F.interpolate(mask.unsqueeze(0).unsqueeze(0), size=(img.size[1],img.size[0]), mode='bilinear', align_corners=False).squeeze().squeeze()
    img = np.float32(img) / 255.
    heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + img
    cam = cam / np.max(cam)
    cam = np.uint8(255 * cam)
    return cam

# def show_attention(mask, save_path):
#     mask = (mask-mask.min())/(mask.max()-mask.min())
#     heatmap = cv2.applyColorMap(np.uint8(255 * mask), cv2.COLORMAP_JET)
#     attention = np.uint8(heatmap)
#     cv2.imwrite(save_path, attention)

In [56]:
# 主函数
def generate_att(
        t, input_latent, noise, prompts, controller, pos,
        is_self=True,
        is_multi_self=False,
        is_cross_norm=True,
        weight=[0.3,0.5,0.1,0.1]
):
    controller.reset()
    
    # g_cpu = torch.Generator(4307)
    
    # # 根据去噪步数 t 向潜在表示表示中添加噪声
    # latents_noisy = ldm_stable.scheduler.add_noise(
    #     input_latent, noise, torch.tensor(t,device=device)
    # )
    
    # image, x_t = run_and_display(
    #     prompts, controller, latent=latents_noisy,run_baseline=False, generator=g_cpu,t=t
    # )
    
    layers = ("mid", "up", "down")
    
    cross_attention_maps = aggregate_all_attention(prompts,controller, layers, True, 0)
    
    self_attention_maps = aggregate_all_attention(prompts,controller, ["up", "mid", "down"], False, 0)
    
    imgs = []
    
    for idx,res in enumerate([8, 16, 32, 64]):
        out_att = cross_attention_maps[idx].permute(2,0,1).float()
        
        if is_cross_norm:
            att_max = torch.amax(out_att,dim=(1,2),keepdim=True)
            att_min = torch.amin(out_att,dim=(1,2),keepdim=True)
            out_att = (out_att-att_min)/(att_max-att_min)
            
        if is_multi_self: 
            self_att = self_attention_maps[idx].view(res*res,res*res).float() 
            self_att = self_att/self_att.max()
            out_att = torch.matmul(self_att.unsqueeze(0),out_att.view(-1,res*res,1)).view(-1,res,res)
            
        if res != 64:
            out_att = F.interpolate(out_att.unsqueeze(0), size=(64,64), mode='bilinear', align_corners=False).squeeze()
        
        # 应用层权重并存储结果注意力图
        imgs.append(out_att * weight[idx])

    cross_att_map = torch.stack(imgs).sum(0)[pos].mean(0).view(64*64, 1)

    if is_self and not is_multi_self:
        self_att = self_attention_maps[3].view(64*64,64*64).float() 
        self_att = self_att/self_att.max()
        for i in range(1):
            cross_att_map = torch.matmul(self_att,cross_att_map)
            
    att_map = cross_att_map.view(res,res)
    att_map = F.interpolate(att_map.unsqueeze(0).unsqueeze(0), size=(512,512), mode='bilinear', align_corners=False).squeeze().squeeze()
    
    # 归一化并使用Sigmoid增强对比度
    att_map = (att_map-att_map.min())/(att_map.max()-att_map.min())
    att_map = F.sigmoid(8 * (att_map-0.4))
    att_map = (att_map-att_map.min())/(att_map.max()-att_map.min()) 

    """
    att_map_map = Image.fromarray((att_map.cpu().detach().numpy()*255).astype(np.uint8),mode="L")
    
    display(att_map_map)
    
    print("8x8 cross att map")
    show_cross_attention(prompts,controller, res=8, from_where=layers)
    print("8x8 self att map")
    show_self_attention_comp(prompts,controller, res=8, from_where=layers)
    
    print("16x16 cross att map")
    show_cross_attention(prompts,controller, res=16, from_where=layers)
    print("16x16 self att map")
    show_self_attention_comp(prompts,controller, res=16, from_where=layers)
    
    print("32x32 cross att map")
    show_cross_attention(prompts,controller, res=32, from_where=layers)
    print("32x32 self att map")
    show_self_attention_comp(prompts,controller, res=32, from_where=layers)
    
    print("64x64 cross att map")
    show_cross_attention(prompts,controller, res=64, from_where=layers)
    print("64x64 self att map")
    show_self_attention_comp(prompts,controller, res=64, from_where=layers)

    # print("64x64 self att map")
    # show_self_attention_comp(prompts,controller, res=64, from_where=layers)
    """
    
    return att_map

In [57]:
# from torchvision import transforms
# from IPython.display import display
# from PIL import Image
# with torch.no_grad():
#     same_seeds(3407)
#     
#     img_path = "./sample_img/thread.png"
#     input_img = Image.open(img_path).convert("RGB")
# 
#     print("ori_image")
#     display(input_img)
#     
#     # 图像转换操作
#     t = []
#     t.append(transforms.ToTensor())
#     transforms = transforms.Compose(t)
# 
#     img_tensor = (transforms(input_img).unsqueeze(0)).to(device)
# 
#     rgb_512 = F.interpolate(img_tensor, (512, 512), mode='bilinear', align_corners=False).half()
# 
#     input_latent = encode_imgs(rgb_512)
# 
#     noise = torch.randn_like(input_latent).to(device)
# 
#     raw_image = input_img
# 
#     # 目标类名称
#     cls_name = "thread"
#     text = f"a photograph of {cls_name}"
#     
#     # 使用BLIP进行文本-图像输入处理
#     inputs = processor(raw_image,text,return_tensors="pt").to("cuda")  
# 
#     # 使用BLIP生成新的文本描述，进行增强
#     # use blip and "++" emphasizing semantic information of target categories
#     out = model.generate(**inputs)
#     texts = processor.decode(out[0], skip_special_tokens=True)
#     texts = text +"++"+ texts[len(text):]       # 增强目标类别描述，加入更多语义信息
# 
#     g_cpu = torch.Generator(3407)
#     prompts = [texts]
#     print("**** blip_prompt: "+texts+"****")    # 打印生成的增强文本
# 
#     # 以下参数设置：  
#     # - `pos`是目标类别在句子中的位置，例如"plane"在"a photograph of plane"中的位置是4
#     # - `t`是去噪步数，通常在50到150之间
#     pos = [4]   
#     t = 100
#     
#     # 生成注意力掩码，用于引导生成模型
#     mask = generate_att(
#         t,input_latent, noise, prompts, controller, pos,
#         is_self=True,
#         is_multi_self=False,
#         is_cross_norm=True,
#         weight=[0.3, 0.5, 0.1, 0.1]
#     )
# 
#     cam = show_cam_on_image(raw_image, mask)
#     print("visual_cam")
#     display(Image.fromarray(cam[:,:,::-1]))

In [58]:
"""
baseline：计算MIoU与推理速度
"""
from torchvision import transforms
from PIL import Image
import time
import os

def calculate_iou(gt_mask, pred_mask):
    """
    计算 IoU（Intersection over Union）
    :param gt_mask: Ground Truth Mask (numpy array)
    :param pred_mask: Predicted Mask (numpy array)
    :return: IoU score
    """
    intersection = np.logical_and(gt_mask, pred_mask).sum()
    union = np.logical_or(gt_mask, pred_mask).sum()
    if union == 0:
        return 0  # 避免除零错误
    
    iou = intersection / union
    print(f'iou:{iou:.2f}')
    return iou

def calculate_miou(gt_folder, pred_folder):
    """
    计算文件夹中所有图片的 MIoU
    :param gt_folder: Ground Truth 图片文件夹路径
    :param pred_folder: 预测 Mask 图片文件夹路径
    :return: MIoU 值
    """
    gt_files = sorted(os.listdir(gt_folder))
    pred_files = sorted(os.listdir(pred_folder))
    
    if len(gt_files) != len(pred_files):
        raise ValueError("Ground Truth 和预测 Mask 的图片数量不一致！")
    
    iou_scores = []
    
    for gt_file, pred_file in zip(gt_files, pred_files):
        # 加载 Ground Truth 和预测 Mask
        gt_path = os.path.join(gt_folder, gt_file)
        pred_path = os.path.join(pred_folder, pred_file)
        
        gt_mask = np.array(Image.open(gt_path).convert("L"))  # 转换为灰度图
        pred_mask = np.array(Image.open(pred_path).convert("L"))  # 转换为灰度图
        
        # 二值化处理（假设 Ground Truth 和预测 Mask 是二值图像）
        gt_mask = (gt_mask > 128).astype(np.uint8)  # 阈值化
        pred_mask = (pred_mask > 128).astype(np.uint8)  # 阈值化
        
        # 计算 IoU
        iou = calculate_iou(gt_mask, pred_mask)
        iou_scores.append(iou)
    
    # 计算 MIoU
    miou = np.mean(iou_scores)
    return miou


def baseline(img_path, pred_folder, cls_name):
    start_time = time.time()

    with torch.no_grad():
        same_seeds(3407)
        
        # 加载并预处理图像
        input_img = Image.open(img_path).convert("RGB")
        transform = transforms.Compose([transforms.ToTensor()])
        img_tensor = transform(input_img).unsqueeze(0).to(device)
        
        # 调整尺寸并编码到前在空间
        rgb_512 = F.interpolate(img_tensor, (512, 512), mode='bilinear', align_corners=False).half()
        input_latent = encode_imgs(rgb_512)
        noise = torch.randn_like(input_latent).to(device)
    
        # 生成增强文本描述
        text = f"a photograph of {cls_name}" 
        
        # 使用BLIP进行文本-图像输入处理
        inputs = processor(input_img, text,return_tensors="pt").to("cuda")  
        # 使用BLIP生成新的文本描述，进行增强
        # use blip and "++" emphasizing semantic information of target categories
        out = model.generate(**inputs)
        texts = processor.decode(out[0], skip_special_tokens=True)
        texts = text +"++"+ texts[len(text):]       # 增强目标类别描述，加入更多语义信息
        
        prompts = [texts]
    
        # 以下参数设置：  
        # - `pos`是目标类别在句子中的位置，例如"plane"在"a photograph of plane"中的位置是4
        # - `t`是去噪步数，通常在50到150之间
        pos = [4]   
        t = 10
        
        # 生成注意力掩码，用于引导生成模型
        mask = generate_att(
            t,input_latent, noise, prompts, controller, pos,
            is_self=True,
            is_multi_self=False,
            is_cross_norm=True,
            weight=[0.1, 0.4, 0.4, 0.1]
        )
        
        pred_mask_np = mask.cpu().numpy()
        
        threshold = 0.5  # 推荐范围 [0.3, 0.7]
        binary_mask = (pred_mask_np > threshold).astype(np.uint8) * 255  # 二值化为0或255
        
        # 步骤3：转换为PIL图像并调整尺寸
        pred_mask = Image.fromarray(binary_mask).convert("L")
        
        # 步骤4：保持与原图相同尺寸（重要！确保与GT对齐）
        original_size = Image.open(img_path).size  # 获取原始图像尺寸
        pred_mask = pred_mask.resize(original_size)  # 调整到原始尺寸
        
        # 保存结果（保持与原始文件名一致）
        if not os.path.exists(pred_folder):
            os.makedirs(pred_folder)
        pred_path = os.path.join(pred_folder, os.path.basename(img_path))
        pred_mask.save(pred_path)
        
    return time.time() - start_time


if __name__ == '__main__':
    class_name = 'scratch'
    
    img_folder = f'/home/saki/data/mvtec_anomaly_detection/wood/test/{class_name}'
    gt_folder = f'/home/saki/data/mvtec_anomaly_detection/wood/ground_truth/{class_name}'
    pred_folder = f"/home/saki/data/mvtec_anomaly_detection/wood/predictions/{class_name}"

    # 遍历所有测试图像
    total_time = 0
    img_paths = sorted([os.path.join(img_folder, f) for f in os.listdir(img_folder)])
    for idx, img_path in enumerate(img_paths):
        # cls_name = class_names  # 根据实际情况获取类别
        inference_time = baseline(
            img_path, pred_folder, class_name
        )
        total_time += inference_time
        print(f"Processed {idx+1}/{len(img_paths)}, Time: {inference_time:.2f}s")
    
    # 性能评估
    miou = calculate_miou(gt_folder, pred_folder)
    avg_speed = total_time / len(img_paths)
    
    print(f"\nEvaluation Results:")
    print(f"MIoU: {miou:.4f}")
    print(f"Average Speed: {avg_speed:.2f} seconds/image")

Processed 1/21, Time: 0.59s
Processed 2/21, Time: 0.53s
Processed 3/21, Time: 0.59s
Processed 4/21, Time: 0.54s
Processed 5/21, Time: 0.54s
Processed 6/21, Time: 0.60s
Processed 7/21, Time: 0.58s
Processed 8/21, Time: 0.53s
Processed 9/21, Time: 0.53s
Processed 10/21, Time: 0.53s
Processed 11/21, Time: 0.59s
Processed 12/21, Time: 0.54s
Processed 13/21, Time: 0.54s
Processed 14/21, Time: 0.53s
Processed 15/21, Time: 0.54s
Processed 16/21, Time: 0.58s
Processed 17/21, Time: 0.54s
Processed 18/21, Time: 0.55s
Processed 19/21, Time: 0.54s
Processed 20/21, Time: 0.60s
Processed 21/21, Time: 0.54s
iou:0.54
iou:0.25
iou:0.72
iou:0.61
iou:0.47
iou:0.80
iou:0.63
iou:0.80
iou:0.59
iou:0.69
iou:0.52
iou:0.28
iou:0.70
iou:0.62
iou:0.78
iou:0.55
iou:0.36
iou:0.40
iou:0.48
iou:0.29
iou:0.53

Evaluation Results:
MIoU: 0.5538
Average Speed: 0.55 seconds/image
