In [None]:
!nvidia-smi

## 1. Clone SAM và SAM 2 và cài đặt thư viện cần thiết

In [None]:
!git clone https://github.com/facebookresearch/segment-anything.git
!git clone https://github.com/facebookresearch/segment-anything-2.git

In [None]:
!pip install opencv-python matplotlib scikit-image tqdm

## 2. Tải checkpoint

In [None]:
# SAM
!wget https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth

# SAM2 - large
!wget https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_large.pt

# SAM2 - tiny
!wget https://dl.fbaipublicfiles.com/segment_anything_2/092824/sam2.1_hiera_tiny.pt

## 3. Import & khởi tạo model

In [None]:
import cv2, os, time
import numpy as np
import torch
import pandas as pd
from tqdm import tqdm

## 3.1 SAM

In [None]:
from segment_anything import sam_model_registry, SamPredictor

sam = sam_model_registry["vit_h"](checkpoint="sam_vit_h_4b8939.pth").cuda()

sam_predictor = SamPredictor(sam)

## 3.2 SAM 2

In [None]:
%cd segment-anything-2 # direct vào folder sam2 đã clone
!pip install -e .

In [None]:
import hydra
from hydra import initialize_config_dir
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

hydra.core.global_hydra.GlobalHydra.instance().clear()

initialize_config_dir(
    config_dir="segment-anything-2/sam2/configs", # direct vào folder configs trong sam2
    version_base=None
)

sam2 = build_sam2(
    config_file="sam2.1/sam2.1_hiera_l",
    ckpt_path="sam2.1_hiera_large.pt",
    device="cuda"
)

sam2_predictor = SAM2ImagePredictor(sam2)

sam2_tiny = build_sam2(
    config_file="sam2.1/sam2.1_hiera_t",
    ckpt_path="/kaggle/working/sam2.1_hiera_tiny.pt",
    device="cuda"
)

sam2_tiny_predictor = SAM2ImagePredictor(sam2_tiny)

## 4. Hàm inference

## 4.1 SAM

In [None]:
def run_sam(image_path):
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

    torch.cuda.synchronize()
    start = time.time()

    sam_predictor.set_image(img)
    h, w, _ = img.shape
    point = np.array([[w//2, h//2]])
    label = np.array([1])

    masks, _, _ = sam_predictor.predict(
        point_coords=point,
        point_labels=label,
        multimask_output=False
    )

    torch.cuda.synchronize()
    elapsed = time.time() - start

    mask = masks[0]
    if isinstance(mask, torch.Tensor):
        mask = mask.cpu().numpy()
    if mask.ndim == 3:
        mask = mask[0]

    return mask.astype(bool), elapsed


## 4.2 SAM 2

In [None]:
def run_sam2(image_path):
    img = cv2.imread(image_path)
    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    
    torch.cuda.synchronize()
    start = time.time()
    
    sam2_predictor.set_image(img)

    h, w, _ = img.shape
    point = np.array([[w//2, h//2]])


    masks, _, _ = sam2_predictor.predict(
        point_coords=point,
        point_labels=[1],
        multimask_output=False
    )

    mask = masks[0]

    torch.cuda.synchronize()
    elapsed = time.time() - start
    
    if isinstance(mask, torch.Tensor):
        mask = mask.cpu().numpy()
    if mask.ndim == 3:
        mask = mask[0]

    return mask.astype(bool), elapsed


## 5. Metrics

In [None]:
def iou(pred, gt):
    inter = np.logical_and(pred, gt).sum()
    union = np.logical_or(pred, gt).sum()
    return inter / union if union > 0 else 0

def dice(pred, gt):
    inter = np.logical_and(pred, gt).sum()
    return 2 * inter / (pred.sum() + gt.sum() + 1e-6)

## 6. Evaluation

In [None]:
def find_mask_path(mask_dir, image_name):

    base = os.path.splitext(image_name)[0]
    for ext in [".png", ".jpg", ".jpeg"]:
        p = os.path.join(mask_dir, base + ext)
        if os.path.exists(p):
            return p
    return None


def evaluate_dataset(dataset_name, eval_root):

    img_dir = os.path.join(eval_root, "img")
    mask_dir = os.path.join(eval_root, "mask")

    assert os.path.isdir(img_dir), f"Image folder not found: {img_dir}"
    assert os.path.isdir(mask_dir), f"Mask folder not found: {mask_dir}"

    img_names = sorted(os.listdir(img_dir))
    results = []

    skipped = 0

    for name in tqdm(img_names, desc=f"Evaluating {dataset_name}"):

        img_path = os.path.join(img_dir, name)
        mask_path = find_mask_path(mask_dir, name)
        if mask_path is None:
            print(f"[WARN] Missing mask for {name}, skipping.")
            skipped += 1
            continue

        gt = cv2.imread(mask_path, 0)

        if gt is None:
            print(f"[WARN] Missing mask: {mask_path}, skipping.")
            skipped += 1
            continue

        gt = gt > 0  # binarize
        gt = gt.astype(bool)

        sam_mask, t_sam = run_sam(img_path)

        sam2_mask, t_sam2 = run_sam2(img_path)

        iou_sam = iou(sam_mask, gt)
        iou_sam2 = iou(sam2_mask, gt)
        dice_sam = dice(sam_mask, gt)
        dice_sam2 = dice(sam2_mask, gt)

        results.append([
            name,
            iou_sam, iou_sam2,
            dice_sam, dice_sam2,
            t_sam, t_sam2
        ])

    df = pd.DataFrame(
        results,
        columns=[
            "Image",
            "IoU_SAM", "IoU_SAM2",
            "Dice_SAM", "Dice_SAM2",
            "Time_SAM", "Time_SAM2"
        ]
    )

    summary = df.mean(numeric_only=True)

    print(f"\n{dataset_name} summary:")
    display(summary)

    if skipped > 0:
        print(f"[INFO] Skipped {skipped} images due to missing masks.")

    return df, summary


In [None]:
DATASETS = {
    "COD10K": "COD10K_SPLIT/COD10K_SPLIT/eval",
    "DUTS": "DUTS_SPLIT/DUTS_SPLIT/eval",
    "ECSSD": "ECSSD_SPLIT/ECSSD_SPLIT/eval",
    "Kvasir": "Kvasir_VAL/Kvasir_VAL",
    "MSRA-B": "MSRA-B_SPLIT/MSRA-B_SPLIT/eval"
}

all_summaries = []

for name, path in DATASETS.items():
    _, summary = evaluate_dataset(name, path)
    summary["Dataset"] = name
    all_summaries.append(summary)

final_df = pd.DataFrame(all_summaries).set_index("Dataset")
display(final_df)