In [None]:
import numpy as np
import torch
from transformers import AutoImageProcessor, AutoModel
import matplotlib.pyplot as plt
from PIL import Image
from tifffile import imread

In [None]:

model_name = "facebook/dinov3-vitl16-pretrain-sat493m"  # smaller + sat-trained

processor = AutoImageProcessor.from_pretrained(model_name)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = AutoModel.from_pretrained(model_name).eval().to(device)

print("device:", device)

In [None]:
#load incomplite unrefined labels

labels = imread("/run/media/mak/Partition of 1TB disk/SH_dataset/planet_labels_2022.tif")
image = imread("/home/mak/PycharmProjects/SegEdge/experiments/get_data_from_api/patches_mt/dop20_593000_5982000_1km_20cm.tif")

image.shape, labels.shape

In [None]:
import rasterio
from rasterio.plot import reshape_as_image
from rasterio.windows import Window

img_path = "/home/mak/PycharmProjects/SegEdge/experiments/get_data_from_api/patches_mt/dop20_593000_5982000_1km_20cm.tif"

with rasterio.open(img_path) as img1:
    # Basic metadata
    crs = img1.crs                 # e.g., EPSG:25832 for German DOP20 tiles [web:60]

    transform = img1.transform     # Affine mapping pixel -> map coords [web:60]
    width, height = img1.width, img1.height  # pixel dimensions [web:61]
    count = img1.count             # number of bands [web:61]
    bounds = img1.bounds           # left, bottom, right, top in CRS units [web:40]

    # Read all bands into array shaped (bands, H, W)
    arr = img1.read()              # preserves dtype; no normalization [web:40]
    print("crs:", crs, " transform:", transform, " size:", (width, height), " bands:", count)

# If you want HxWxC for visualization:
img = reshape_as_image(arr)       # converts (C,H,W) -> (H,W,C) [web:40]

#print image
plt.imshow(img)

In [None]:
import rasterio
from rasterio.warp import reproject, Resampling

lab_path = "/run/media/mak/Partition of 1TB disk/SH_dataset/planet_labels_2022.tif"

with rasterio.open(img_path) as ref, rasterio.open(lab_path) as img1:
    dst_meta = ref.meta.copy()
    dst_meta.update(dtype=img1.dtypes[0], count=img1.count)  # keep label dtype/bands [web:65]
    labels_on_img = rasterio.io.MemoryFile().open(**dst_meta)
    for i in range(1, img1.count + 1):
        dest = rasterio.band(labels_on_img, i)
        reproject(
            source=rasterio.band(img1, i),
            destination=dest,
            src_transform=img1.transform, src_crs=img1.crs,
            dst_transform=ref.transform, dst_crs=ref.crs,
            dst_width=ref.width, dst_height=ref.height,
            resampling=Resampling.nearest,  # categorical labels [web:65][web:74]
        )
    # read to array if needed
    labels_arr = labels_on_img.read()  # shape (bands, H, W) aligned to image [web:65]
    print("reprojected labels shape:", labels_arr.shape)


plt.imshow(labels_arr[0])


In [None]:
from skimage.segmentation import find_boundaries

# labels_arr is shape (1, H, W) but find_boundaries expects (H, W)
# Extract the first band to get 2D array
labels_2d = labels_arr[0]  # now shape (H, W)

boundaries = find_boundaries(labels_2d, mode='thick')  # binary mask of edges
overlay = img.copy()
overlay[boundaries] = [0, 0, 0]

plt.figure(figsize=(20,20))
plt.imshow(overlay)
plt.axis('off')
plt.show()


In [None]:
import torch
import numpy as np
from sklearn.cluster import KMeans

# Extract patch features from DINOv3
inputs = processor(images=img, return_tensors="pt").to(device)

# Check what size the processor actually created
print(f"Processed input shape: {inputs['pixel_values'].shape}")  # (1, 3, H, W)

with torch.no_grad():
    outputs = model(**inputs)
    last_hidden_states = outputs.last_hidden_state  # (1, 1 + num_patches, hidden_dim)

# Remove CLS token
patch_features = last_hidden_states[:, 1:, :]  # (1, num_patches, hidden_dim)
batch_size, num_patches, hidden_dim = patch_features.shape
print(f"Total patches: {num_patches}, hidden_dim: {hidden_dim}")

# Get actual processed image dimensions from processor
processed_h = inputs['pixel_values'].shape[2]
processed_w = inputs['pixel_values'].shape[3]
patch_size = model.config.patch_size

