In [None]:
###### Combination of SAM and depth anything simply ##########################3



# 0. install dependencies: PyTorch, Segment-anything, OpenCV, matplotlib, tqdm, pandas, Transformers(for Depth Anything)
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
!pip install git+https://github.com/facebookresearch/segment-anything.git
!pip install opencv-python matplotlib tqdm pandas
!pip install transformers accelerate safetensors

# 1. mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

# 2. file path configuration
import os
base_dir = "/content/drive/MyDrive/Colab Notebooks/MR_Project/endovis2018_sub"
# endo_image_dir = f"{base_dir}/train/image"
val_image_dir = f"{base_dir}/val/image"
val_label_dir = f"{base_dir}/val/label"

output_mask_dir = "/content/drive/MyDrive/Colab Notebooks/MR_Project/generated_files/sam_masks"
test_image_dir = "/content/drive/MyDrive/Colab Notebooks/MR_Project/generated_files/test/image"
test_label_dir = "/content/drive/MyDrive/Colab Notebooks/MR_Project/generated_files/test/label"
test_mask_dir  = "/content/drive/MyDrive/Colab Notebooks/MR_Project/generated_files/test/sam_mask"
fused_output_dir = "/content/drive/MyDrive/Colab Notebooks/MR_Project/generated_files/fused_depth"

for d in [output_mask_dir, test_image_dir, test_label_dir, test_mask_dir, fused_output_dir]:
    os.makedirs(d, exist_ok=True)

# 3. import libraries and SAM model
import torch, cv2, numpy as np, shutil, matplotlib.pyplot as plt, pandas as pd
from tqdm import tqdm
from segment_anything import sam_model_registry, SamPredictor

sam_checkpoint = "/content/drive/MyDrive/Colab Notebooks/MR_Project/sam_vit_b_01ec64.pth"
model_type = "vit_b"
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)

# 4. Define SAM mask inference function
def generate_mask_with_sam(image_path, predictor, save_path=None, visualize=False):
    image = cv2.imread(image_path)
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    predictor.set_image(image)
    masks, scores, _ = predictor.predict(point_coords=None, point_labels=None, multimask_output=True)
    best_mask = masks[np.argmax(scores)]
    if save_path:
        cv2.imwrite(save_path, (best_mask*255).astype(np.uint8))
    if visualize:
        plt.figure(figsize=(10,5))
        plt.subplot(1,2,1); plt.imshow(image); plt.title("Original")
        plt.subplot(1,2,2); plt.imshow(best_mask,cmap='gray'); plt.title("Predicted Mask")
        plt.show(); plt.close('all')
    if device=="cuda": torch.cuda.empty_cache()
    return best_mask

'''
# 5. Batch generate training set SAM mask
batch_size = 5
image_files = sorted(os.listdir(endo_image_dir))
num_batches = (len(image_files)+batch_size-1)//batch_size

for i in range(num_batches):
    batch = image_files[i*batch_size:(i+1)*batch_size]
    print(f"Processing train batch {i+1}/{num_batches}")
    for fname in batch:
        img_path = os.path.join(endo_image_dir,fname)
        save_path = os.path.join(output_mask_dir,fname)
        if os.path.exists(save_path): continue
        generate_mask_with_sam(img_path,predictor,save_path=save_path)
print("All SAM masks generated.")
'''

# 6. Select validation samples
val_files = sorted([f for f in os.listdir(val_image_dir) if f.endswith(".bmp")])
if len(val_files)==0: raise ValueError(f"{val_image_dir} empty!")

np.random.seed(42)
num_samples = min(20,len(val_files))
sample_files = np.random.choice(val_files,num_samples,replace=False)
print("Selected validate files:",sample_files)

for fname in sample_files:
    shutil.copy(os.path.join(val_image_dir,fname),os.path.join(test_image_dir,fname))
    shutil.copy(os.path.join(val_label_dir,fname),os.path.join(test_label_dir,fname))
    generate_mask_with_sam(os.path.join(test_image_dir,fname),predictor,save_path=os.path.join(test_mask_dir,fname))

# 7. pixel-level accuracy calculation: SAM mask and the groundtruth
# Endovis 2018 dataset provides groundtruth
accuracy_list=[]
for fname in sample_files:
    label = cv2.imread(os.path.join(test_label_dir,fname),cv2.IMREAD_GRAYSCALE)
    mask  = cv2.imread(os.path.join(test_mask_dir,fname),cv2.IMREAD_GRAYSCALE)
    mask_bin = (mask>127).astype(np.uint8)
    label_bin = (label>127).astype(np.uint8)
    acc = (mask_bin==label_bin).sum()/label_bin.size
    accuracy_list.append({"filename":fname,"accuracy":acc})
accuracy_df = pd.DataFrame(accuracy_list)
accuracy_df["accuracy"] = (accuracy_df["accuracy"]*100).round(2)
print(f"Average pixel-level accuracy: {accuracy_df['accuracy'].mean():.2f}%")
display(accuracy_df)

# 8. Hugging Face Depth Anything pipeline
from transformers import pipeline
from PIL import Image

depth_pipe = pipeline(
    task="depth-estimation",
    model="depth-anything/Depth-Anything-V2-Small-hf",
    device=0 if torch.cuda.is_available() else -1
)

