In [None]:
# Colab notebook: Auto-label images using SAM (Segment Anything) + CLIP
# Purpose: Automatically generate bounding boxes for novel classes where pretrained detectors don't exist.
# Dataset structure (on your Drive):
# /content/drive/MyDrive/data_managment_nutrion/food-101/images/"cuisines"/"images"
# Supports mixed image formats (.jpg, .jpeg, .png, ...)
# ------------------------------------------------------------
# Cell 1: Install required packages
# - Segment Anything (SAM)
# - CLIP (OpenAI)
# - Supporting libraries (torch, torchvision, OpenCV, PIL, tqdm)
# ------------------------------------------------------------

!pip -q install git+https://github.com/facebookresearch/segment-anything.git
!pip -q install transformers ftfy regex tqdm torchvision pillow opencv-python
!pip -q install git+https://github.com/openai/CLIP.git
!pip -q install supervision
!pip install ImageHash
print("‚úÖ Dependencies installed")


In [None]:
# ------------------------------------------------------------
# Cell 2: Mount Google Drive
# - Access dataset and store models/checkpoints
# ------------------------------------------------------------
from google.colab import drive
drive.mount('/content/drive')

print("‚úÖ Google Drive mounted")


In [None]:
# ------------------------------------------------------------
# Cell 3: Prepare models folder and download SAM checkpoint
# ------------------------------------------------------------
from pathlib import Path
import os

models_dir = Path('/content/drive/MyDrive/models')
models_dir.mkdir(parents=True, exist_ok=True)
print("Models folder:", models_dir)

# SAM vit_h checkpoint (~1.1 GB)
sam_checkpoint = models_dir / 'sam_vit_h_4b8939.pth'

if not sam_checkpoint.exists():
    print("SAM checkpoint not found. Downloading...")
    sam_url = 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
    os.system(f"wget -c '{sam_url}' -O '{sam_checkpoint}'")
else:
    print("‚úÖ SAM checkpoint already exists:", sam_checkpoint)


In [None]:
# ------------------------------------------------------------
# Cell 4: Define dataset and output paths
# ------------------------------------------------------------
from pathlib import Path

# Root folder containing all cuisines
dataset_root = Path('/content/drive/MyDrive/data_managment_nutrion/food-101/images')

# Output folder for YOLO-style labeled data
output_root = Path('/content/drive/MyDrive/auto_labeled_dataset')
output_root.mkdir(parents=True, exist_ok=True)

# List all cuisines (top-level folders)
cuisines = [d for d in dataset_root.iterdir() if d.is_dir()]
print("Found cuisines:", [c.name for c in cuisines])

# For each cuisine, we will later iterate over the food class subfolders:
# dataset_root / cuisine_name / class_name


In [None]:
# ------------------------------------------------------------
# Cell 5: Import required Python libraries
# ------------------------------------------------------------
import cv2
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

from transformers import CLIPProcessor, CLIPModel
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator

print("‚úÖ Libraries imported")


In [None]:
# ------------------------------------------------------------
# Cell 6: Initialize SAM and CLIP models
# ------------------------------------------------------------
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print("Using device:", device)

# SAM
sam_model_type = 'vit_h'
sam = sam_model_registry[sam_model_type](checkpoint=str(sam_checkpoint))
sam.to(device)
mask_generator = SamAutomaticMaskGenerator(sam)
print("‚úÖ SAM loaded")

# CLIP
clip_model_name = 'openai/clip-vit-base-patch32'
clip_model = CLIPModel.from_pretrained(clip_model_name).to(device)
clip_processor = CLIPProcessor.from_pretrained(clip_model_name)
print("‚úÖ CLIP loaded")


In [None]:
# ------------------------------------------------------------
# Cell 7: Helper functions
# ------------------------------------------------------------
import cv2
import numpy as np
from PIL import Image
import torch
import matplotlib.pyplot as plt

# Load image as RGB
def load_image_cv2(path):
    img = cv2.imread(str(path))
    if img is None:
        raise RuntimeError(f"Failed to read image: {path}")
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

# Convert boolean mask to bounding box
def mask_to_box(mask):
    ys, xs = np.where(mask)
    if len(xs) == 0 or len(ys) == 0:
        return None
    return int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())

# CLIP score for whole image
def clip_image_text_score(image_rgb, text):
    pil_img = Image.fromarray(image_rgb)
    inputs = clip_processor(
        text=[text],
        images=pil_img,
        return_tensors="pt",
        padding=True
    ).to(device)
    with torch.no_grad():
        outputs = clip_model(**inputs)
    img_emb = outputs.image_embeds / outputs.image_embeds.norm(dim=-1, keepdim=True)
    txt_emb = outputs.text_embeds / outputs.text_embeds.norm(dim=-1, keepdim=True)
    return (img_emb @ txt_emb.T).item()

