In [None]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import os
import cv2
from tqdm import trange
import shutil

In [None]:
def show_anns(anns,img_name,save_path,iou_list_all):
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    # ax = plt.gca()
    # ax.set_autoscale_on(False)
    # polygons = []
    # color = []
    for ann in sorted_anns:
        m = ann['segmentation']
        img = np.ones((m.shape[0], m.shape[1], 3))
        color_mask = np.random.random((1, 3)).tolist()[0]
        for i in range(3):
            img[:,:,i] = color_mask[i]
        plt.imshow(np.dstack((img, m*0.35)))
    plt.tight_layout()
    plt.axis('off') 
    # plt.show()
    img_save_path = os.path.join(save_path, img_name)
    plt.savefig(img_save_path, bbox_inches='tight', pad_inches=0)
    # if any(i > 0.8 for i in iou_list_all):
    #     shutil.copy(img_save_path, os.path.join(save_path, 'high_iou', img_name))
    # elif any(i <0.01 for i in iou_list_all):
    #     shutil.copy(img_save_path, os.path.join(save_path, 'low_iou', img_name))
    # else:
    #     pass
    plt.clf()
    # print('save image to {}'.format(img_save_path))

def show_anns_private(anns,img_name,save_path,iou_list_all):   
    if len(anns) == 0:
        return
    sorted_anns = sorted(anns, key=(lambda x: x['area']), reverse=True)
    # ax = plt.gca()
    # ax.set_autoscale_on(False)
    # polygons = []
    # color = []
    for ann in sorted_anns:
        m = ann['segmentation']
        img = np.ones((m.shape[0], m.shape[1], 3))
        color_mask = np.random.random((1, 3)).tolist()[0]
        for i in range(3):
            img[:,:,i] = color_mask[i]
        plt.imshow(np.dstack((img, m*0.35)))
    plt.tight_layout()
    plt.axis('off') 
    # plt.show()
    img_save_path = os.path.join(save_path, img_name)
    plt.savefig(img_save_path, bbox_inches='tight', pad_inches=0)
    if any(i > 0.8 for i in iou_list_all):
        shutil.copy(img_save_path, os.path.join(save_path, 'high_iou', img_name))
    elif any(i <0.01 for i in iou_list_all):
        shutil.copy(img_save_path, os.path.join(save_path, 'low_iou', img_name))
    else:
        pass
    plt.clf()
    # print('save image to {}'.format(img_save_path))

def calculate_iou(mask1, mask2):
    intersection = np.logical_and(mask1.astype(bool), mask2.astype(bool))
    union = np.logical_or(mask1.astype(bool), mask2.astype(bool))
    iou = np.sum(intersection) / np.sum(union)
    return iou

def calculate_mask_area(mask):
    if isinstance(mask, np.ndarray):
        return np.sum(mask)
    elif isinstance(mask, torch.Tensor):
        return torch.sum(mask).item()
    else:
        raise TypeError("Input mask must be a numpy array or PyTorch tensor.")

# 计算掩码的形态特征
def calculate_mask_shape(mask):
    # print(mask.dtype)
    contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if len(contours) == 0:
        return 0
    contour = contours[0]
    area = cv2.contourArea(contour)
    perimeter = cv2.arcLength(contour, True)
    if perimeter == 0:
        return 0
    circularity = 4 * np.pi * area / (perimeter * perimeter)
    return circularity

def calculate_max_IoU_everything(masks, gt, Loose_IoU=False):
    """
    计算一系列掩码与GT的IoU或Loose_IoU最大值。

    Args:
        masks: 包含掩码的列表。
        gt: GT掩码。

    Returns:
        与GT的IoU或Loose_IoU最大值。
    """
    ious = []
    for mask in masks:
        mask = mask['segmentation'].astype(int) * 255
        mask = mask.astype(np.uint8)

        # 计算预测掩码和标签掩码的面积
        if Loose_IoU == True:
            area_threshold = 5
            circularity_threshold = 0.5
            predicted_area = float(calculate_mask_area(mask))
            true_area = float(calculate_mask_area(gt))

            # 计算预测掩码和标签掩码的形态特征
            predicted_circularity = calculate_mask_shape(mask)
            true_circularity = calculate_mask_shape(gt)

            # 如果预测掩码和标签掩码的面积差距过大或形态特征不匹配，则将iou置为0
            if abs(predicted_area-true_area) / true_area > area_threshold or abs(predicted_circularity - true_circularity) > circularity_threshold:
                iou=0
            else:
                iou = calculate_iou(mask,gt)
            ious.append(iou)
        else:
            iou = calculate_iou(mask,gt)
            ious.append(iou)

    if len(ious) == 0:
        max_iou = 0
    else:
        max_iou = np.max(ious)
    return max_iou
    