# 9. Define (optimized) mask fusion function for following depth estimation
def generate_depth_with_mask_optimized(image_path, mask_path,
                                       base_strength_org=0.5,   # strength factor for tissue, set as 0.5
                                       base_strength_tool=0.8,  # strength factor for tool, set as 0.8
                                       blur_ksize=7,            # Kernel size of Gaussian blur, for noise suppression, odd number
                                       depth_pipe=None):
    img = np.array(Image.open(image_path).convert("RGB")).astype(np.float32)
    mask = np.array(Image.open(mask_path).convert("L"))
    mask_bin = (mask>127).astype(np.float32)
    mask_smooth = cv2.GaussianBlur(mask_bin,(blur_ksize,blur_ksize),0) # mask_smooth: smooth version of SAM Mask, the bigger blur_ksize, the smooth the edge

    red_ch = img[...,0] # pick up the Red channel of image, roughly determine the mask is tissue or Tool
    tissue_mask = (mask_smooth>0)&(red_ch>100)
    tool_mask   = (mask_smooth>0)&(red_ch<=100)

    out = img.copy()
    # human tissue: enhance G/B channels (due to red/pink color); R-0 G-1 B-2
    out[tissue_mask,1] = out[tissue_mask,1]*(1-base_strength_org)+255*base_strength_org
    out[tissue_mask,2] = out[tissue_mask,2]*(1-base_strength_org)+255*base_strength_org
    # tool: enhance B channel, and reduce R channel (due to the silver/gray/black color)
    out[tool_mask,2] = out[tool_mask,2]*(1-base_strength_tool)+255*base_strength_tool
    out[tool_mask,0] = out[tool_mask,0]*(1-0.3*base_strength_tool)

    mod_img = np.clip(out,0,255).astype(np.uint8)

    if depth_pipe is not None:
        input_img = Image.fromarray(mod_img)
        result = depth_pipe(input_img)
        depth = np.array(result["depth"])
        depth_norm = (depth-depth.min())/(depth.max()-depth.min()+1e-8)
        depth_uint8 = (depth_norm*255).astype(np.uint8)
    else:
        depth_uint8=None

    return mod_img, depth_uint8


# 10. Batch generate depth maps and visualize comparisons
output_dir = "/content/drive/MyDrive/Colab Notebooks/MR_Project/generated_files/depth_optimized"
os.makedirs(output_dir,exist_ok=True)

comparison_dir = os.path.join(output_dir,"visual_comparison")
os.makedirs(comparison_dir,exist_ok=True)

stats=[]
for fname in tqdm(sample_files,desc="Generating optimized depth"):
    img_path = os.path.join(test_image_dir,fname)
    mask_path = os.path.join(test_mask_dir,fname)

    # original depth
    result = depth_pipe(Image.open(img_path).convert("RGB"))
    depth = np.array(result["depth"])
    depth_norm = (depth-depth.min())/(depth.max()-depth.min()+1e-8)
    depth_uint8 = (depth_norm*255).astype(np.uint8)
    cv2.imwrite(os.path.join(output_dir,f"{fname}_original.png"),depth_uint8)
    stats.append({"file":fname,"mode":"original","mean_depth":depth_uint8.mean(),"std_depth":depth_uint8.std()})

    # optimized mask depth
    mod_img, depth_opt = generate_depth_with_mask_optimized(img_path,mask_path,depth_pipe=depth_pipe)
    cv2.imwrite(os.path.join(output_dir,f"{fname}_optimized.png"),depth_opt)
    stats.append({"file":fname,"mode":"optimized","mean_depth":depth_opt.mean(),"std_depth":depth_opt.std()})

    # comparison
    img_rgb = np.array(Image.open(img_path).convert("RGB"))
    mask_arr = np.array(Image.open(mask_path).convert("L"))
    fig, axes = plt.subplots(1,4,figsize=(18,4))
    axes[0].imshow(img_rgb); axes[0].set_title("Original"); axes[0].axis("off")
    axes[1].imshow(mask_arr,cmap="gray"); axes[1].set_title("SAM Mask"); axes[1].axis("off")
    axes[2].imshow(depth_uint8,cmap="plasma"); axes[2].set_title("Depth: Original"); axes[2].axis("off")
    axes[3].imshow(depth_opt,cmap="plasma"); axes[3].set_title("Depth: Optimized"); axes[3].axis("off")
    plt.tight_layout()
    plt.savefig(os.path.join(comparison_dir,fname.replace(".bmp","_compare.png")))
    plt.close(fig)

# Save data...
stats_df = pd.DataFrame(stats)
stats_csv = os.path.join(output_dir,"depth_statistics.csv")
stats_df.to_csv(stats_csv,index=False)
display(stats_df.head())

# Visualize some images
import glob
from IPython.display import Image as IPImage, display as ipy_display

comparison_images = sorted(glob.glob(os.path.join(comparison_dir,"*_compare.png")))
print(f"total {len(comparison_images)} comparisons。")
for path in comparison_images[:6]:
    print(os.path.basename(path))
    ipy_display(IPImage(filename=path))

Output hidden; open in https://colab.research.google.com to view.