# CLIP score for a mask
def clip_mask_score(image_rgb, mask, text):
    box = mask_to_box(mask)
    if box is None:
        return -1.0
    x1, y1, x2, y2 = box
    crop = image_rgb[y1:y2+1, x1:x2+1]
    mask_crop = mask[y1:y2+1, x1:x2+1]
    if mask_crop.sum() == 0:
        return -1.0
    pil_img = Image.fromarray(crop).convert("RGBA")
    alpha = Image.fromarray((mask_crop * 255).astype(np.uint8))
    pil_img.putalpha(alpha)
    inputs = clip_processor(
        text=[text],
        images=pil_img,
        return_tensors="pt",
        padding=True
    ).to(device)
    with torch.no_grad():
        outputs = clip_model(**inputs)
    img_emb = outputs.image_embeds / outputs.image_embeds.norm(dim=-1, keepdim=True)
    txt_emb = outputs.text_embeds / outputs.text_embeds.norm(dim=-1, keepdim=True)
    return (img_emb @ txt_emb.T).item()

# Visualize image with mask overlay + bounding box
def visualize_result(image_rgb, mask, box, score, title="Result"):
    overlay = image_rgb.copy()
    # green mask overlay
    green = np.zeros_like(overlay)
    green[..., 1] = 255
    overlay[mask] = (0.6 * overlay[mask] + 0.4 * green[mask]).astype(np.uint8)
    # blue bounding box
    if box is not None:
        x1, y1, x2, y2 = box
        cv2.rectangle(overlay, (x1, y1), (x2, y2), (255, 0, 0), 2)
    plt.figure(figsize=(6,6))
    plt.imshow(overlay)
    plt.title(f"{title}\nCLIP score: {score:.3f}")
    plt.axis("off")
    plt.show()


In [None]:
# ------------------------------------------------------------
# Cell 8: Label all cuisines and classes, save mask + bounding box
# ------------------------------------------------------------
IMAGE_PRESENCE_THRESHOLD = 0.18  # min CLIP score to keep image
MAX_MASKS = 200                   # masks per image to check
TOP_K_MASKS = 5                   # merge top-K masks for final

img_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}

# Loop over cuisines
for cuisine_dir in dataset_root.iterdir():
    if not cuisine_dir.is_dir():
        continue
    print(f"\nüçΩÔ∏è Processing cuisine: {cuisine_dir.name}")

    # Loop over classes/foods
    class_dirs = [d for d in cuisine_dir.iterdir() if d.is_dir()]
    for class_dir in class_dirs:
        class_name = class_dir.name.replace('_', ' ')
        print(f"\nüîπ Class: {class_name}")

        out_class_dir = output_root / cuisine_dir.name / class_dir.name
        out_class_dir.mkdir(parents=True, exist_ok=True)

        kept_images = 0
        labeled_images = 0

        for img_path in tqdm(list(class_dir.iterdir()), desc=f"{class_name}"):
            if img_path.suffix.lower() not in img_exts:
                continue

            try:
                image_rgb = load_image_cv2(img_path)
            except Exception as e:
                print("Skipping", img_path, "error:", e)
                continue

            # ---- IMAGE PRESENCE CHECK ----
            presence_score = clip_image_text_score(image_rgb, f"a photo of {class_name} food")
            if presence_score < IMAGE_PRESENCE_THRESHOLD:
                continue
            kept_images += 1

            # ---- SAM SEGMENTATION ----
            try:
                masks = mask_generator.generate(image_rgb)
            except Exception as e:
                print("SAM failed for", img_path, e)
                continue

            # ---- SCORE MASKS WITH CLIP ----
            scored_masks = []
            for m in masks[:MAX_MASKS]:
                score = clip_mask_score(image_rgb, m['segmentation'], f"a photo of {class_name} food")
                if score > 0:
                    scored_masks.append((score, m['segmentation']))

            if not scored_masks:
                continue

            # ---- MERGE TOP-K MASKS ----
            scored_masks.sort(key=lambda x: x[0], reverse=True)
            merged_mask = np.zeros_like(scored_masks[0][1], dtype=bool)
            for _, m in scored_masks[:TOP_K_MASKS]:
                merged_mask |= m

            best_box = mask_to_box(merged_mask)
            if best_box is None:
                continue

            best_score = max([s for s, _ in scored_masks[:TOP_K_MASKS]])

            # ---- SAVE MASK ----
            mask_path = out_class_dir / f"{img_path.stem}_mask.png"
            Image.fromarray((merged_mask*255).astype(np.uint8)).save(mask_path)

            # ---- SAVE YOLO LABEL ----
            h, w, _ = image_rgb.shape
            x1, y1, x2, y2 = best_box
            xc = ((x1 + x2) / 2) / w
            yc = ((y1 + y2) / 2) / h
            bw = (x2 - x1) / w
            bh = (y2 - y1) / h

            label_path = out_class_dir / f"{img_path.stem}.txt"
            with open(label_path, "w") as f:
                f.write(f"0 {xc:.6f} {yc:.6f} {bw:.6f} {bh:.6f}\n")

            labeled_images += 1

            # ---- VISUALIZE MASK + BOUNDING BOX ----
            visualize_result(image_rgb, merged_mask, best_box, best_score, title=f"{class_name} | {img_path.name}")

        print(f"‚úÖ DONE: {class_name}")
        print(f"Images with class present: {kept_images}")
        print(f"Images successfully labeled: {labeled_images}")
        print(f"Output folder: {out_class_dir}")


