In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2
from tqdm import trange
import os
import shutil


In [None]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
    
def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='x', s=marker_size/2, linewidth=1.25)   
   
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    

def calculate_iou(mask1, mask2):
    """
    计算两个掩码的IoU。

    Args:
        mask1: 掩码1。
        mask2: 掩码2。

    Returns:
        两个掩码的IoU。
    """

    if mask1.shape != mask2.shape:
        raise ValueError("mask1 and mask2 must have the same shape.")
    if isinstance(mask1, np.ndarray):
        mask1 = mask1.astype(bool)
        mask2 = mask2.astype(bool)
        intersection = np.logical_and(mask1, mask2)
        union = np.logical_or(mask1, mask2)
        return np.sum(intersection) / np.sum(union)
    elif isinstance(mask1, torch.Tensor):
        mask1 = mask1.to(torch.bool)
        mask2 = mask2.to(torch.bool)
        intersection = torch.logical_and(mask1, mask2)
        union = torch.logical_or(mask1, mask2)
        return torch.sum(intersection).item() / torch.sum(union).item()
    else:
        raise TypeError("Input mask must be a numpy array or PyTorch tensor.")

def calculate_mask_area(mask):
    if isinstance(mask, np.ndarray):
        return np.sum(mask)
    elif isinstance(mask, torch.Tensor):
        return torch.sum(mask).item()
    else:
        raise TypeError("Input mask must be a numpy array or PyTorch tensor.")

# 计算掩码的形态特征
def calculate_mask_shape(mask):
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if len(contours) == 0:
        return 0
    contour = contours[0]
    area = cv2.contourArea(contour)
    perimeter = cv2.arcLength(contour, True)
    if perimeter == 0:
        return 0
    circularity = 4 * np.pi * area / (perimeter * perimeter)
    return circularity
  
def calculate_max_Loose_iou_points(masks, gt, Loose_IoU=False):
    """
    计算一系列掩码与GT的Loose_IoU最大值。

    Args:
        masks: 包含掩码的列表。
        gt: GT掩码。

    Returns:
        与GT的Loose_IoU最大值。
    """

    ious = []
    for mask in masks:
        mask = mask.astype(np.uint8) * 255
    
        if Loose_IoU == True:
            area_threshold = 5
            circularity_threshold = 0.8
            # 计算预测掩码和标签掩码的面积
            predicted_area = float(calculate_mask_area(mask))
            true_area = float(calculate_mask_area(gt))
            # print(predicted_area,true_area)

            # 计算预测掩码和标签掩码的形态特征
            predicted_circularity = calculate_mask_shape(mask)
            true_circularity = calculate_mask_shape(gt)
            # print(predicted_circularity,true_circularity)

            # 如果预测掩码和标签掩码的面积差距过大或形态特征不匹配，则将iou置为0
            if abs(predicted_area-true_area) / true_area > area_threshold or abs(predicted_circularity - true_circularity) > circularity_threshold:
                iou=0
            else:
                iou = calculate_iou(mask,gt)
            ious.append(iou)
        else:
            iou=calculate_iou(mask,gt)
            ious.append(iou)

    if len(ious) == 0:
        max_iou = 0
        max_iou_index = 0
    else:
        max_iou = np.max(ious)
        max_iou_index = np.argmax(ious)
    return max_iou, max_iou_index

