# Points Prompt for COCO format

## Import

In [None]:
import tempfile
import json
import os
from pycocotools.coco import COCO
from pycocotools import mask as maskUtils
import matplotlib.pyplot as plt
from icecream import ic
import cv2
import numpy as np
from copy import deepcopy
import uuid

## Build Predictor

In [None]:
import os
import torch
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image

In [None]:
# use bfloat16 for the entire notebook
torch.autocast(device_type="cuda", dtype=torch.bfloat16).__enter__()

if torch.cuda.get_device_properties(0).major >= 8:
    # turn on tfloat32 for Ampere GPUs (https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices)
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True

In [None]:
from sam2.build_sam import build_sam2_video_predictor

sam2_checkpoint = "../SurgicalSAM2/checkpoints/sam2_hiera_tiny.pt"
model_cfg = "sam2_hiera_t.yaml"

predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)

### Helpers

In [None]:
def show_mask(mask, ax, obj_id=None, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        cmap = plt.get_cmap("tab10")
        cmap_idx = 0 if obj_id is None else obj_id
        color = np.array([*cmap(cmap_idx)[:3], 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)


def show_points(coords, labels, ax, marker_size=200):
    pos_points = coords[labels == 1]
    neg_points = coords[labels == 0]
    ax.scatter(
        pos_points[:, 0],
        pos_points[:, 1],
        color="green",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )
    ax.scatter(
        neg_points[:, 0],
        neg_points[:, 1],
        color="red",
        marker="*",
        s=marker_size,
        edgecolor="white",
        linewidth=1.25,
    )


def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(
        plt.Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)
    )

## Build Vedio Dir

### Load Video and Frames Information

In [None]:
with open("/bd_byta6000i0/users/sam2/kyyang/endoscapes_video.json","r") as f:
    video_info = json.load(f)

In [None]:
len(video_info)

In [None]:
video_order = 7
video_dir = tempfile.mkdtemp()

for idx, frame in enumerate(video_info[video_order]['frames']):
    frame_name = formatted_number = str(idx).zfill(8)  # 填充到5位宽度
    dst_path = os.path.join(video_dir, f'{frame_name}.jpg')
    src_path = frame['path']
    os.symlink(src_path,dst_path)
# ic(sorted(os.listdir(video_dir)))


## Load frames into predictor

In [None]:
inference_state = predictor.init_state(video_path=video_dir)

## Load COCO info

In [None]:
annotation_file = "/bd_byta6000i0/users/dataset/MedicalImage/Endoscapes2023/raw/train_seg/annotation_coco.json"
coco = COCO(annotation_file)

In [None]:
num_categories = len(coco.cats)
ic(num_categories)

### Get the first annotated image of current video

In [None]:
img_ids = coco.getImgIds()
imgs = coco.loadImgs(img_ids)
for img in imgs:
    if img["video_id"] == video_info[video_order]["video_id"]:
        ann_ids = coco.getAnnIds(imgIds=img["id"])
        if ann_ids == []:
            continue
        first_frame = img
        break
ic(first_frame)    

In [None]:
for idx, frame in enumerate(video_info[video_order]['frames']):
    if(first_frame['file_name'] == frame['file_name']):
        prompt_frame_id = idx
        break   

### Calculate the prompt point

In [None]:
#TODO: sample more point

def getSamplePointsFromMask(mask: np.ndarray) -> list:
    kernel = np.ones((3, 3), np.uint8)  # 可以调整核的大小来控制闭运算程度

# 对 mask 进行闭运算
    closed_mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
    
    num_labels, labels, stats, centroids = cv2.connectedComponentsWithStats(
        mask.astype(np.uint8)
    )
    center_points = []
    # 遍历每个连通区域
    for i in range(1, num_labels):  # 从 1 开始，因为 0 表示背景
        # 获取中心坐标
        sample_points = []
        
        center_x = centroids[i, 0]
        center_y = centroids[i, 1]

        sample_points.append([center_x, center_y])
        # 将中心坐标添加到列表中
        center_points.append(sample_points)
        
    return center_points

In [None]:
predictor.reset_state(inference_state)

In [None]:
ann_ids = coco.getAnnIds(imgIds=first_frame["id"])
anns = coco.loadAnns(ann_ids)
ann_count = 0
all_points = []
all_masks = np.zeros((first_frame['height'],first_frame['width']))

for ann in anns:
    mask = coco.annToMask(ann)
    all_masks[mask==1] = ann["category_id"]
    sample_points = getSamplePointsFromMask(mask)
    
    for reigon_samples in sample_points:
        
        labels = np.ones(len(reigon_samples))
        points = np.array(reigon_samples)
        # show_points(points, labels, plt.gca())
        ann_obj_id = ann_count * (num_categories + 1) + ann["category_id"]
        # ic(ann_obj_id)
        _, out_obj_ids, out_mask_logits = predictor.add_new_points_or_box(
            inference_state=inference_state,
            frame_idx=prompt_frame_id,
            obj_id=ann_obj_id,
            points=points,
            labels=labels,
        )
        ann_count += 1
        all_points.append(points)
    # break    
    
# ic(anns)

In [None]:
cmap = plt.get_cmap('tab10')  # 'tab10' 提供 10 种不同的颜色

for i, points in enumerate(all_points):
    x = points[:, 0]  # 提取所有点的 x 坐标
    y = points[:, 1]  # 提取所有点的 y 坐标
    plt.scatter(x, y, color=cmap(i % 10))  # 绘制点，使用颜色映射自动分配颜色
plt.imshow(all_masks)
plt.show()

### Visualize the first prompt frame

In [None]:
plt.figure(figsize=(12, 8))
plt.title(f"frame {prompt_frame_id}")
plt.imshow(Image.open(video_info[video_order]['frames'][prompt_frame_id]['path']))
show_points(points, labels, plt.gca())
show_mask((out_mask_logits[2] > 0.0).cpu().numpy(), plt.gca(), obj_id=out_obj_ids[2])

## Predict on the whole video

In [None]:
video_segments = {} 
# video_segments contains the per-frame segmentation results
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state, reverse=True):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }
for out_frame_idx, out_obj_ids, out_mask_logits in predictor.propagate_in_video(inference_state):
    video_segments[out_frame_idx] = {
        out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy()
        for i, out_obj_id in enumerate(out_obj_ids)
    }