# Cell 8 Explanation: Single-Class Labeling with SAM + CLIP

This cell processes a single food class (`koshary`) for testing purposes, using **SAM (Segment Anything Model)** for segmentation and **CLIP** for class verification. The goal is to generate:

1. **Masks** ‚Äî pixel-wise segmentation of the main food item.
2. **Bounding Boxes** ‚Äî converted from masks, in YOLO format.
3. **Visualization** ‚Äî show mask overlay + bounding box for verification.

### Steps

1. **Image Presence Check**  
   - Use CLIP to compute similarity between image and the prompt `"a clean photo of koshary food on a plate"`.  
   - Images with low CLIP scores are discarded to remove irrelevant images.

2. **SAM Segmentation**  
   - Generate up to `MAX_MASKS` per image to capture possible food regions.  
   - Even messy or side-bunched food can be captured because multiple masks are generated.

3. **Mask Scoring with CLIP**  
   - Each mask is cropped and scored with CLIP against the same prompt.  
   - Only masks with meaningful scores are considered, filtering out background and unrelated objects.

4. **Mask Merging**  
   - Merge **top-K masks** (by CLIP score) into a single mask.  
   - Weighted merging ensures that high-confidence masks dominate, reducing background contamination.

5. **Bounding Box Extraction**  
   - Convert merged mask to a bounding box.  
   - Bounding box saved in **YOLO format** for downstream tasks (e.g., object detection, calorie estimation).

6. **Mask and Bounding Box Saving**  
   - Save mask as PNG (`*_mask.png`).  
   - Save bounding box in YOLO `.txt` file.

7. **Visualization**  
   - Overlay merged mask on the original image in green.  
   - Draw bounding box in red.  
   - Display CLIP score for verification.

### Hyperparameter Notes

| Parameter | Value | Purpose |
|-----------|-------|--------|
| IMAGE_PRESENCE_THRESHOLD | 0.25 | Remove images that don‚Äôt clearly contain the main class. |
| MAX_MASKS | 300 | Generate more candidate regions, capturing messy arrangements. |
| TOP_K_MASKS | 3 | Merge top-scoring masks; avoids background masks. |
| mask-to-image ratio | 0.02‚Äì0.9 | Remove tiny noise masks or full-frame background. |

This approach ensures that the **main food class is preserved**, while minimizing inclusion of background or side objects.





### New Features / Improvements

1. **Plate / Background Removal**
   - Masks that occupy **too large a fraction of the image** are discarded (likely the plate).  
   - Masks that are **too small** are ignored (noise).  
   - Aspect ratio filtering removes elongated or perfectly circular regions (often plate edges).

2. **Weighted CLIP Scoring**
   - Each mask‚Äôs CLIP score is multiplied by the **mask area fraction** to penalize huge masks that include background.

3. **Top-K Mask Merging**
   - Merge only **top-K masks** by weighted score to generate the final mask.  

4. **Image Presence Filtering**
   - CLIP is used to discard images that **don‚Äôt contain the main class** with a threshold.

5. **Visualization**
   - Overlay merged mask in green and bounding box in red.  
   - CLIP score is shown to confirm relevance.

### Hyperparameters

| Parameter | Value | Purpose |
|-----------|-------|--------|
| IMAGE_PRESENCE_THRESHOLD | 0.25 | Discard images without the main food. |
| MAX_MASKS | 300 | Check more candidate masks to catch messy arrangements. |
| TOP_K_MASKS | 3 | Merge only top-K masks by weighted score. |
| MIN_MASK_RATIO | 0.02 | Ignore tiny masks (noise). |
| MAX_MASK_RATIO | 0.6 | Ignore masks that cover plate/background. |
| ASPECT_RATIO_RANGE | 0.3‚Äì3 | Ignore elongated or perfectly round masks (likely plate). |