def calculate_max_Loose_iou_bbox(masks, gt, Loose_IoU=False):
    """
    计算一系列掩码与GT的Loose_IoU最大值。

    Args:
        masks: 包含掩码的张量。
        gt: GT掩码。

    Returns:
        与GT的Loose_IoU最大值。
    """

    merged_mask = torch.sum(masks, dim=0)
    merged_mask = torch.clamp(merged_mask, 0, 1)
    
    # gt_tensor=torch.from_numpy(gt).to(predictor.device)
    merged_mask = merged_mask > 0
    mask = merged_mask[0].cpu().numpy().astype(np.uint8)
    mask = mask * 255
    # plt.imshow(merged_mask[0].cpu().numpy(), cmap='gray')
    if Loose_IoU == True:
        area_threshold = 5
        circularity_threshold = 0.5
        # 计算预测掩码和标签掩码的面积
        predicted_area = float(calculate_mask_area(mask))
        true_area = float(calculate_mask_area(gt))

        # 计算预测掩码和标签掩码的形态特征
        predicted_circularity = calculate_mask_shape(mask)
        true_circularity = calculate_mask_shape(gt)

        # 如果预测掩码和标签掩码的面积差距过大或形态特征不匹配，则将iou置为0
        if abs(predicted_area-true_area) / true_area > area_threshold or abs(predicted_circularity - true_circularity) > circularity_threshold:
            return 0
        else:    
            return calculate_iou(mask,gt)
    else:
        return calculate_iou(mask,gt)

def get_random_points(GT, foreground_num, background_num):

# 从前景和背景中随机采样正负样本点

    # 随机选择前景和背景像素
    foreground_coords = np.where(GT > 0)
    background_coords = np.where(GT == 0)
    foreground_indices = np.random.choice(len(foreground_coords[0]), size=foreground_num, replace=True)
    background_indices = np.random.choice(len(background_coords[0]), size=background_num, replace=False)
    foreground_pixels = np.array([(foreground_coords[1][i], foreground_coords[0][i]) for i in foreground_indices])
    background_pixels = np.array([(background_coords[1][i], background_coords[0][i]) for i in background_indices])

    # 构造输入点和标签数组
    if background_num == 0:
        input_point = foreground_pixels
        input_label = np.ones(foreground_num)
    else:
        input_point = np.concatenate([foreground_pixels, background_pixels], axis=0)
        input_label = np.concatenate([np.ones(foreground_num), np.zeros(background_num)], axis=0)

    return input_point, input_label

def get_random_points_from_bbox(GT, foreground_num, background_num):