def calculate_max_IoU_everything_multi_class(masks, GT, p, Loose_IoU=False):   
    num_classes = len(object_classes)
    one_hot_mask = np.zeros((GT.shape[0], GT.shape[1], num_classes), dtype=np.uint8)
    iou_list_all = []
    for i in range(1,num_classes):
        one_hot_mask[:, :, i] = (GT == i).astype(np.uint8)
        if one_hot_mask[:, :, i].sum() < 20:
            iou_list_all.append(np.nan)
        else: 
            # 找到掩码中的所有轮廓
            contours, _ = cv2.findContours(one_hot_mask[:, :, i], cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

            # 提取每个轮廓的外接矩形框
            _boxes = [cv2.boundingRect(contour) for contour in contours]
            # boxes_new = []
            boxes_new = _boxes
            
            # 创建一个与原始掩码相同大小的空白图像
            new_masks = [np.zeros_like(one_hot_mask[:, :, i]) for _ in boxes_new]

            # 将每个部分分离出来，并将它们放置在新掩码中
            for j, box in enumerate(boxes_new):
                x, y, w, h = box
                mask_part = one_hot_mask[:, :, i][y:y+h, x:x+w].copy()
                new_mask_part = np.zeros_like(one_hot_mask[:, :, i])
                new_mask_part[y:y+h, x:x+w][mask_part > 0] = 255
                new_masks[j][new_mask_part > 0] = 255
            iou_list = []   
            for j in range(len(new_masks)):
                max_iou = calculate_max_IoU_everything(masks, new_masks[j], Loose_IoU=Loose_IoU)
                iou_list.append(max_iou)
            everything_iou_list[i][p].append(np.nanmean(iou_list)) 
            iou_list_all.append(np.nanmean(iou_list))
                
    _img_name=os.path.basename(image_path).replace(os.path.splitext(file_name)[1],'_vit-h_' + everything_mode_list[p] + '_' + '_'.join([str(round(iou, 3)) if not np.isnan(iou) else 'None' for iou in iou_list_all])  + '.png')
    show_anns_private(masks,_img_name,save_path,iou_list_all) # 保存推理结果

In [None]:
import sys
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor

# # 加载模型
# sam_checkpoint = "/home/ubuntu/8TDisk/tao/HCX/GroundedSAM-zero-shot-anomaly-detection/weights/sam_vit_h_4b8939.pth"
# model_type = "vit_h"

sam_checkpoint = "/home/ubuntu/8TDisk/tao/HCX/GroundedSAM-zero-shot-anomaly-detection/weights/sam_vit_b_01ec64.pth"
model_type = "vit_b"

# sam_checkpoint = "/home/ubuntu/8TDisk/tao/HCX/GroundedSAM-zero-shot-anomaly-detection/weights/sam_vit_l_0b3195.pth"
# model_type = "vit_l"

device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

# 配置相关参数
mask_generator_32 = SamAutomaticMaskGenerator(sam)
mask_generator_8 = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=8,
)
mask_generator_64 = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=64,
)
mask_generator_128 = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=128,
)


In [None]:
# Everything one-hot mask
# 文件夹路径
folder_path = "/home/ubuntu/8TDisk/tao/HCX/GroundedSAM-zero-shot-anomaly-detection/jj2233/dataset_reg_mvtec/"
iou_list_32 = []
iou_list_8 = []
iou_list_64 = []
iou_list_128 = []
img_format = ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG', '.bmp', '.BMP']
save_path = "/home/ubuntu/8TDisk/tao/HCX/GroundedSAM-zero-shot-anomaly-detection/jj2233/dataset_reg_mvtec_everything/"

