In [2]:
import os
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
from torchvision import transforms
from segment_anything import sam_model_registry, SamPredictor

# ---------------------------
# 0) change the image path
# ---------------------------
ROOT_DIR   = r"D:\bishe\CUB_200_2011\CUB_200_2011"   # CUB 
OUT_DIR    = r"D:\bishe\data_nor_14_w3_64"       # output
SAM_CKPT   = r"D:\bishe\sam\sam_vit_b_01ec64.pth"    # SAM
MODEL_TYPE = "vit_b"
DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"
OUTPUT_SIZE = (64, 64)

# ---------------------------
# 1) read CUB
# ---------------------------
def load_image_list(root_dir):
    p = os.path.join(root_dir, 'images.txt')
    ids, relpaths = [], []
    with open(p, 'r') as f:
        for line in f:
            idx, rel = line.strip().split()
            ids.append(int(idx))
            relpaths.append(rel)
    return ids, relpaths

def load_keypoints(root_dir):
    """
    return dict: img_id -> {part_id(1..15): (x,y,vis)}
    """
    kp_dict = {}
    p = os.path.join(root_dir, 'parts', 'part_locs.txt')
    with open(p, 'r') as f:
        for line in f:
            img_id, part_id, x, y, v = line.strip().split()
            img_id  = int(img_id); part_id = int(part_id)
            x = float(x); y = float(y); v = int(v)
            kp_dict.setdefault(img_id, {})[part_id] = (x, y, v)
    return kp_dict

def load_bboxes(root_dir):
    """
    return dict: img_id -> (x, y, w, h)
    """
    bbox_dict = {}
    p = os.path.join(root_dir, 'bounding_boxes.txt')
    with open(p, 'r') as f:
        for line in f:
            img_id, x, y, w, h = line.strip().split()
            bbox_dict[int(img_id)] = (float(x), float(y), float(w), float(h))
    return bbox_dict

# ---------------------------
# 2) SAM 
# ---------------------------
def build_sam(ckpt, model_type="vit_b", device="cpu"):
    sam = sam_model_registry[model_type](checkpoint=ckpt).to(device)
    predictor = SamPredictor(sam)
    return predictor

def predict_mask_with_sam(predictor: SamPredictor, image_pil: Image.Image, bbox):
    """
    bbox: (x, y, w, h) in original image coordinates
    return: mask (H, W) uint8 in {0,1}
    """
    image_np = np.array(image_pil)
    predictor.set_image(image_np)
    x, y, w, h = bbox
    input_box = np.array([[x, y, x + w, y + h]])
    masks, _, _ = predictor.predict(
        point_coords=None, point_labels=None,
        box=input_box, multimask_output=False
    )
    return masks[0].astype(np.uint8)

# ---------------------------
# 3) crop and resize
# ---------------------------
def crop_and_resize_by_mask(image_pil, mask01, keypoints_xyv, out_wh=(256,256)):
    mask_bool = mask01.astype(bool)
    if mask_bool.sum() == 0:
        w, h = image_pil.size
        x_min, y_min, x_max, y_max = 0, 0, w-1, h-1
    else:
        ys, xs = np.where(mask_bool)
        y_min, y_max = ys.min(), ys.max()
        x_min, x_max = xs.min(), xs.max()

    # 
    cropped_img = image_pil.crop((x_min, y_min, x_max+1, y_max+1))
    cropped_mask = mask01[y_min:y_max+1, x_min:x_max+1]

    # 
    out_w, out_h = out_wh
    cw, ch = cropped_img.size
    sx, sy = out_w / cw, out_h / ch

    # 
    resized_img = cropped_img.resize((out_w, out_h), Image.BILINEAR)
    resized_mask = Image.fromarray((cropped_mask * 255).astype(np.uint8)).resize((out_w, out_h), Image.NEAREST)

    # 
    new_kps = keypoints_xyv.copy().astype(np.float32)
    for i, (x, y, v) in enumerate(new_kps):
        if v == 0:
            continue
        nx = (x - x_min) * sx
        ny = (y - y_min) * sy
        if nx < 0 or nx >= out_w or ny < 0 or ny >= out_h:
            new_kps[i, 2] = 0
        else:
            new_kps[i, 0] = nx
            new_kps[i, 1] = ny
    return resized_img, resized_mask, new_kps

def normalize_keypoints_xy(keypoints_xyv, out_wh=(256,256)):
    out_w, out_h = out_wh
    k = keypoints_xyv.copy().astype(np.float32)
    k[:,0] /= out_w
    k[:,1] /= out_h
    return k

# ---------------------------
# 4) CUB15 → 14

CUB_TO_AVIAN_PART_IDS = [1, 2, 5, 6, 7, 11, 10, 15, 4, 14, 8, 12, 9, 13]

CUB_IDX_0_BASE = [pid-1 for pid in CUB_TO_AVIAN_PART_IDS]

def cub15_to_avian14(kpts15_xyv):
    """
    kpts15_xyv: (15,3) in (x,y,v) 
    return kpts14_xyv: (14,3)
    """
    k14 = kpts15_xyv[CUB_IDX_0_BASE, :]  # 直接索引抽取并重排
    return k14

# ---------------------------
# 5) main
# ---------------------------
def run(
    root_dir=ROOT_DIR,
    out_dir=OUT_DIR,
    sam_ckpt=SAM_CKPT,
    model_type=MODEL_TYPE,
    device=DEVICE,
    out_size=OUTPUT_SIZE
):
    os.makedirs(out_dir, exist_ok=True)
    print(f"Device: {device}")

    # read
    img_ids, img_relpaths = load_image_list(root_dir)
    kp_dict = load_keypoints(root_dir)
    bbox_dict = load_bboxes(root_dir)

    # SAM
    predictor = build_sam(sam_ckpt, model_type=model_type, device=device)

    for img_id, rel in tqdm(zip(img_ids, img_relpaths), total=len(img_ids)):
        img_path = os.path.join(root_dir, 'images', rel)
        image = Image.open(img_path).convert('RGB')
        
        k15 = np.zeros((15,3), dtype=np.float32)
        parts = kp_dict.get(img_id, {})
        for pid in range(1,16):
            if pid in parts:
                x,y,v = parts[pid]
                k15[pid-1] = (x,y,v)  

        # bbox
        bbox = bbox_dict[img_id]  # (x,y,w,h)

        # SAM 
        mask01 = predict_mask_with_sam(predictor, image, bbox)  # (H,W) 0/1

        # 
        img_resized, mask_resized, k15_cropped = crop_and_resize_by_mask(
            image, mask01, k15, out_wh=out_size
        )

        # 
        k14 = cub15_to_avian14(k15_cropped)

        # （[0,1]）
        k14_norm = normalize_keypoints_xy(k14, out_wh=out_size)

        # save
        sample_id = f"{img_id:06d}"
        dst = os.path.join(out_dir, sample_id)
        os.makedirs(dst, exist_ok=True)

        img_resized.save(os.path.join(dst, "image.png"))
        mask_resized.save(os.path.join(dst, "mask.png"))
        np.save(os.path.join(dst, "keypoints_15.npy"), k15_cropped)   
        np.save(os.path.join(dst, "keypoints14.npy"), k14)
        np.save(os.path.join(dst, "keypoints14_norm.npy"), k14_norm)

    print("✅ finish：image.png, mask.png, keypoints_15.npy, keypoints14.npy, keypoints14_norm.npy")

#run()

In [None]:
run()