# render the segmentation results every few frames


In [None]:
video_segments[0]

### Visualize the result

In [None]:
vis_frame_stride = 15
plt.close("all")
for out_frame_idx in range(0, len(video_info[video_order]['frames']), vis_frame_stride):
    plt.figure(figsize=(6, 4))
    plt.title(f"frame {out_frame_idx}")
    plt.imshow(Image.open(video_info[video_order]['frames'][out_frame_idx]['path']))
    for out_obj_id, out_mask in video_segments[out_frame_idx].items():
        show_mask(out_mask, plt.gca(), obj_id=out_obj_id)

In [None]:
video_segments[29]

### Use remainder to calculate the category and mask and save as COCO format

In [None]:
coco_annotations = []

In [None]:
def mask_to_bbox(mask):
    """
    Extracts the bounding box from a binary mask.
    """
    pos = np.where(mask)
    if len(pos[0]) == 0:
        return None
    xmin, ymin = np.min(pos[1]), np.min(pos[0])
    xmax, ymax = np.max(pos[1]), np.max(pos[0])
    return [float(xmin), float(ymin), float(xmax - xmin + 1), float(ymax - ymin + 1)]

In [None]:
for frame_id in range(len(video_info[video_order]["frames"])):
    current_frame = video_info[video_order]["frames"]
    if current_frame[frame_id]['id'] == None:
        continue

    merged_mask = {}

    ## merge the mask
    for key, mask in video_segments[frame_id].items():
        remainder = key % (num_categories + 1)
        mask = np.logical_or.reduce(mask, axis=0)
        if remainder not in merged_mask:
            merged_mask[remainder] = mask
        else:
            merged_mask[remainder] = np.logical_or(merged_mask[remainder], mask)
    # ic(merged_mask)
    # break
    for key, mask in merged_mask.items():
        annotation = {
            "id": uuid.uuid4(),
            "image_id": current_frame[frame_id]['id'],
            "category_id": key,
            "segmentation": maskUtils.encode(np.asfortranarray(mask)),
            "bbox": mask_to_bbox(mask),
            "area": int(np.sum(mask)),
            "iscrowd": 0,
        }
        coco_annotations.append(annotation)


## Save the result as COCO format

In [None]:
img_ids = coco.getImgIds()
cat_ids = coco.getCatIds()
coco_images = coco.loadImgs(img_ids)
coco_cats = coco.loadCats(cat_ids)
predict_data = {
    "images": coco_images,
    "annotations": coco_annotations,
    "categories": coco_cats
}

In [None]:
coco_images