# 遍历文件夹中的文件
for i in trange(len(os.listdir(folder_path))):
    file_name = os.listdir(folder_path)[i]
    # 如果文件是原图
    if os.path.splitext(file_name)[1] in img_format and '_mask' not in file_name:
        # 构造原图路径和标签路径
        image_path = os.path.join(folder_path, file_name)
        mask_path = os.path.join(folder_path, file_name[:-4] + '_mask.png')
        # 读取原图和标签
        image = cv2.imread(image_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        GT = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)
        
        # 在这里对原图和标签进行处理
        everything_32 = mask_generator_32.generate(image)
        max_iou_32 = calculate_max_IoU_everything(everything_32, GT, Loose_IoU=True)
        iou_list_32.append(max_iou_32)
        img_name_32=os.path.basename(image_path).replace(os.path.splitext(file_name)[1],'_vit-l_everything_32_' + str(max_iou_32) + '.png')
        everything_8 = mask_generator_8.generate(image)
        max_iou_8 = calculate_max_IoU_everything(everything_8, GT, Loose_IoU=True)
        iou_list_8.append(max_iou_8)
        img_name_8=os.path.basename(image_path).replace(os.path.splitext(file_name)[1],'_vit-l_everything_8_' + str(max_iou_8) + '.png')
        everything_64 = mask_generator_64.generate(image)
        max_iou_64 = calculate_max_IoU_everything(everything_64, GT, Loose_IoU=True)
        iou_list_64.append(max_iou_64)
        img_name_64=os.path.basename(image_path).replace(os.path.splitext(file_name)[1],'_vit-l_everything_64_' + str(max_iou_64) + '.png')
        everything_128 = mask_generator_128.generate(image)
        max_iou_128 = calculate_max_IoU_everything(everything_128, GT, Loose_IoU=True)
        iou_list_128.append(max_iou_128)
        img_name_128=os.path.basename(image_path).replace(os.path.splitext(file_name)[1],'_vit-l_everything_128_' + str(max_iou_128) + '.png')

        # 保存推理结果
        # show_anns(everything_32,img_name_32,save_path,max_iou_32)
        # show_anns(everything_8,img_name_8,save_path,max_iou_8)
        # show_anns(everything_64,img_name_64,save_path,max_iou_64)
        # show_anns(everything_128,img_name_128,save_path,max_iou_128)

# mean
print('32:',np.mean(iou_list_32))
print('8:',np.mean(iou_list_8))
print('64:',np.mean(iou_list_64))
print('128:',np.mean(iou_list_128))