This approach ensures that the **mask focuses on the main food**, even if it is messy or side-bunched, while removing most of the plate and background.


In [None]:
!pip install ImageHash

In [None]:
import hashlib
from pathlib import Path
from tqdm import tqdm
from PIL import Image
import imagehash  # pip install ImageHash

def get_precise_hash(img_path):
    """Detects byte-for-byte identical files."""
    return hashlib.md5(img_path.read_bytes()).hexdigest()

def get_perceptual_hash(img_path):
    """Detects visually identical images even if resized/compressed."""
    try:
        with Image.open(img_path) as img:
            return str(imagehash.phash(img))
    except:
        return None

# Parameters
delete_files = True  # Set to False first if you want to test without deleting
seen_hashes = set()
duplicates_removed = 0

print("üöÄ Starting Dataset De-duplication...")

for cuisine_dir in tqdm(list(dataset_root.iterdir()), desc="Cuisines"):
    if not cuisine_dir.is_dir(): continue
    for class_dir in cuisine_dir.iterdir():
        if not class_dir.is_dir(): continue
        for img_path in class_dir.iterdir():
            if img_path.suffix.lower() not in {'.jpg', '.jpeg', '.png', '.bmp', '.webp'}:
                continue

            # 1. Check Byte-level (MD5)
            file_hash = get_precise_hash(img_path)

            # 2. Check Visual-level (pHash)
            vis_hash = get_perceptual_hash(img_path)

            # Identify if either hash has been seen before
            is_duplicate = (file_hash in seen_hashes) or (vis_hash and vis_hash in seen_hashes)

            if is_duplicate:
                if delete_files:
                    img_path.unlink()  # Physically deletes the file
                duplicates_removed += 1
            else:
                seen_hashes.add(file_hash)
                if vis_hash:
                    seen_hashes.add(vis_hash)

print(f"\n‚úÖ Clean-up Complete!")
print(f"Total Unique Images Kept: {len(seen_hashes)}")
print(f"Duplicates Deleted: {duplicates_removed}")


In [None]:
import pandas as pd
from sklearn.model_selection import train_test_split

IMAGE_PRESENCE_THRESHOLD = 0.3
CONFIDENCE_THRESHOLD = 0.12
MAX_FOOD_AREA = 0.6
WINDOW_STEP = 35
TOP_N_POINTS = 20

img_exts = {'.jpg','.jpeg','.png','.bmp','.tiff','.webp'}
meta_rows = []

