# 初始化

## 基础导入

In [None]:
%cd '/home/xiaobo/Project/Trans-WSSS'
%pwd
import os.path as osp
import sys
import warnings
from typing import Any

import alchemy_cat.data.plugins.augers as au
import cv2
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn.functional as F
from PIL import Image
from alchemy_cat.acplot import col_all
from alchemy_cat.alg import size2HW
from alchemy_cat.contrib.tasks.wsss.viz import viz_cam
from alchemy_cat.contrib.voc import VOC_CLASSES
from alchemy_cat.py_tools import ADict
from skimage.feature import peak_local_max
from skimage.measure import label, regionprops

from libs.seeding.score import cam2score_cuda
from utils.resize import resize_cam_cuda

sys.path.append("others/segment_anything")
from segment_anything import sam_model_registry, SamPredictor

## mask可视化函数

In [None]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

## CAM/score可视化函数

In [None]:
def viz_fg_cam(img: np.ndarray, lb: np.ndarray,
               cam: np.ndarray | torch.Tensor,
               suptitle: str, with_bg: bool=False,
               input_points: list[np.ndarray]=None, input_labels: list[np.ndarray]=None):
    if torch.is_tensor(cam):
        cam = cam.detach().cpu().numpy()

    pos_names = ['dummy'] + (['background'] if with_bg else []) + [VOC_CLASSES[1:][cls] for cls in fg_cls]

    plt.figure(dpi=250)

    viz_cam(fig=plt.gcf(),
            img_id=img_id, img=img, label=lb,
            cls_in_label=np.ones(len(pos_names), dtype=np.uint8),
            cam=cam,
            cls_names=pos_names,
            get_row_col=col_all)

    if input_points is not None:
        for input_point, input_label, ax in zip(input_points, input_labels, plt.gcf().axes[1::2], strict=True):
            show_points(input_point, input_label, ax, marker_size=75)

    if suptitle:
        plt.suptitle(suptitle, fontsize=8)
    plt.tight_layout()
    plt.show()


def cal_sam_input_size(sam_transform, ori_size: Any, low_res: bool=True) -> tuple[int, int]:
    ori_h, ori_w = size2HW(ori_size)
    sam_input_h, sam_input_w = sam_transform.get_preprocess_shape(ori_h, ori_w, sam_transform.target_length)
    if low_res:
        sam_input_h, sam_input_w = sam_input_h // 4, sam_input_w // 4
    return (sam_input_h, sam_input_w)


def transform_img_lb(sam_transform, img: np.ndarray, lb: np.ndarray, low_res: bool=True) -> tuple[np.ndarray, np.ndarray]:
    sam_input_size = cal_sam_input_size(sam_transform, img.shape[:2], low_res=low_res)
    img, lb = au.scale_img_label(sam_input_size, img, lb, align_corner=False, PIL_mode=Image.BILINEAR)
    img, lb = au.pad_img_label(img, lb, pad_img_to=sam_transform.target_length // (4 if low_res else 1),
                               img_pad_val=0, ignore_label=255, pad_location='right-bottom')
    return img, lb

# 全局数据

# 读取模型

In [None]:
sam_checkpoint = "pretrains/SAM/sam_vit_h_4b8939.pth"
model_type = "vit_h"

device = "cuda:0"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

sam_transform = predictor.transform

sam_img_size = predictor.model.image_encoder.img_size
low_res_mask_size = sam_img_size // 4

print(f'sam_img_size: {sam_img_size}, low_res_mask_size: {low_res_mask_size}')

# 读取数据

## 读取图片