# Calculate correct patch grid dimensions (may not be square!)
num_patches_h = processed_h // patch_size
num_patches_w = processed_w // patch_size
expected_patches = num_patches_h * num_patches_w

print(f"Processed image: {processed_h}x{processed_w}")
print(f"Patch size: {patch_size}")
print(f"Calculated patch grid: {num_patches_h}x{num_patches_w} = {expected_patches} patches")
print(f"Actual patches from model: {num_patches}")

# If they don't match, the image might be rectangular or have register tokens
# Work directly with flattened patches
patch_features_flat = patch_features[0].cpu().numpy()  # (num_patches, hidden_dim)

# Create a simple spatial mask: resize label mask to approximate patch layout
# Use the actual num_patches to guess dimensions
if num_patches == 200:  # Common case: 10x20 or 20x10 or adjust based on aspect ratio
    aspect_ratio = img.shape[1] / img.shape[0]  # W/H
    num_patches_h = int(np.sqrt(num_patches / aspect_ratio))
    num_patches_w = num_patches // num_patches_h
    print(f"Adjusted patch grid based on aspect ratio: {num_patches_h}x{num_patches_w}")

# Verify
if num_patches_h * num_patches_w != num_patches:
    print(f"WARNING: Grid mismatch! Using flattened approach instead.")
    # Fallback: cluster ALL patches, no spatial filtering
    features_inside = patch_features_flat
else:
    # Reshape to spatial grid
    patch_features_2d = patch_features_flat.reshape(num_patches_h, num_patches_w, hidden_dim)

    # Downsample label mask to match the patch grid
    from skimage.transform import resize
    mask_patches = resize(labels_2d.astype(float) > 0,
                         (num_patches_h, num_patches_w),
                         order=0, anti_aliasing=False) > 0.5

    # Extract features only from patches inside boundaries
    features_inside = patch_features_2d[mask_patches]

print(f"Features to cluster: {features_inside.shape}")

# Cluster with k-means
n_clusters = 5
kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
cluster_labels = kmeans.fit_predict(features_inside)

# Create cluster map at patch resolution
if num_patches_h * num_patches_w == num_patches:
    cluster_map = np.zeros((num_patches_h, num_patches_w), dtype=int)
    cluster_map[mask_patches] = cluster_labels + 1

    # Upsample to original image size
    from skimage.transform import resize
    cluster_map_full = resize(cluster_map.astype(float),
                              (img.shape[0], img.shape[1]),
                              order=0, preserve_range=True, anti_aliasing=False).astype(int)
else:
    # Fallback: create a simple visualization
    cluster_map_full = np.zeros((img.shape[0], img.shape[1]), dtype=int)
    cluster_map_full[labels_2d > 0] = np.random.choice(cluster_labels, size=(labels_2d > 0).sum())

# Visualize
import matplotlib.pyplot as plt
plt.figure(figsize=(15, 5))
plt.subplot(131)
plt.imshow(img)
plt.title("Original")
plt.axis('off')

plt.subplot(132)
plt.imshow(labels_2d > 0, cmap='gray')
plt.title("Label Boundaries")
plt.axis('off')

plt.subplot(133)
plt.imshow(cluster_map_full, cmap='tab10')
plt.title(f"DINOv3 Clusters (k={n_clusters})")
plt.axis('off')
plt.tight_layout()
plt.show()

print(f"Unique clusters: {np.unique(cluster_labels)}")


In [None]:
import torch
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from scipy.ndimage import binary_dilation
import matplotlib.pyplot as plt