for cuisine_dir in tqdm(list(dataset_root.iterdir()), desc="Cuisines"):
    if not cuisine_dir.is_dir(): continue
    for class_dir in cuisine_dir.iterdir():
        if not class_dir.is_dir(): continue
        class_name = class_dir.name
        CLASS_PROMPT = f"the appearance of {class_name} food"
        PLATE_PROMPT = "a white plate or rim, bread or background"
        out_class_dir = output_root / cuisine_dir.name / class_name
        out_class_dir.mkdir(parents=True, exist_ok=True)

        predictor = SamPredictor(sam)
        for img_path in tqdm(list(class_dir.iterdir()), desc=f"{cuisine_dir.name}/{class_name}", leave=False):
            if img_path.suffix.lower() not in img_exts: continue
            try:
                image_rgb = load_image_cv2(img_path)
            except: continue
            h, w, _ = image_rgb.shape

            # IMAGE PRESENCE CHECK
            presence_score = clip_image_text_score(image_rgb, CLASS_PROMPT)
            if presence_score < IMAGE_PRESENCE_THRESHOLD: continue

            # GENERATE POINTS
            win_h, win_w = h//6, w//6
            points = []
            for y in range(int(h*0.1), int(h*0.9), WINDOW_STEP):
                for x in range(int(w*0.1), int(w*0.9), WINDOW_STEP):
                    crop = image_rgb[y:y+win_h, x:x+win_w]
                    if crop.size == 0: continue
                    food_score = clip_image_text_score(crop, CLASS_PROMPT)
                    plate_score = clip_image_text_score(crop, PLATE_PROMPT)
                    points.append({'coord':[x+win_w//2,y+win_h//2],'food_score':food_score,'plate_score':plate_score})

            if not points: continue
            points.sort(key=lambda p: p['food_score'], reverse=True)
            input_coords = [p['coord'] for p in points[:TOP_N_POINTS] if p['food_score']>CONFIDENCE_THRESHOLD]
            input_labels = [1]*len(input_coords)
            # Negative points (plate/background)
            points.sort(key=lambda p: p['plate_score'], reverse=True)
            for p in points[:10]:
                if p['plate_score'] > p['food_score']+0.05:
                    input_coords.append(p['coord'])
                    input_labels.append(0)

            if not input_coords or 1 not in input_labels: continue

            predictor.set_image(image_rgb)
            masks, scores, _ = predictor.predict(np.array(input_coords), np.array(input_labels), multimask_output=True)
            best_mask = None
            mask_indices = np.argsort([np.sum(m) for m in masks])[::-1]
            for idx in mask_indices:
                m = masks[idx]
                if np.any(m[0,:]) or np.any(m[-1,:]) or np.any(m[:,0]) or np.any(m[:,-1]): continue
                m_area_frac = np.sum(m)/(h*w)
                if 0.04<m_area_frac<MAX_FOOD_AREA:
                    kernel = np.ones((7,7),np.uint8)
                    best_mask = cv2.morphologyEx(m.astype(np.uint8), cv2.MORPH_CLOSE, kernel).astype(bool)
                    break
            if best_mask is None: continue

            # Bounding box
            best_box = mask_to_box(best_mask)
            x1, y1, x2, y2 = best_box
            xc, yc = ((x1+x2)/2)/w, ((y1+y2)/2)/h
            bw, bh = (x2-x1)/w, (y2-y1)/h

            # Save YOLO label
            label_path = out_class_dir / f"{img_path.stem}.txt"
            with open(label_path, "w") as f:
                f.write(f"0 {xc:.6f} {yc:.6f} {bw:.6f} {bh:.6f}\n")

            # Save mask
            mask_path = out_class_dir / f"{img_path.stem}_mask.png"
            Image.fromarray((best_mask*255).astype(np.uint8)).save(mask_path)

            # Metadata
            meta_rows.append({
                "image_path": str(img_path),
                "mask_path": str(mask_path),
                "class": class_name,
                "cuisine": cuisine_dir.name
            })

            # Optional visualization
            visualize_result(image_rgb, best_mask, best_box, np.max(scores), title=f"{class_name}")


In [None]:
# single Colab cell -> mounts drive, creates models folder, downloads SAM checkpoint if missing
from pathlib import Path
import os

# 2) Prepare models folder on Drive
models_dir = Path('/content/drive/MyDrive/models')
models_dir.mkdir(parents=True, exist_ok=True)
print("Models folder:", models_dir)

# 3) Check / download SAM vit_h checkpoint (large ~1.1 GB). Only download if missing.
sam_checkpoint = models_dir / 'sam_vit_h_4b8939.pth'
if sam_checkpoint.exists():
    print("‚úÖ SAM checkpoint already exists at:", sam_checkpoint)
else:
    print("SAM checkpoint not found in Drive. Downloading now (this may take a few minutes)...")
    # Official public file hosted by Meta (Facebook Research)
    sam_url = 'https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth'
    # -c allows resume if partially downloaded; -O writes to target path
    cmd = f"wget -c '{sam_url}' -O '{sam_checkpoint}'"
    print("Running:", cmd)
    ret = os.system(cmd)
    if ret == 0 and sam_checkpoint.exists():
        print("‚úÖ Download completed and saved to Drive:", sam_checkpoint)
    else:
        print("‚ö†Ô∏è Download failed (return code", ret, ").")
        print("Check your internet connection, enough Drive space, or download manually from:")
        print(sam_url)
        print("Then upload the file to", models_dir)

# 4) Use sam_checkpoint as Path for the rest of the notebook
print("sam_checkpoint (Path):", sam_checkpoint)


In [None]:
# -----------------------
# Load CLIP (required before scoring)
# -----------------------
import torch
from transformers import CLIPProcessor, CLIPModel

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

clip_model_name = "openai/clip-vit-base-patch32"

clip_model = CLIPModel.from_pretrained(clip_model_name).to(device)
clip_processor = CLIPProcessor.from_pretrained(clip_model_name)

clip_model.eval()

print("‚úÖ CLIP loaded")


In [None]:
# =========================================================
# Complete Auto-label + Visualization (all images labeled)
# =========================================================

from pathlib import Path
import cv2
import numpy as np
import torch
from PIL import Image
from tqdm import tqdm
import matplotlib.pyplot as plt

# -----------------------
# CONFIG
# -----------------------
CLASS_NAME = "koshari"
CLASS_PROMPT = "a photo of koshary food"

IMAGE_PRESENCE_THRESHOLD = 0.18   # image must contain main class
MASK_SCORE_THRESHOLD = 0.0        # merged mask will always label
MAX_MASKS_TO_CHECK = 200
TOP_K_MASKS = 5                   # merge top-K masks

SHOW_VISUALS = True

img_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}

dataset_root = Path("/content/drive/MyDrive/data_managment_nutrion/food-101/images/Egyptian")
class_dir = dataset_root / CLASS_NAME

output_root = Path("/content/drive/MyDrive/auto_labeled_dataset")
out_class_dir = output_root / CLASS_NAME
out_class_dir.mkdir(parents=True, exist_ok=True)

assert class_dir.exists(), f"Class folder not found: {class_dir}"

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using device:", device)

# -----------------------
# HELPER FUNCTIONS
# -----------------------
def load_image_cv2(path):
    img = cv2.imread(str(path))
    if img is None:
        raise RuntimeError(f"Failed to read image: {path}")
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

def mask_to_box(mask):
    ys, xs = np.where(mask)
    if len(xs) == 0:
        return None
    return int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())

def clip_image_text_score(image_rgb, text):
    pil_img = Image.fromarray(image_rgb)
    inputs = clip_processor(
        text=[text],
        images=pil_img,
        return_tensors="pt",
        padding=True
    ).to(device)

    with torch.no_grad():
        outputs = clip_model(**inputs)

    img_emb = outputs.image_embeds
    txt_emb = outputs.text_embeds
    img_emb = img_emb / img_emb.norm(dim=-1, keepdim=True)
    txt_emb = txt_emb / txt_emb.norm(dim=-1, keepdim=True)
    return (img_emb @ txt_emb.T).item()

def clip_mask_score(image_rgb, mask, text):
    box = mask_to_box(mask)
    if box is None:
        return -1.0

    x1, y1, x2, y2 = box
    crop = image_rgb[y1:y2+1, x1:x2+1]
    mask_crop = mask[y1:y2+1, x1:x2+1]

    if mask_crop.sum() == 0:
        return -1.0

    pil_img = Image.fromarray(crop).convert("RGBA")
    alpha = Image.fromarray((mask_crop * 255).astype(np.uint8))
    pil_img.putalpha(alpha)

    inputs = clip_processor(
        text=[text],
        images=pil_img,
        return_tensors="pt",
        padding=True
    ).to(device)

    with torch.no_grad():
        outputs = clip_model(**inputs)

    img_emb = outputs.image_embeds
    txt_emb = outputs.text_embeds
    img_emb = img_emb / img_emb.norm(dim=-1, keepdim=True)
    txt_emb = txt_emb / txt_emb.norm(dim=-1, keepdim=True)
    return (img_emb @ txt_emb.T).item()

def visualize_result(image_rgb, mask, box, score, title="Result"):
    overlay = image_rgb.copy()
    # Mask overlay (green)
    green = np.zeros_like(overlay)
    green[..., 1] = 255
    overlay[mask] = (0.6 * overlay[mask] + 0.4 * green[mask]).astype(np.uint8)
    # Bounding box (red)
    if box is not None:
        x1, y1, x2, y2 = box
        cv2.rectangle(overlay, (x1, y1), (x2, y2), (255, 0, 0), 2)
    # Show
    plt.figure(figsize=(6,6))
    plt.imshow(overlay)
    plt.title(f"{title}\nCLIP score: {score:.3f}")
    plt.axis("off")
    plt.show()

# -----------------------
# MAIN LOOP
# -----------------------
kept_images = 0
labeled_images = 0

for img_path in tqdm(list(class_dir.iterdir()), desc=f"Processing {CLASS_NAME}"):
    if img_path.suffix.lower() not in img_exts:
        continue

    try:
        image_rgb = load_image_cv2(img_path)
    except Exception:
        continue

    # ---- IMAGE-LEVEL PRESENCE CHECK ----
    presence_score = clip_image_text_score(image_rgb, CLASS_PROMPT)
    if presence_score < IMAGE_PRESENCE_THRESHOLD:
        continue

    kept_images += 1

    # ---- SAM SEGMENTATION ----
    try:
        masks = mask_generator.generate(image_rgb)
    except Exception:
        continue

    scored_masks = []
    for m in masks[:MAX_MASKS_TO_CHECK]:
        score = clip_mask_score(image_rgb, m["segmentation"], CLASS_PROMPT)
        if score > 0:
            scored_masks.append((score, m["segmentation"]))

    if not scored_masks:
        continue

    # ---- MERGE TOP-K MASKS ----
    scored_masks.sort(key=lambda x: x[0], reverse=True)
    top_masks = scored_masks[:TOP_K_MASKS]

    merged_mask = np.zeros_like(top_masks[0][1], dtype=bool)
    for _, m in top_masks:
        merged_mask |= m

    # --- merged mask always gets a box ---
    best_box = mask_to_box(merged_mask)
    if best_box is None:
        continue

    best_score = max([s for s, _ in top_masks])

    # ---- WRITE YOLO LABEL ----
    h, w, _ = image_rgb.shape
    x1, y1, x2, y2 = best_box

    xc = ((x1 + x2) / 2) / w
    yc = ((y1 + y2) / 2) / h
    bw = (x2 - x1) / w
    bh = (y2 - y1) / h

    label_path = out_class_dir / f"{img_path.stem}.txt"
    with open(label_path, "w") as f:
        f.write(f"0 {xc:.6f} {yc:.6f} {bw:.6f} {bh:.6f}\n")

    labeled_images += 1

    # ---- VISUALIZE ALL IMAGES ----
    if SHOW_VISUALS:
        visualize_result(
            image_rgb=image_rgb,
            mask=merged_mask,
            box=best_box,
            score=best_score,
            title=f"{CLASS_NAME} | {img_path.name}"
        )

print("\n‚úÖ DONE:", CLASS_NAME)
print("Images with class present:", kept_images)
print("Images successfully labeled:", labeled_images)
print("Output folder:", out_class_dir)


In [None]:
from google.colab import drive
drive.mount('/content/drives')

In [None]:
# -----------------------
# 3) Paths & checkpoints
# -----------------------
import os
from pathlib import Path

# Root directory that contains subfolders, each named after its class
dataset_root = Path('/content/drive/MyDrive/data_managment_nutrion/food-101/images/Egyptian')

# Output directory for YOLO-style labeled data
output_root = Path('/content/drive/MyDrive/auto_labeled_dataset')
output_root.mkdir(parents=True, exist_ok=True)

# SAM checkpoint (downloaded to Drive earlier)
sam_checkpoint = Path('/content/drive/MyDrive/models/sam_vit_h_4b8939.pth')

# -----------------------
# 4) Imports for the pipeline
# -----------------------
import cv2
import numpy as np
from PIL import Image
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt
from transformers import CLIPProcessor, CLIPModel
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator

# -----------------------
# 5) Initialize SAM and CLIP
# -----------------------
sam_model_type = 'vit_h'

if not sam_checkpoint.exists():
    raise FileNotFoundError(f"SAM checkpoint not found at {sam_checkpoint}. Please download it first.")

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print('Using device:', device)

# Load SAM
sam = sam_model_registry[sam_model_type](checkpoint=str(sam_checkpoint))
sam.to(device=device)
mask_generator = SamAutomaticMaskGenerator(sam)
print('‚úÖ SAM loaded successfully')

# Load CLIP
clip_model_name = 'openai/clip-vit-base-patch32'
clip_model = CLIPModel.from_pretrained(clip_model_name).to(device)
clip_processor = CLIPProcessor.from_pretrained(clip_model_name)
print('‚úÖ CLIP loaded successfully')

# -----------------------
# 6) Helper functions
# -----------------------
def load_image_cv2(path):
    img = cv2.imread(str(path))
    if img is None:
        raise RuntimeError(f'Failed to read image: {path}')
    return cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

def mask_to_box(mask):
    ys, xs = np.where(mask)
    if len(xs) == 0 or len(ys) == 0:
        return None
    return int(xs.min()), int(ys.min()), int(xs.max()), int(ys.max())

def get_clip_score_for_mask(image_rgb, mask, text, clip_model, clip_processor, device):
    box = mask_to_box(mask)
    if box is None:
        return -1.0
    x_min, y_min, x_max, y_max = box
    cropped = image_rgb[y_min:y_max+1, x_min:x_max+1]
    mask_crop = mask[y_min:y_max+1, x_min:x_max+1]
    if mask_crop.sum() == 0:
        return -1.0
    pil_img = Image.fromarray(cropped)
    alpha = Image.fromarray((mask_crop * 255).astype(np.uint8)).convert('L')
    pil_rgba = pil_img.convert('RGBA')
    pil_rgba.putalpha(alpha)
    inputs = clip_processor(text=[text], images=pil_rgba, return_tensors='pt', padding=True).to(device)
    with torch.no_grad():
        outputs = clip_model(**inputs)
    img_emb, txt_emb = outputs.image_embeds, outputs.text_embeds
    img_emb = img_emb / img_emb.norm(dim=-1, keepdim=True)
    txt_emb = txt_emb / txt_emb.norm(dim=-1, keepdim=True)
    return (img_emb @ txt_emb.T).item()

# -----------------------
# 7) Main labeling loop
# -----------------------
mask_score_threshold = 0.18
max_masks_to_check = 200
img_exts = {'.jpg', '.jpeg', '.png', '.bmp', '.tiff', '.webp'}

# All subfolders are treated as classes
class_dirs = [d for d in dataset_root.iterdir() if d.is_dir()]
classes_sorted = sorted([d.name for d in class_dirs])
class_to_id = {c: i for i, c in enumerate(classes_sorted)}
print("üìÇ Found classes:", classes_sorted)

for class_dir in class_dirs:
    class_name = class_dir.name.replace('_', ' ')
    print(f"\nüîç Processing class: {class_name}")
    out_class_folder = output_root / class_dir.name
    out_class_folder.mkdir(parents=True, exist_ok=True)

    for img_path in tqdm(list(class_dir.iterdir()), desc=f"{class_name}"):
        if img_path.suffix.lower() not in img_exts:
            continue
        try:
            image_rgb = load_image_cv2(img_path)
        except Exception as e:
            print("Skipping", img_path, "error:", e)
            continue

        try:
            masks = mask_generator.generate(image_rgb)
        except Exception as e:
            print("SAM failed for", img_path, e)
            continue

        best_score, best_box = -1.0, None
        for i, mask_dict in enumerate(masks[:max_masks_to_check]):
            score = get_clip_score_for_mask(image_rgb, mask_dict['segmentation'], class_name, clip_model, clip_processor, device)
            if score > best_score:
                best_score, best_box = score, mask_to_box(mask_dict['segmentation'])

        h, w, _ = image_rgb.shape
        label_file = out_class_folder / (img_path.stem + '.txt')
        if best_box and best_score >= mask_score_threshold:
            x_min, y_min, x_max, y_max = best_box
            xc, yc = ((x_min + x_max) / 2) / w, ((y_min + y_max) / 2) / h
            bw, bh = (x_max - x_min) / w, (y_max - y_min) / h
            class_id = class_to_id[class_dir.name]
            with open(label_file, 'w') as f:
                f.write(f"{class_id} {xc:.6f} {yc:.6f} {bw:.6f} {bh:.6f}\n")

print("\n‚úÖ Labeling complete! Check:", output_root)

# -----------------------
# 8) Visualize labeled examples
# -----------------------
examples_shown = 8
shown = 0
fig = plt.figure(figsize=(12, 8))
for class_dir in class_dirs:
    for img_path in class_dir.iterdir():
        if shown >= examples_shown:
            break
        if img_path.suffix.lower() not in img_exts:
            continue
        label_path = output_root / class_dir.name / (img_path.stem + '.txt')
        if not label_path.exists():
            continue

        with open(label_path) as f:
            vals = f.read().strip().split()
        if len(vals) != 5:
            continue

        class_id, xc, yc, bw, bh = int(vals[0]), *map(float, vals[1:])
        image_rgb = load_image_cv2(img_path)
        h, w, _ = image_rgb.shape
        x_min, y_min = int((xc - bw/2) * w), int((yc - bh/2) * h)
        x_max, y_max = int((xc + bw/2) * w), int((yc + bh/2) * h)

        ax = fig.add_subplot(2, 4, shown + 1)
        ax.imshow(image_rgb)
        ax.add_patch(plt.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min,
                                   fill=False, color='red', linewidth=2))
        ax.set_title(class_dir.name)
        ax.axis('off')
        shown += 1
    if shown >= examples_shown:
        break