# 从缺陷外接矩形框的前景和背景中随机采样正负样本点

    # 找到白色掩码的外接矩形框
    contours, _ = cv2.findContours(GT, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    boxes = [cv2.boundingRect(contour) for contour in contours]

    # 将每个box放大10%
    boxes_new = []
    for box in boxes:
        x0, y0, w, h = box
        if w*h < 50:
            continue
        else:
            w_new = int(1.1 * w)
            h_new = int(1.1 * h)
            box_new = (x0, y0, w_new, h_new)
            boxes_new.append(box_new)
    if len(boxes_new) == 0:
        boxes_new = boxes
        # print(boxes_new)
    # 随机选择前景和背景像素
    # print(boxes_new)
    foreground_pixels = []
    background_pixels = []
    for box in boxes_new:
        x0, y0, w, h = box
        mask = GT[y0:y0+h, x0:x0+w]
        foreground_coords = np.where(mask > 0)
        background_coords = np.where(mask == 0)
        foreground_indices = np.random.choice(len(foreground_coords[0]), size=foreground_num, replace=True)
        background_indices = np.random.choice(len(background_coords[0]), size=background_num, replace=True)
        foreground_pixels.append(np.array([(foreground_coords[1][i]+x0, foreground_coords[0][i]+y0) for i in foreground_indices]))
        background_pixels.append(np.array([(background_coords[1][i]+x0, background_coords[0][i]+y0) for i in background_indices]))

    # 构造输入点和标签数组
    if background_num == 0:
        input_point = np.concatenate(foreground_pixels, axis=0)
        input_label = np.concatenate([np.ones(foreground_num) for _ in range(len(boxes_new))], axis=0)
    else:
        input_point = np.concatenate(foreground_pixels + background_pixels, axis=0)
        input_label = np.concatenate([np.ones(foreground_num) for _ in range(len(boxes_new))] + [np.zeros(background_num) for _ in range(len(boxes_new))], axis=0)

    return input_point, input_label

def get_bbox(GT,scale=0):

    # 找到白色掩码的外接矩形框
    contours, _ = cv2.findContours(GT, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    ori_boxes = [cv2.boundingRect(contour) for contour in contours]
    if ori_boxes == []:
        print(file_name)
        raise ValueError('ori_boxes is empty')
    boxes = []
    for box in ori_boxes:
        left = box[0]
        top = box[1]
        right = box[0] + box[2]
        bottom = box[1] + box[3]
        boxes.append([left, top, right, bottom])
    boxes = torch.tensor(boxes,device=predictor.device)
    centers = (boxes[:, :2] + boxes[:, 2:]) / 2

    # 计算边界框的宽度和高度
    widths = boxes[:, 2] - boxes[:, 0]
    heights = boxes[:, 3] - boxes[:, 1]

    # 计算新的边界框宽度和高度
    new_widths = widths * (1 + scale)
    new_heights = heights * (1 + scale)

    # 计算新的边界框左上角和右下角坐标
    new_lefts = centers[:, 0] - new_widths / 2
    new_tops = centers[:, 1] - new_heights / 2
    new_rights = centers[:, 0] + new_widths / 2
    new_bottoms = centers[:, 1] + new_heights / 2

    # 创建新的边界框张量
    new_boxes = torch.stack([new_lefts, new_tops, new_rights, new_bottoms], dim=1)

    return new_boxes



In [None]:
import sys
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

# # 加载模型
# sam_checkpoint = "/home/ubuntu/8TDisk/tao/HCX/GroundedSAM-zero-shot-anomaly-detection/weights/sam_vit_h_4b8939.pth"
# model_type = "vit_h"

sam_checkpoint = "/home/ubuntu/8TDisk/tao/HCX/GroundedSAM-zero-shot-anomaly-detection/weights/sam_vit_b_01ec64.pth"
model_type = "vit_b"

# sam_checkpoint = "/home/ubuntu/8TDisk/tao/HCX/GroundedSAM-zero-shot-anomaly-detection/weights/sam_vit_l_0b3195.pth"
# model_type = "vit_l"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

# 调用预测模型
predictor = SamPredictor(sam)

# # 通过调用`SamPredictor.set_image`来处理图像以产生一个图像嵌入。`SamPredictor`会记住这个嵌入，并将其用于随后的掩码预测。
# predictor.set_image(image)
# print(predictor.features.shape)



In [None]:
# point one-hot mask

# 文件夹路径
folder_path = "/home/ubuntu/8TDisk/tao/HCX/GroundedSAM-zero-shot-anomaly-detection/jj2233/dataset_reg_other/"
dataset_lists = ['AITEX', 'BSData', 'CrackForest', 'KSDD', 'MTD', 'RSDDs']
task_list=[[1,0],[2,0],[4,0],[16,0],[1,1],[2,2],[4,4],[16,16]]
# task_iou_list = [[[] for _ in range(len(task_list))] for _ in range(len(dataset_lists))]
img_format = ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG', '.bmp', '.BMP']

for times in range(1,6):
    task_iou_list = [[[] for _ in range(len(task_list))] for _ in range(len(dataset_lists))]
    for p,task in enumerate(task_list):

        # 遍历文件夹中的文件
        for i in trange(len(os.listdir(folder_path))):
            file_name = os.listdir(folder_path)[i]
            # 如果文件是原图
            if os.path.splitext(file_name)[1] in img_format and '_mask' not in file_name:
                # 构造原图路径和标签路径
                image_path = os.path.join(folder_path, file_name)
                mask_path = os.path.join(folder_path, file_name[:-4] + '_mask.png')
                # 读取原图和标签
                image = cv2.imread(image_path)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                GT = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                # 通过调用`SamPredictor.set_image`来处理图像以产生一个图像嵌入。`SamPredictor`会记住这个嵌入，并将其用于随后的掩码预测。
                predictor.set_image(image)
                # 示例用法
                input_point, input_label = get_random_points(GT, task[0], task[1])
                # plt.imshow(image)
                masks, scores, logits = predictor.predict(
                    point_coords=input_point,
                    point_labels=input_label,
                    multimask_output=True, # 是否产生多个掩码
                )
                iou,num=calculate_max_Loose_iou_points(masks, GT, Loose_IoU=True)
                task_iou_list[dataset_lists.index(file_name.split('_')[0])][p].append(iou)
                # # display
                # show_mask(masks[num], plt.gca())
                # show_points(input_point, input_label, plt.gca())
                # plt.tight_layout()
                # plt.axis('off')
                # img_name=os.path.basename(image_path).replace(os.path.splitext(file_name)[1],'_vit-b_' + str(task[0]) +'_' + str(task[1]) + '_' + str(iou) + '.png')
                # img_save_path = './point_bbox_vit_b/'+ img_name
                # plt.savefig(img_save_path, bbox_inches='tight', pad_inches=0) 
                # plt.clf()
    print(f'第{times}轮:')
    for iou_list in range(len(task_iou_list)):
        print(dataset_lists[iou_list]+':')
        for iou_detail in range(len(task_iou_list[iou_list])):
            print(str(task_list[iou_detail])+ ':' + str(np.mean(task_iou_list[iou_list][iou_detail])))
    print('*'*20 )
    







In [None]:
# point multi-class mask

# 文件夹路径
folder_path = "/home/ubuntu/8TDisk/tao/HCX/GroundedSAM-zero-shot-anomaly-detection/jj2233/dataset_reg_screen/"
# dataset_lists = ['AITEX', 'BSData', 'CrackForest', 'KSDD', 'MTD', 'RSDDs']
# object_classes = ['background','yiwu','dengzhu','xigao','tongban','lougu']
object_classes = ['background','class1','class2','class3','class4','class5']
# task_list=[[1,0],[2,0],[4,0],[16,0],[1,1],[2,2],[4,4],[16,16]]
task_list = [[1,0]]
# task_iou_list = [[[] for _ in range(len(task_list))] for _ in range(len(dataset_lists))]
img_format = ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG', '.bmp', '.BMP']
save_path = "/home/ubuntu/8TDisk/tao/HCX/GroundedSAM-zero-shot-anomaly-detection/jj2233/dataset_reg_screen_point_bbox_vit_b/"

for times in range(1):
    task_iou_list = [[[] for _ in range(len(task_list))] for _ in range(len(object_classes))]
    for p,task in enumerate(task_list):

        # 遍历文件夹中的文件
        for j in trange(len(os.listdir(folder_path))//100):
        # for j in range(len(os.listdir(folder_path))):
            file_name = os.listdir(folder_path)[j]
            # 如果文件是原图
            if os.path.splitext(file_name)[1] in img_format and '_mask' not in file_name:
                # 构造原图路径和标签路径
                image_path = os.path.join(folder_path, file_name)
                # print(image_path)
                mask_path = os.path.join(folder_path, file_name[:-4] + '_mask.png')
                # 读取原图和标签
                image = cv2.imread(image_path)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                GT = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                # 通过调用`SamPredictor.set_image`来处理图像以产生一个图像嵌入。`SamPredictor`会记住这个嵌入，并将其用于随后的掩码预测。
                predictor.set_image(image)
                # 示例用法
                num_classes = len(object_classes)
                one_hot_mask = np.zeros((GT.shape[0], GT.shape[1], num_classes), dtype=np.uint8)
                for i in range(1,num_classes):
                    one_hot_mask[:, :, i] = (GT == i).astype(np.uint8)
                    if one_hot_mask[:, :, i].sum() == 0:
                        continue
                    else: 
                        input_point, input_label = get_random_points(one_hot_mask[:, :, i], task[0], task[1])
                        plt.imshow(image)
                        masks, scores, logits = predictor.predict(
                            point_coords=input_point,
                            point_labels=input_label,
                            multimask_output=True, # 是否产生多个掩码
                        )
                        iou,num=calculate_max_Loose_iou_points(masks, one_hot_mask[:, :, i], Loose_IoU=True)
                        task_iou_list[i][p].append(iou)
                    # display
                    # show_mask(masks[num], plt.gca())
                    # show_points(input_point, input_label, plt.gca())
                    # plt.tight_layout()
                    # plt.axis('off')
                    # img_name = os.path.basename(image_path).replace(os.path.splitext(file_name)[1],'_vit-b_' + object_classes[i] + '_' + str(task[0]) +'_' + str(task[1]) + '_' + str(iou) + '.png')
                    # img_save_path = os.path.join(save_path, img_name)
                    # plt.savefig(img_save_path, bbox_inches='tight', pad_inches=0) 
                    # if iou > 0.9:
                    #     shutil.copy(img_save_path, os.path.join(save_path, 'high_iou', img_name))
                    # elif iou < 0.01:
                    #     shutil.copy(img_save_path, os.path.join(save_path, 'low_iou', img_name))
                    # else:
                    #     pass
                    # plt.clf()
    print(f'第{times}轮:')
    for iou_list in range(len(task_iou_list)):
        print(object_classes[iou_list]+':')
        for iou_detail in range(len(task_iou_list[iou_list])):
            print(str(task_list[iou_detail])+ ':' + str(np.mean(task_iou_list[iou_list][iou_detail])))
    print('*'*20 )
    







In [None]:
# bbox one-hot mask
# 文件夹路径
folder_path = "/home/ubuntu/8TDisk/tao/HCX/GroundedSAM-zero-shot-anomaly-detection/jj2233/dataset_reg_mvtec/"
# folder_path = '/home/ubuntu/8TDisk/tao/CJZ/industrial-5i/'
# dataset_lists = ['AITEX', 'BSData', 'CrackForest', 'KSDD', 'MTD', 'RSDDs']
dataset_lists = ['MVTEC']
# scale_list=[0,0.05,0.1,0.2,-0.05,-0.1,-0.2]
scale_list=[0]
# scale_iou_list = [[[] for _ in range(len(scale_list))] for _ in range(len(dataset_lists))]
img_format = ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG', '.bmp', '.BMP']

for times in range(1,2):
    scale_iou_list = [[[] for _ in range(len(scale_list))] for _ in range(len(dataset_lists))]
    for p,scale in enumerate(scale_list):
        # 遍历文件夹中的文件
        for i in trange(len(os.listdir(folder_path))):
            file_name = os.listdir(folder_path)[i]
            # 如果文件是原图
            if os.path.splitext(file_name)[1] in img_format and '_mask' not in file_name:
                # 构造原图路径和标签路径
                image_path = os.path.join(folder_path, file_name)
                mask_path = os.path.join(folder_path, file_name[:-4] + '_mask.png')
                # 读取原图和标签
                image = cv2.imread(image_path)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                GT = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                # 通过调用`SamPredictor.set_image`来处理图像以产生一个图像嵌入。`SamPredictor`会记住这个嵌入，并将其用于随后的掩码预测。
                predictor.set_image(image)
                # 示例用法
                input_boxes = get_bbox(GT, scale=scale)
                transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
                masks, _, _ = predictor.predict_torch(
                    point_coords=None,
                    point_labels=None,
                    boxes=transformed_boxes,
                    multimask_output=False,
                )
                iou=calculate_max_Loose_iou_bbox(masks, GT, Loose_IoU=True)
                scale_iou_list[dataset_lists.index(file_name.split('_')[0])][p].append(iou)

                # plt.imshow(image)
                # for mask in masks:
                #     show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
                # for box in input_boxes:
                #     show_box(box.cpu().numpy(), plt.gca())
                # plt.tight_layout()
                # plt.axis('off') 
                # img_name=os.path.basename(image_path).replace(os.path.splitext(file_name)[1],'_vit-b_bbox_scale_' + str(scale) + '_' + str(iou) + '.png')
                # img_save_path = './point_bbox_vit_b/'+ img_name
                # # img_save_path = './test/'+ img_name
                # plt.savefig(img_save_path, bbox_inches='tight', pad_inches=0)
                # plt.clf()
    print(f'第{times}轮:')
    for iou_list in range(len(scale_iou_list)):
        print(dataset_lists[iou_list]+':')
        for iou_detail in range(len(scale_iou_list[iou_list])):
            print(str(scale_list[iou_detail])+ ':' + str(np.mean(scale_iou_list[iou_list][iou_detail])))
    print('*'*20 )



In [None]:
# bbox multi-class mask
# 文件夹路径
folder_path = "/home/ubuntu/8TDisk/tao/HCX/GroundedSAM-zero-shot-anomaly-detection/jj2233/dataset_reg_screen/"
# dataset_lists = ['AITEX', 'BSData', 'CrackForest', 'KSDD', 'MTD', 'RSDDs']
# dataset_lists = ['MVTEC']
# object_classes = ['background','yiwu','dengzhu','xigao','tongban','lougu']
object_classes = ['background','class1','class2','class3','class4','class5']
# scale_list=[0,0.05,0.1,0.2,-0.05,-0.1,-0.2]
scale_list=[0]
# scale_iou_list = [[[] for _ in range(len(scale_list))] for _ in range(len(dataset_lists))]
img_format = ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG', '.bmp', '.BMP']
save_path = "/home/ubuntu/8TDisk/tao/HCX/GroundedSAM-zero-shot-anomaly-detection/jj2233/dataset_reg_screen_point_bbox_vit_b/"

for times in range(1):
    scale_iou_list = [[[] for _ in range(len(scale_list))] for _ in range(len(object_classes))]
    for p,scale in enumerate(scale_list):
        # 遍历文件夹中的文件
        for i in trange(len(os.listdir(folder_path))//10):
            file_name = os.listdir(folder_path)[i]
            # 如果文件是原图
            if os.path.splitext(file_name)[1] in img_format and '_mask' not in file_name:
                # 构造原图路径和标签路径
                image_path = os.path.join(folder_path, file_name)
                mask_path = os.path.join(folder_path, file_name[:-4] + '_mask.png')
                # 读取原图和标签
                image = cv2.imread(image_path)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                GT = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                # 通过调用`SamPredictor.set_image`来处理图像以产生一个图像嵌入。`SamPredictor`会记住这个嵌入，并将其用于随后的掩码预测。
                predictor.set_image(image)
                # 示例用法
                num_classes = len(object_classes)
                one_hot_mask = np.zeros((GT.shape[0], GT.shape[1], num_classes), dtype=np.uint8)
                for i in range(1,num_classes):
                    one_hot_mask[:, :, i] = (GT == i).astype(np.uint8)
                    if one_hot_mask[:, :, i].sum() == 0:
                        continue
                    else: 
                        input_boxes = get_bbox(one_hot_mask[:, :, i], scale=scale)
                        transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
                        masks, _, _ = predictor.predict_torch(
                            point_coords=None,
                            point_labels=None,
                            boxes=transformed_boxes,
                            multimask_output=False,
                        )
                        plt.imshow(image)
                        iou=calculate_max_Loose_iou_bbox(masks, one_hot_mask[:, :, i], Loose_IoU=True)
                        scale_iou_list[i][p].append(iou)

                        plt.imshow(image)
                        for mask in masks:
                            show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
                        for box in input_boxes:
                            show_box(box.cpu().numpy(), plt.gca())
                        plt.tight_layout()
                        plt.axis('off') 
                        img_name=os.path.basename(image_path).replace(os.path.splitext(file_name)[1],'_vit-b_' + object_classes[i] + '_bbox_scale_' + str(scale) + '_' + str(iou) + '.png')
                        img_save_path = os.path.join(save_path, img_name)
                        plt.savefig(img_save_path, bbox_inches='tight', pad_inches=0) 
                        if iou > 0.9:
                            shutil.copy(img_save_path, os.path.join(save_path, 'high_iou', img_name))
                        elif iou < 0.01:
                            shutil.copy(img_save_path, os.path.join(save_path, 'low_iou', img_name))
                        else:
                            pass
                        plt.savefig(img_save_path, bbox_inches='tight', pad_inches=0)
                        plt.clf()
    print(f'第{times}轮:')
    for iou_list in range(len(scale_iou_list)):
        print(object_classes[iou_list]+':')
        for iou_detail in range(len(scale_iou_list[iou_list])):
            print(str(scale_list[iou_detail])+ ':' + str(np.mean(scale_iou_list[iou_list][iou_detail])))
    print('*'*20 )





In [None]:
# bbox_industrial-5i
# 文件夹路径
# folder_path = "/home/ubuntu/8TDisk/tao/HCX/GroundedSAM-zero-shot-anomaly-detection/jj2233/dataset_reg_mvtec/"
folder_path = '/home/ubuntu/8TDisk/tao/CJZ/industrial-5i/'
# dataset_lists = ['AITEX', 'BSData', 'CrackForest', 'KSDD', 'MTD', 'RSDDs']
dataset_lists = ['industrial-5i']
# scale_list=[0,0.05,0.1,0.2,-0.05,-0.1,-0.2]
scale_list=[0]
# scale_iou_list = [[[] for _ in range(len(scale_list))] for _ in range(len(dataset_lists))]
# img_format = ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG', '.bmp', '.BMP']
img_format = ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG']

for times in range(1,2):
    scale_iou_list = [[[] for _ in range(len(scale_list))] for _ in range(len(dataset_lists))]
    for p,scale in enumerate(scale_list):
        # 遍历文件夹中的文件
        for i in trange(len(os.listdir(folder_path))):
            file_name = os.listdir(folder_path)[i]
            # 如果文件是原图
            if os.path.splitext(file_name)[1] in img_format and '_mask' not in file_name:
                # 构造原图路径和标签路径
                image_path = os.path.join(folder_path, file_name)
                # mask_path = os.path.join(folder_path, file_name[:-4] + '_mask.png')
                mask_path = image_path.replace('.png','.bmp')
                # 读取原图和标签
                image = cv2.imread(image_path)
                image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
                GT = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
                if np.all(GT == 0):
                    continue
                else:
                    # 通过调用`SamPredictor.set_image`来处理图像以产生一个图像嵌入。`SamPredictor`会记住这个嵌入，并将其用于随后的掩码预测。
                    predictor.set_image(image)
                    # 示例用法
                    input_boxes = get_bbox(GT, scale=scale)
                    transformed_boxes = predictor.transform.apply_boxes_torch(input_boxes, image.shape[:2])
                    masks, _, _ = predictor.predict_torch(
                        point_coords=None,
                        point_labels=None,
                        boxes=transformed_boxes,
                        multimask_output=False,
                    )
                    # iou=calculate_max_iou_bbox(masks, GT)
                    iou = calculate_max_Loose_iou_bbox(masks, GT, Loose_IoU=True)
                    file_name = 'industrial-5i_' + file_name
                    scale_iou_list[dataset_lists.index(file_name.split('_')[0])][p].append(iou)

                    # plt.imshow(image)
                    # for mask in masks:
                    #     show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
                    # for box in input_boxes:
                    #     show_box(box.cpu().numpy(), plt.gca())
                    # plt.tight_layout()
                    # plt.axis('off') 
                    # img_name=os.path.basename(image_path).replace(os.path.splitext(file_name)[1],'_vit-b_bbox_scale_' + str(scale) + '_' + str(iou) + '.png')
                    # img_save_path = './point_bbox_vit_b/'+ img_name
                    # # img_save_path = './test/'+ img_name
                    # plt.savefig(img_save_path, bbox_inches='tight', pad_inches=0)
                    # plt.clf()
    print(f'第{times}轮:')
    for iou_list in range(len(scale_iou_list)):
        print(dataset_lists[iou_list]+':')
        for iou_detail in range(len(scale_iou_list[iou_list])):
            print(str(scale_list[iou_detail])+ ':' + str(np.mean(scale_iou_list[iou_list][iou_detail])))
    print('*'*20 )