In [None]:
# * 读取和显示图片。
img_id = '2007_000170'
# img_id = '2007_003205'
img = cv2.imread(f'datasets/VOC2012/JPEGImages/{img_id}.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

lb = np.asarray(Image.open(f'datasets/VOC2012/SegmentationClass/{img_id}.png'), dtype=np.uint8)

ori_h, ori_w = img.shape[:2]

sam_low_res_input_h, sam_low_res_input_w = cal_sam_input_size(sam_transform, img.shape[:2], low_res=True)
low_res_img, low_res_lb = transform_img_lb(sam_transform, img, lb, low_res=True)

low_res_pad_h = low_res_mask_size - sam_low_res_input_h
low_res_pad_w = low_res_mask_size - sam_low_res_input_w

plt.imshow(img)
plt.show()

print(f"ori_h: {ori_h}, ori_w: {ori_w}, "
      f"sam_low_res_input_h: {sam_low_res_input_h}, sam_low_res_input_w: {sam_low_res_input_w}, "
      f"low_res_pad_h: {low_res_pad_h}, low_res_pad_w: {low_res_pad_w}")

## 读取CAM、score、affed cam、affed score、CRF后score

In [None]:
# * 设置基础实验。
infer_rslt_dir = 'experiment/clip_cam/离线伪真,CI/l1/ps自cl_loss,5100/infer/final'
aff_rslt_dir = osp.join(infer_rslt_dir, 'aff2次,at_cam,att1次,·5掩阈,ce_npp,mask')
crf_rslt_dir = osp.join(aff_rslt_dir, 'seed/best')

# * 读取图片的原始CAM。
loaded = np.load(osp.join(infer_rslt_dir, 'cam', f'{img_id}.npz'))

fg_cls = loaded['fg_cls']

cam_ori = resize_cam_cuda(torch.as_tensor(loaded['cam'], device=device), (ori_h, ori_w))
score_ori = cam2score_cuda(cam_ori, (ori_h, ori_w), resize_first=True)

viz_fg_cam(img, lb, cam_ori, 'cam_ori', with_bg=False)
viz_fg_cam(img, lb, score_ori, 'score_ori', with_bg=False)

# * 读取图片的affed CAM。
loaded = np.load(osp.join(aff_rslt_dir, 'cam_affed', f'{img_id}.npz'))

cam_affed = resize_cam_cuda(torch.as_tensor(loaded['cam'], device=device), (ori_h, ori_w))  # 有可能aff at score。
score_affed = cam2score_cuda(cam_affed, (ori_h, ori_w), resize_first=True)

viz_fg_cam(img, lb, cam_affed, 'cam_affed', with_bg=False)
viz_fg_cam(img, lb, score_affed, 'score_affed', with_bg=False)

# * 读取图片的CRF后score。
loaded = np.load(osp.join(crf_rslt_dir, 'data', f'{img_id}.npz'))
bg_fg_score = torch.as_tensor(loaded['bg_fg_score'], device=device)

viz_fg_cam(img, lb, bg_fg_score, 'bg_fg_score', with_bg=True)

# 提示法-V1

## 选取合适输入，处理为密集提示。

### 定义密集提示生成函数。

In [None]:
def const_fg_bg_ign(bg_fg_score: torch.Tensor,
                    thresh: float,
                    fg_val:float, bg_val: float, ign_val: float=0, pad_val: float=-40.) -> np.ndarray:
    bg_fg_score = resize_cam_cuda(bg_fg_score, (sam_low_res_input_h, sam_low_res_input_w))

    seed = torch.argmax(bg_fg_score, dim=0)
    low_res_logit_mask = torch.full_like(bg_fg_score[1:, :, :], bg_val)

    for i in range(1, bg_fg_score.shape[0]):
        conf_mask = bg_fg_score[i] > thresh
        low_res_logit_mask[i-1, (seed == i) & conf_mask] = fg_val
        low_res_logit_mask[i-1, (seed == i) & (~conf_mask)] = ign_val

    h, w = low_res_logit_mask.shape[-2:]
    pad_h = low_res_mask_size - h
    pad_w = low_res_mask_size - w
    low_res_logit_mask = F.pad(low_res_logit_mask, (0, pad_w, 0, pad_h), value=pad_val)  # 向右下填0，可以简化坐标变换。

    return low_res_logit_mask

### 得到密集提示。

In [None]:
# * 将score处理为SAM接收的形式（1x256x256，logit形式）。
input_masks = const_fg_bg_ign(bg_fg_score,
                              thresh=0.95,
                              fg_val=4., bg_val=-10., ign_val=0., pad_val=-40.)
viz_fg_cam(low_res_img, low_res_lb, input_masks, 'low_res_logit_mask', with_bg=False)
input_masks = input_masks[:, None, :, :]

## 提取点提示

### 定义点提示生成函数。

In [None]:
def merge_points_other_fg_as_bg(input_points: list[np.ndarray], input_labels: list[np.ndarray]) \
    -> tuple[list[np.ndarray], list[np.ndarray]]:
    fg_points = [input_point[input_label.astype(bool)] for input_point, input_label in zip(input_points, input_labels)]
    bg_points = [input_point[~input_label.astype(bool)] for input_point, input_label in zip(input_points, input_labels)]

    merged_bg_points = []
    for i in range(input_num := len(input_points)):
        my_bg_point = bg_points[i]
        others_fg_points = [fg_points[j] for j in range(input_num) if j != i]
        merged_bg_points.append(np.concatenate([my_bg_point, *others_fg_points], axis=0))

    merged_input_points, merged_input_labels = [], []
    for fg_point, merged_bg_point in zip(fg_points, merged_bg_points):
        merged_input_points.append(merged_input_point := np.concatenate([fg_point, merged_bg_point], axis=0))
        merged_input_label = np.zeros(merged_input_point.shape[0], dtype=np.int64)
        merged_input_label[:fg_point.shape[0]] = 1
        merged_input_labels.append(merged_input_label)

    return merged_input_points, merged_input_labels


def is_ill_heat_map(heat_map: torch.Tensor,
                    fg_min_area_rel: float=.05, bg_min_area_rel: float=.05,
                    fg_thresh_rel: float=0., bg_thresh_rel: float=0.) -> bool | str:
    # * 最大值小于0，最小值大于0，病态。
    max_heat = heat_map.max()
    min_heat = heat_map.min()
    if max_heat < 0 or min_heat > 0:
        return f"max_heat: {max_heat}, min_heat: {min_heat}"
    # * 前景区域过小，病态。
    if (fg_area := (heat_map > fg_thresh_rel * max_heat).sum()) < (fg_min_area := fg_min_area_rel * heat_map.numel()):
        return f"fg_area: {fg_area} < fg_min_area: {fg_min_area}"
    # * 背景区域过小，病态。
    if (bg_area := (heat_map < bg_thresh_rel * min_heat).sum()) < (bg_min_area := bg_min_area_rel * heat_map.numel()):
        return f"bg_area: {bg_area} < bg_min_area: {bg_min_area}"
    return False


def get_point_peak_local(heat_map: torch.Tensor,
                         fg_min_area_rel: float=.05, bg_min_area_rel: float=.05,
                         fg_thresh_rel: float=0., bg_thresh_rel: float=0.,
                         min_distance_rel: float=None,
                         fg_peak_thresh_rel=None, bg_peak_thresh_rel=None,
                         exclude_border=True,
                         max_fg_points_num=np.inf, max_bg_points_num=np.inf,
                         footprint=None, p_norm=np.inf) -> tuple[np.ndarray, np.ndarray] | tuple[None, None]:
    # * 检查是否病态。
    if ill_msg := is_ill_heat_map(heat_map,
                                  fg_min_area_rel, bg_min_area_rel,
                                  fg_thresh_rel, bg_thresh_rel):
        warnings.warn(f"发现病态热力图: {ill_msg}")
        return None, None

    # * 计算前背景共享参数。
    min_distance = round(np.sqrt(heat_map.numel()) * min_distance_rel) if min_distance_rel else 1

    # * 找前景点。
    # ** 构造前景heat map。
    max_heat = heat_map.max()
    fg_heat_map = F.threshold(heat_map, fg_thresh_abs := fg_thresh_rel * max_heat, fg_thresh_abs)
    # ** 找极大值作为前景点。
    fg_points = peak_local_max(fg_heat_map.cpu().numpy(),
                               min_distance=min_distance,
                               threshold_abs=None, threshold_rel=fg_peak_thresh_rel,
                               exclude_border=exclude_border, num_peaks=max_fg_points_num,
                               footprint=footprint, p_norm=p_norm)[:, ::-1]

    # * 找背景点。
    # ** 构造前景heat map。
    inv_heat_map = -heat_map
    max_inv_heat = inv_heat_map.max()
    bg_heat_map = F.threshold(inv_heat_map, bg_thresh_abs := bg_thresh_rel * max_inv_heat, bg_thresh_abs)
    # ** 找极大值作为前景点。
    bg_points = peak_local_max(bg_heat_map.cpu().numpy(),
                               min_distance=min_distance,
                               threshold_abs=None, threshold_rel=bg_peak_thresh_rel,
                               exclude_border=exclude_border, num_peaks=max_bg_points_num,
                               footprint=footprint, p_norm=p_norm)[:, ::-1]

    # * 构造最终输出。
    points = np.concatenate([fg_points, bg_points], axis=0)
    labels = np.zeros(points.shape[0], dtype=np.int64)
    labels[:fg_points.shape[0]] = 1

    return points, labels


def get_point_thresh_label_center(heat_map: torch.Tensor,
                                  fg_min_area_rel: float=.05, bg_min_area_rel: float=.05,
                                  fg_thresh_rel: float=0., bg_thresh_rel: float=0.,
                                  fg_area_thresh_rel=.05, bg_area_thresh_rel=.05,
                                  # exclude_border=True,
                                  # max_fg_points_num=np.inf, max_bg_points_num=np.inf,
                                  ) -> tuple[np.ndarray, np.ndarray] | tuple[None, None]:
    # * 检查是否病态。
    if ill_msg := is_ill_heat_map(heat_map,
                                  fg_min_area_rel, bg_min_area_rel,
                                  fg_thresh_rel, bg_thresh_rel):
        warnings.warn(f"发现病态热力图: {ill_msg}")
        return None, None

    heat_map = heat_map.cpu().numpy()

    # * 找前景点。
    # ** 构造前景响应区域。
    max_heat = heat_map.max()
    fg_region = heat_map > fg_thresh_rel * max_heat
    label_fg_region = label(fg_region, connectivity=2)
    fg_regions = regionprops(label_fg_region)

    # ** 找极大值作为前景点。
    fg_points = []
    for region in fg_regions:
        if region.area < fg_area_thresh_rel * heat_map.size:
            continue
        fg_points.append(region.centroid[::-1])
    fg_points = np.asarray(fg_points, dtype=np.int64)

    # * 找背景点。
    # ** 构造前景heat map。
    inv_heat_map = -heat_map
    max_inv_heat = inv_heat_map.max()
    bg_region = inv_heat_map > bg_thresh_rel * max_inv_heat
    label_bg_region = label(bg_region, connectivity=2)
    bg_regions = regionprops(label_bg_region)
    # ** 找极大值作为前景点。
    bg_points = []
    for region in bg_regions:
        if region.area < bg_area_thresh_rel * heat_map.size:
            continue
        bg_points.append(region.centroid[::-1])
    bg_points = np.asarray(bg_points, dtype=np.int64)

    # * 构造最终输出。
    points = np.concatenate([fg_points, bg_points], axis=0)
    labels = np.zeros(points.shape[0], dtype=np.int64)
    labels[:fg_points.shape[0]] = 1

    return points, labels

### 获取点提示。

In [None]:
# NOTE 可以尝试CAM或CAM affed作为点提示热度图先验。
heat_maps = cam_ori
# heat_maps = cam_affed

input_points, input_labels = [], []
for heat_map in heat_maps:
    # input_point, input_label = get_point_peak_local(heat_map,
    #                                                 fg_min_area_rel=.05, bg_min_area_rel=.05,
    #                                                 fg_thresh_rel=0., bg_thresh_rel=0.,
    #                                                 min_distance_rel=0.05,
    #                                                 fg_peak_thresh_rel=0.2, bg_peak_thresh_rel=0.2,
    #                                                 exclude_border=True,
    #                                                 max_fg_points_num=10, max_bg_points_num=5,
    #                                                 footprint=None, p_norm=np.inf)
    input_point, input_label = get_point_thresh_label_center(heat_map,
                                                             fg_min_area_rel=.05, bg_min_area_rel=.05,
                                                             fg_thresh_rel=0.05, bg_thresh_rel=0.2,
                                                             fg_area_thresh_rel=.001, bg_area_thresh_rel=.01)
    input_points.append(input_point)
    input_labels.append(input_label)
viz_fg_cam(img, lb, heat_maps, '', with_bg=False, input_points=input_points, input_labels=input_labels)

### 融合点提示。

In [None]:
input_points, input_labels = merge_points_other_fg_as_bg(input_points, input_labels)
viz_fg_cam(img, lb, heat_maps, '', with_bg=False, input_points=input_points, input_labels=input_labels)

# SAM推理

## 输入SAM，获取分割结果

In [None]:
# * 将score作为mask直接输入SAM，观察输出。
predictor.set_image(img)

sam_prompts_outs = []
for input_point, input_label, input_mask in zip(input_points, input_labels, input_masks, strict=True):
    spo = ADict()
    spo.input_point = input_point
    spo.input_label = input_label
    spo.input_mask = None
    sam_prompts_outs.append(spo)

In [None]:
for spo in sam_prompts_outs:
    spo.mask, spo.score, spo.low_res_logit_mask = predictor.predict(
        point_coords=spo.input_point,
        point_labels=spo.input_label,
        mask_input = spo.input_mask,
        multimask_output=False,
        return_logits=False
    )

    spo.logit_mask, _, _ = predictor.predict(
        point_coords=spo.input_point,
        point_labels=spo.input_label,
        mask_input = spo.input_mask,
        multimask_output=False,
        return_logits=True
    )

## 打印SAM分割可视化

In [None]:
# * 打印提示和分割结果。
for i, spo in enumerate(sam_prompts_outs):
    print(f"=============== Plot spo {i}===============")

    plt.figure(figsize=(10, 10))
    plt.imshow(img)
    show_mask(spo.mask, plt.gca())
    show_points(spo.input_point, spo.input_label, plt.gca())
    plt.title(f"Score: {spo.score[0]:.3f}", fontsize=18)
    plt.axis('on')
    plt.show()

    # * 打印low_res_logit_mask。
    plt.figure(figsize=(10, 10))
    plt.imshow(spo.low_res_logit_mask.transpose(1, 0, 2).reshape(low_res_mask_size, -1))
    plt.colorbar(fraction=0.1)
    plt.axis('on')
    plt.show()

    # * 打印logit_mask。
    plt.figure(figsize=(10, 10))
    plt.imshow(spo.logit_mask.transpose(1, 0, 2).reshape(img.shape[0], -1))
    plt.colorbar(fraction=0.025)
    plt.axis('on')
    plt.show()

## 打印SAM分割数值分析

In [None]:
for i, spo in enumerate(sam_prompts_outs):
    print(f"=============== Stat spo {i}===============")

    print(f"{spo.logit_mask.mean()=}, {spo.logit_mask.std()=}, {spo.logit_mask.min()=}, {spo.logit_mask.max()=}")

    for i, logit in enumerate(spo.low_res_logit_mask):
        print(f"mask {i} padding logit: {logit[-low_res_pad_h:, -low_res_pad_w:].mean()=}")

    for i, logit in enumerate(spo.logit_mask):
        print(f"mask {i} fg logit: {logit[spo.mask[i]].mean()=}")
        print(f"mask {i} bg logit: {logit[~spo.mask[i]].mean()=}")

    plt.figure(figsize=(10, 10))
    for i, logit in enumerate(spo.low_res_logit_mask):
        plt.subplot(1, 1, i + 1)
        plt.hist(logit.flatten(), bins=50)
        plt.yscale('log', base=10)
        plt.tick_params(axis='both', which='major', labelsize=20)
    plt.show()

# 终点