plt.tight_layout()
plt.show()


In [None]:
examples_shown = 8
shown = 0
fig = plt.figure(figsize=(12, 8))
for class_dir in class_dirs:
    for img_path in class_dir.iterdir():
        if shown >= examples_shown:
            break
        if img_path.suffix.lower() not in img_exts:
            continue
        label_path = output_root / class_dir.name / (img_path.stem + '.txt')
        if not label_path.exists():
            continue

        with open(label_path) as f:
            vals = f.read().strip().split()
        if len(vals) != 5:
            continue

        class_id, xc, yc, bw, bh = int(vals[0]), *map(float, vals[1:])
        image_rgb = load_image_cv2(img_path)
        h, w, _ = image_rgb.shape
        x_min, y_min = int((xc - bw/2) * w), int((yc - bh/2) * h)
        x_max, y_max = int((xc + bw/2) * w), int((yc + bh/2) * h)

        ax = fig.add_subplot(2, 4, shown + 1)
        ax.imshow(image_rgb)
        ax.add_patch(plt.Rectangle((x_min, y_min), x_max - x_min, y_max - y_min,
                                   fill=False, color='red', linewidth=2))
        ax.set_title(class_dir.name)
        ax.axis('off')
        shown += 1
    if shown >= examples_shown:
        break

plt.tight_layout()
plt.show()