def extract_features_from_image(image_array, model, processor, device):
    """Extract DINOv3 patch features from an image."""
    inputs = processor(images=image_array, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        last_hidden_states = outputs.last_hidden_state

    # Remove CLS token
    patch_features = last_hidden_states[:, 1:, :]  # (1, num_patches, hidden_dim)
    patch_features_flat = patch_features[0].cpu().numpy()  # (num_patches, hidden_dim)

    # Get processed dimensions
    processed_h = inputs['pixel_values'].shape[2]
    processed_w = inputs['pixel_values'].shape[3]
    patch_size = model.config.patch_size

    num_patches_h = processed_h // patch_size
    num_patches_w = processed_w // patch_size

    # Reshape to spatial grid
    patch_features_2d = patch_features_flat.reshape(num_patches_h, num_patches_w, -1)

    return patch_features_2d, num_patches_h, num_patches_w


def propagate_from_seed_points(image, seed_points, model, processor, device,
                                 similarity_threshold=0.7, max_iterations=50):
    """
    Segment image by propagating from seed points using feature similarity.

    Args:
        image: RGB image array (H, W, 3)
        seed_points: List of (row, col) tuples marking seed pixels
        model: DINOv3 model
        processor: Image processor
        device: torch device
        similarity_threshold: Cosine similarity threshold for expansion (0-1)
        max_iterations: Max iterations for region growing

    Returns:
        mask: Binary mask (H, W) of segmented region
    """
    # Extract patch features
    patch_features, num_patches_h, num_patches_w = extract_features_from_image(
        image, model, processor, device
    )
    H, W = image.shape[:2]
    hidden_dim = patch_features.shape[2]

    # Map seed points to patch coordinates
    patch_h_scale = num_patches_h / H
    patch_w_scale = num_patches_w / W

    seed_patches = []
    for row, col in seed_points:
        patch_row = int(row * patch_h_scale)
        patch_col = int(col * patch_w_scale)
        patch_row = np.clip(patch_row, 0, num_patches_h - 1)
        patch_col = np.clip(patch_col, 0, num_patches_w - 1)
        seed_patches.append((patch_row, patch_col))

    print(f"Seed pixels: {seed_points}")
    print(f"Seed patches: {seed_patches}")

    # Get seed features (average if multiple seeds)
    seed_features = np.array([patch_features[r, c] for r, c in seed_patches])
    seed_feature_mean = seed_features.mean(axis=0, keepdims=True)  # (1, hidden_dim)

    # Compute cosine similarity of all patches to seed
    patch_features_flat = patch_features.reshape(-1, hidden_dim)  # (num_patches, hidden_dim)
    similarities = cosine_similarity(patch_features_flat, seed_feature_mean)[:, 0]  # (num_patches,)
    similarity_map = similarities.reshape(num_patches_h, num_patches_w)  # (H_p, W_p)

    # Initialize mask from seeds
    mask_patches = np.zeros((num_patches_h, num_patches_w), dtype=bool)
    for r, c in seed_patches:
        mask_patches[r, c] = True

    # Region growing: iteratively add neighbors above threshold
    for iteration in range(max_iterations):
        # Dilate current mask to get candidates
        candidates = binary_dilation(mask_patches) & ~mask_patches

        # Check similarity of candidates
        new_patches = candidates & (similarity_map >= similarity_threshold)

        if not new_patches.any():
            print(f"Converged at iteration {iteration}")
            break

        mask_patches |= new_patches

    # Upsample mask to image resolution
    from skimage.transform import resize
    mask_full = resize(mask_patches.astype(float), (H, W),
                      order=0, preserve_range=True, anti_aliasing=False) > 0.5

    print(f"Segmented {mask_full.sum()} pixels ({100*mask_full.sum()/(H*W):.1f}%)")

    return mask_full, similarity_map


# Load your new image (image B)
import rasterio
from rasterio.plot import reshape_as_image

img_b_path = "/home/mak/PycharmProjects/SegEdge/experiments/get_data_from_api/patches_mt/dop20_592000_5983000_1km_20cm.tif"  # <--- PUT YOUR IMAGE B PATH HERE

with rasterio.open(img_b_path) as src:
    arr_b = src.read()
    img_b = reshape_as_image(arr_b)  # converts (C,H,W) -> (H,W,C)

# Define seed points on image B (row, col)
# You can click to get coordinates or use existing ones
seed_points = [(2500, 2500), (2600, 2400)]  # <--- ADJUST THESE TO YOUR SEED LOCATIONS

# Propagate from seeds on image B
mask_b, similarity_map_b = propagate_from_seed_points(
    img_b, seed_points, model, processor, device,
    similarity_threshold=0.75,  # tune this
    max_iterations=50
)

# Visualize results
fig, axes = plt.subplots(1, 4, figsize=(20, 5))

axes[0].imshow(img_b)
axes[0].set_title("Image B")
axes[0].axis('off')
for r, c in seed_points:
    axes[0].plot(c, r, 'r*', markersize=15)

axes[1].imshow(similarity_map_b, cmap='hot')
axes[1].set_title("Feature Similarity to Seeds")
axes[1].axis('off')

axes[2].imshow(mask_b, cmap='gray')
axes[2].set_title("Propagated Mask")
axes[2].axis('off')

overlay = img_b.copy()
overlay[mask_b] = overlay[mask_b] * 0.5 + np.array([0, 255, 0]) * 0.5  # green overlay
axes[3].imshow(overlay.astype(np.uint8))
axes[3].set_title("Overlay on Image")
axes[3].axis('off')

plt.tight_layout()
plt.show()



In [None]:
import torch
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from scipy.ndimage import binary_dilation
import matplotlib.pyplot as plt

def extract_features_from_image(image_array, model, processor, device):
    """Extract DINOv3 patch features from an image."""
    inputs = processor(images=image_array, return_tensors="pt").to(device)

    with torch.no_grad():
        outputs = model(**inputs)
        last_hidden_states = outputs.last_hidden_state

    # Remove CLS token
    patch_features = last_hidden_states[:, 1:, :]  # (1, num_patches, hidden_dim)
    num_patches = patch_features.shape[1]
    patch_features_flat = patch_features[0].cpu().numpy()  # (num_patches, hidden_dim)

    # Get processed dimensions
    processed_h = inputs['pixel_values'].shape[2]
    processed_w = inputs['pixel_values'].shape[3]
    patch_size = model.config.patch_size

    num_patches_h = processed_h // patch_size
    num_patches_w = processed_w // patch_size

    print(f"Image processed to: {processed_h}x{processed_w}")
    print(f"Patch grid: {num_patches_h}x{num_patches_w} = {num_patches_h*num_patches_w} (actual patches: {num_patches})")

    # Handle mismatch: might be rectangular or have extra tokens
    expected = num_patches_h * num_patches_w
    if expected != num_patches:
        print(f"WARNING: Expected {expected} patches but got {num_patches}. Adjusting grid...")
        # Try to find correct dimensions
        aspect_ratio = processed_w / processed_h
        num_patches_h = int(np.sqrt(num_patches / aspect_ratio))
        num_patches_w = num_patches // num_patches_h

        # If still doesn't match, try other factorizations
        if num_patches_h * num_patches_w != num_patches:
            for h in range(1, int(np.sqrt(num_patches)) + 1):
                if num_patches % h == 0:
                    num_patches_h = h
                    num_patches_w = num_patches // h
            print(f"Adjusted to: {num_patches_h}x{num_patches_w}")

    # Reshape to spatial grid
    patch_features_2d = patch_features_flat.reshape(num_patches_h, num_patches_w, -1)

    return patch_features_2d, num_patches_h, num_patches_w


def propagate_from_seed_points(image, seed_points, model, processor, device,
                                 similarity_threshold=0.7, max_iterations=50):
    """
    Segment image by propagating from seed points using feature similarity.

    Args:
        image: RGB image array (H, W, 3)
        seed_points: List of (row, col) tuples marking seed pixels
        model: DINOv3 model
        processor: Image processor
        device: torch device
        similarity_threshold: Cosine similarity threshold for expansion (0-1)
        max_iterations: Max iterations for region growing

    Returns:
        mask: Binary mask (H, W) of segmented region
    """
    # Extract patch features
    patch_features, num_patches_h, num_patches_w = extract_features_from_image(
        image, model, processor, device
    )
    H, W = image.shape[:2]
    hidden_dim = patch_features.shape[2]

    # Map seed points to patch coordinates
    patch_h_scale = num_patches_h / H
    patch_w_scale = num_patches_w / W

    seed_patches = []
    for row, col in seed_points:
        patch_row = int(row * patch_h_scale)
        patch_col = int(col * patch_w_scale)
        patch_row = np.clip(patch_row, 0, num_patches_h - 1)
        patch_col = np.clip(patch_col, 0, num_patches_w - 1)
        seed_patches.append((patch_row, patch_col))

    print(f"Seed pixels: {seed_points}")
    print(f"Seed patches: {seed_patches}")

    # Get seed features (average if multiple seeds)
    seed_features = np.array([patch_features[r, c] for r, c in seed_patches])
    seed_feature_mean = seed_features.mean(axis=0, keepdims=True)  # (1, hidden_dim)

    # Compute cosine similarity of all patches to seed
    patch_features_flat = patch_features.reshape(-1, hidden_dim)  # (num_patches, hidden_dim)
    similarities = cosine_similarity(patch_features_flat, seed_feature_mean)[:, 0]  # (num_patches,)
    similarity_map = similarities.reshape(num_patches_h, num_patches_w)  # (H_p, W_p)

    # Initialize mask from seeds
    mask_patches = np.zeros((num_patches_h, num_patches_w), dtype=bool)
    for r, c in seed_patches:
        mask_patches[r, c] = True

    # Region growing: iteratively add neighbors above threshold
    for iteration in range(max_iterations):
        # Dilate current mask to get candidates
        candidates = binary_dilation(mask_patches) & ~mask_patches

        # Check similarity of candidates
        new_patches = candidates & (similarity_map >= similarity_threshold)

        if not new_patches.any():
            print(f"Converged at iteration {iteration}")
            break

        mask_patches |= new_patches

    # Upsample mask to image resolution
    from skimage.transform import resize
    mask_full = resize(mask_patches.astype(float), (H, W),
                      order=0, preserve_range=True, anti_aliasing=False) > 0.5

    print(f"Segmented {mask_full.sum()} pixels ({100*mask_full.sum()/(H*W):.1f}%)")

    return mask_full, similarity_map


# Load your new image (image B)
import rasterio
from rasterio.plot import reshape_as_image

img_b_path = "/home/mak/PycharmProjects/SegEdge/experiments/get_data_from_api/patches_mt/dop20_592000_5983000_1km_20cm.tif"  # <--- PUT YOUR IMAGE B PATH HERE

with rasterio.open(img_b_path) as src:
    arr_b = src.read()
    img_b = reshape_as_image(arr_b)  # converts (C,H,W) -> (H,W,C)

# Define seed points on image B (row, col)
seed_points = [(1100, 1100), (2600, 2400)]  # <--- ADJUST THESE TO YOUR SEED LOCATIONS

# Propagate from seeds on image B
mask_b, similarity_map_b = propagate_from_seed_points(
    img_b, seed_points, model, processor, device,
    similarity_threshold=0.75,  # tune this
    max_iterations=50
)

# Visualize results
fig, axes = plt.subplots(1, 4, figsize=(20, 5))

axes[0].imshow(img_b)
axes[0].set_title("Image B")
axes[0].axis('off')
for r, c in seed_points:
    axes[0].plot(c, r, 'r*', markersize=15)

axes[1].imshow(similarity_map_b, cmap='hot')
axes[1].set_title("Feature Similarity to Seeds")
axes[1].axis('off')

axes[2].imshow(mask_b, cmap='gray')
axes[2].set_title("Propagated Mask")
axes[2].axis('off')

overlay = img_b.copy()
overlay[mask_b] = overlay[mask_b] * 0.5 + np.array([0, 255, 0]) * 0.5  # green overlay
axes[3].imshow(overlay.astype(np.uint8))
axes[3].set_title("Overlay on Image")
axes[3].axis('off')

plt.tight_layout()
plt.show()


In [None]:
import numpy as np
import torch
from sklearn.neighbors import NearestNeighbors
from skimage.transform import resize

# ---- Robust patch-grid utilities ----
def factor_hw(n, ratio):
    """
    Find integers (h, w) with h*w = n minimizing |(h/w) - ratio|.
    ratio ~ processed_h / processed_w.
    """
    best = (1, n)
    best_err = float('inf')
    for h in range(1, int(np.sqrt(n)) + 1):
        if n % h == 0:
            w = n // h
            err = abs((h / w) - ratio)
            if err < best_err:
                best = (h, w)
                best_err = err
    return best  # (h, w)

def extract_patch_features(image_hw3, model, processor, device):
    """
    Returns:
      feats_hwC: (Hp, Wp, C) patch features (L2-normalized)
      Hp, Wp: patch grid size
      processed_hw: (Hproc, Wproc) used by the model
    """
    inputs = processor(images=image_hw3, return_tensors="pt").to(device)  # (1,3,Hproc,Wproc) [web:103]
    with torch.no_grad():
        out = model(**inputs)
    tokens = out.last_hidden_state  # (1, 1+N, C), first is CLS [web:103]
    patch_tokens = tokens[:, 1:, :]  # (1, N, C) [web:103]
    N, C = patch_tokens.shape[1], patch_tokens.shape[2]  # [web:103]
    Hproc, Wproc = inputs["pixel_values"].shape[2], inputs["pixel_values"].shape[3]  # [web:118]

    # Compute expected grid and fix if mismatch (e.g., non-square/crops)
    ps = model.config.patch_size  # e.g., 16 [web:103]
    Hp_guess, Wp_guess = Hproc // ps, Wproc // ps  # [web:118]
    if Hp_guess * Wp_guess != N:
        Hp, Wp = factor_hw(N, ratio=Hproc / Wproc)  # robust factorization [web:118]
    else:
        Hp, Wp = Hp_guess, Wp_guess  # [web:118]

    feats = patch_tokens[0].cpu().numpy().reshape(Hp, Wp, C)  # (Hp,Wp,C) [web:103]
    # L2-normalize for cosine operations
    eps = 1e-8
    norms = np.linalg.norm(feats, axis=2, keepdims=True) + eps  # [web:129]
    feats = feats / norms  # (Hp,Wp,C) normalized [web:129]
    return feats, Hp, Wp, (Hproc, Wproc)

# ---- Build feature bank from Image A labels ----
def build_positive_bank(img_a, labels_a, model, processor, device):
    """
    img_a: (H,W,3) uint8/rgb
    labels_a: (H,W) int or bool, >0 = positive region
    Returns bank: (Na, C) normalized features from positive patches.
    """
    featsA, HpA, WpA, _ = extract_patch_features(img_a, model, processor, device)  # [web:103]
    # Downsample labels to patch grid using nearest (preserve classes)
    maskA = resize((labels_a > 0).astype(float), (HpA, WpA),
                   order=0, anti_aliasing=False) > 0.5  # [web:102]
    bank = featsA[maskA]  # (Na, C) [web:129]
    return bank  # already L2-normalized [web:129]

# ---- Score Image B by kNN to bank ----
def zero_shot_knn_score(img_b, bank, model, processor, device, k=5):
    """
    Returns:
      score_full: (Hb,Wb) similarity score in [0,1], higher = more similar to bank.
    """
    featsB, HpB, WpB, _ = extract_patch_features(img_b, model, processor, device)  # [web:103]
    X = featsB.reshape(-1, featsB.shape[2])  # (Nb, C) L2-normalized [web:129]

    # kNN with cosine distance (1 - cosine_similarity)
    nn = NearestNeighbors(n_neighbors=min(k, max(1, len(bank))), metric='cosine')  # [web:129]
    nn.fit(bank)  # (Na, C) [web:129]
    dists, _ = nn.kneighbors(X, return_distance=True)  # (Nb, k) [web:129]

    # Convert cosine distance to similarity, average over neighbors
    sims = 1.0 - dists  # cosine similarity per neighbor [web:129]
    score = sims.mean(axis=1)  # (Nb,) average k-NN similarity [web:129]
    score_map = score.reshape(HpB, WpB)  # (HpB, WpB) [web:129]

    # Upsample to pixels (nearest to keep sharp regions, or bicubic if preferred)
    Hb, Wb = img_b.shape[0], img_b.shape[1]  # [web:118]
    score_full = resize(score_map, (Hb, Wb),
                        order=1, preserve_range=True, anti_aliasing=True)  # [web:118]
    # Clamp to [0,1]
    score_full = np.clip(score_full, 0.0, 1.0)  # [web:129]
    return score_full  # [web:129]

# ---- Complete working example ----

# Build feature bank from Image A's labeled region
print("Building feature bank from Image A...")
bank = build_positive_bank(img, labels_2d, model, processor, device)
print(f"Bank size: {bank.shape[0]} positive patches with {bank.shape[1]} dimensions")

# Score Image B using kNN against the bank
print("\nScoring Image B...")
score_b = zero_shot_knn_score(img_b, bank, model, processor, device, k=5)
print(f"Score map shape: {score_b.shape}, range: [{score_b.min():.3f}, {score_b.max():.3f}]")

# Threshold for a binary mask (tune threshold 0.6-0.85)
threshold = 0.3
mask_b = score_b >= threshold
print(f"Mask covers {mask_b.sum()} pixels ({100*mask_b.sum()/mask_b.size:.1f}%)")

# Visualize overlay
import matplotlib.pyplot as plt

overlay = img_b.copy()
overlay[mask_b] = (0.5 * overlay[mask_b] + 0.5 * np.array([0, 255, 0])).astype(overlay.dtype)

fig, axs = plt.subplots(1, 4, figsize=(20, 5))

axs[0].imshow(img)
axs[0].set_title("Image A (with labels)")
axs[0].axis('off')

axs[1].imshow(img_b)
axs[1].set_title("Image B")
axs[1].axis('off')

axs[2].imshow(score_b, cmap='hot')
axs[2].set_title(f"kNN Similarity Score (k=5)")
axs[2].colorbar = plt.colorbar(axs[2].imshow(score_b, cmap='hot'), ax=axs[2])
axs[2].axis('off')

axs[3].imshow(overlay)
axs[3].set_title(f"Transferred Mask (threshold={threshold})")
axs[3].axis('off')

plt.tight_layout()
plt.show()