In [None]:
# PRIVATE_DATASETS
# Everything multi-class mask
# 文件夹路径
folder_path = "/home/ubuntu/8TDisk/tao/HCX/GroundedSAM-zero-shot-anomaly-detection/jj2233/dataset_reg_screen/"
everything_mode_list = ['everything_32','everything_8','everything_64','everything_128']
# object_classes = ['background','yiwu','dengzhu','xigao','tongban','lougu']
object_classes = ['background','class1','class2','class3','class4','class5']
img_format = ['.png', '.jpg', '.jpeg', '.PNG', '.JPG', '.JPEG', '.bmp', '.BMP']
save_path = "/home/ubuntu/8TDisk/tao/HCX/GroundedSAM-zero-shot-anomaly-detection/jj2233/dataset_reg_screen_everything/"
# 遍历文件夹中的文件
for times in range(1):
    everything_iou_list = [[[] for _ in range(len(everything_mode_list))] for _ in range(len(object_classes))]
    for i in trange(len(os.listdir(folder_path))//10):
        file_name = os.listdir(folder_path)[i]
        # 如果文件是原图
        if os.path.splitext(file_name)[1] in img_format and '_mask' not in file_name:
            # 构造原图路径和标签路径
            image_path = os.path.join(folder_path, file_name)
            mask_path = os.path.join(folder_path, file_name[:-4] + '_mask.png')
            
            # 读取原图和标签
            image = cv2.imread(image_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            GT = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)  
              
            # 在这里对原图和标签进行处理
            
            everything_32 = mask_generator_32.generate(image)
            max_iou_32 = calculate_max_IoU_everything_multi_class(everything_32, GT, 0, Loose_IoU=False)

            everything_8 = mask_generator_8.generate(image)
            max_iou_8 = calculate_max_IoU_everything_multi_class(everything_8, GT, 1, Loose_IoU=False)

            everything_64 = mask_generator_64.generate(image)
            max_iou_64 = calculate_max_IoU_everything_multi_class(everything_64, GT, 2, Loose_IoU=False)

            everything_128 = mask_generator_128.generate(image)
            max_iou_128 = calculate_max_IoU_everything_multi_class(everything_128, GT, 3, Loose_IoU=False)

    print(f'第{times}轮:')
    for iou_list in range(len(everything_iou_list)):
        print(object_classes[iou_list]+':')
        for iou_detail in range(len(everything_iou_list[iou_list])):
            print(str(everything_mode_list[iou_detail])+ ':' + str(np.mean(everything_iou_list[iou_list][iou_detail])))
    print('*'*20 )




In [None]:
# single image inference
# 文件夹路径

# 构造原图路径和标签路径
image_path = "/home/ubuntu/8TDisk/tao/HCX/GroundedSAM-zero-shot-anomaly-detection/jj2233/dataset_reg_private/miniLED_raw_1.png"

file_name = os.path.basename(image_path)
save_path = '/home/ubuntu/8TDisk/tao/HCX/GroundedSAM-zero-shot-anomaly-detection/jj2233/dataset_reg_mvtec_everything/'
mask_path = image_path.replace('.png','_mask.png')
# 读取原图和标签
image = cv2.imread(image_path)
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
GT = cv2.imread(mask_path, cv2.IMREAD_GRAYSCALE)  
    
# 在这里对原图和标签进行处理

everything_32 = mask_generator_32.generate(image)
max_iou_32 = calculate_max_IoU_everything(everything_32, GT, Loose_IoU=False)
# iou_list_32.append(max_iou_32)
img_name_32=os.path.basename(image_path).replace(os.path.splitext(file_name)[1],'_vit-l_everything_32_' + str(max_iou_32) + '.png')
everything_8 = mask_generator_8.generate(image)
max_iou_8 = calculate_max_IoU_everything(everything_8, GT, Loose_IoU=False)
# iou_list_8.append(max_iou_8)
img_name_8=os.path.basename(image_path).replace(os.path.splitext(file_name)[1],'_vit-l_everything_8_' + str(max_iou_8) + '.png')
everything_64 = mask_generator_64.generate(image)
max_iou_64 = calculate_max_IoU_everything(everything_64, GT, Loose_IoU=False)
# iou_list_64.append(max_iou_64)
img_name_64=os.path.basename(image_path).replace(os.path.splitext(file_name)[1],'_vit-l_everything_64_' + str(max_iou_64) + '.png')
everything_128 = mask_generator_128.generate(image)
max_iou_128 = calculate_max_IoU_everything(everything_128, GT, Loose_IoU=False)
# iou_list_128.append(max_iou_128)
img_name_128=os.path.basename(image_path).replace(os.path.splitext(file_name)[1],'_vit-l_everything_128_' + str(max_iou_128) + '.png')

# 保存推理结果
show_anns(everything_32,img_name_32,save_path,max_iou_32)
show_anns(everything_8,img_name_8,save_path,max_iou_8)
show_anns(everything_64,img_name_64,save_path,max_iou_64)
show_anns(everything_128,img_name_128,save_path,max_iou_128)







SamAutomaticMaskGenerator()参数详解：

- model (Sam):用于掩膜预测的SAM模型。
- points_per_side (int or None): 沿着图像一侧采样的点的数量。点的总数是point_per_side**2。如果没有，'point_grids’必须提供明确的点采样。
- points_per_batch (int):设置模型同时运行的点的数量。更高的数字可能会更快，但会使用更多的GPU内存。
- pred_iou_thresh (float): 滤波阈值，在[0,1]中，使用模型的预测掩膜质量。
- stability_score_thresh (float):滤波阈值，在[0,1]中，使用掩码在用于二进制化模型的掩码预测的截止点变化下的稳定性。
- stability_score_offset (float):计算稳定性分数时，对截止点的偏移量。
- box_nms_thresh (float):非最大抑制用于过滤重复掩码的箱体IoU截止点。
- crop_n_layers (int):如果>0，蒙版预测将在图像的裁剪上再次运行。设置运行的层数，其中每层有2**i_layer的图像裁剪数。
- crop_nms_thresh (float):非最大抑制用于过滤不同作物之间的重复掩码的箱体IoU截止值。
- crop_overlap_ratio (float):设置作物重叠的程度。在第一个作物层中，作物将以图像长度的这个分数重叠。在第一个裁剪层中，裁剪物将以图像长度的这一比例重叠，以后的裁剪层中，更多的裁剪物将缩小这一重叠。
- crop_n_points_downscale_factor (int):在图层n中每面采样的点数被crop_n_points_downscale_factor**n缩减。
- point_grids (list(np.ndarray) or None):用于取样的明确网格的列表，归一化为[0,1]。列表中的第n个网格被用于第n个作物层。与points_per_side排他。
- min_mask_region_area (int):如果>0，后处理将被应用于移除面积小于min_mask_region_area的遮罩中的不连接区域和孔。需要opencv。
- output_mode (str):掩模的返回形式。可以是’binary_mask’, ‘uncompressed_rle’, 或者’coco_rle’。coco_rle’需要pycocotools。对于大的分辨率，'binary_mask’可能会消耗大量的内存。


SamAutomaticMaskGenerator()参数默认值：
model: Sam,
points_per_side: Optional[int] = 32,
points_per_batch: int = 64,
pred_iou_thresh: float = 0.88,
stability_score_thresh: float = 0.95,
stability_score_offset: float = 1.0,
box_nms_thresh: float = 0.7,
crop_n_layers: int = 0,
crop_nms_thresh: float = 0.7,
crop_overlap_ratio: float = 512 / 1500,
crop_n_points_downscale_factor: int = 1,
point_grids: Optional[List[np.ndarray]] = None,
min_mask_region_area: int = 0,
output_mode: str = “binary_mask”,