<a href="https://colab.research.google.com/github/KasiR07/MealLens-Inc/blob/main/SAM.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import torch
import cv2
import numpy as np
import os
import matplotlib.pyplot as plt


# Attempt to import necessary modules from SAM2 package

In [None]:
try:
    from sam2.sam2_image_predictor import SAM2ImagePredictor
    from sam2.build_sam import build_sam2
    from sam2.automatic_mask_generator import SAM2AutomaticMaskGenerator
except ImportError as error:
    print(f"ImportError: {error}")
    print("Make sure 'sam2_env' is activated and dependencies are installed from the 'sam2' directory.")
    raise

# File Path Configuration

In [5]:
BASE_DIR = 'https://github.com/KasiR07/MealLens-Inc/blob/main/Org_Image1.jpg'
CHECKPOINT_FOLDER = os.path.join(BASE_DIR, 'checkpoints')
CONFIG_FILE = os.path.join(BASE_DIR, 'sam2', 'configs', 'sam2.1', 'sam2.1_hiera_b+.yaml')
MODEL_FILE = 'sam2.1_hiera_base_plus.pt'
CHECKPOINT_PATH = os.path.join(CHECKPOINT_FOLDER, MODEL_FILE)
IMG_DIR = 'C:/Users/akhil/food_images_sample'

print("Libraries and modules loaded successfully.")


Libraries and modules loaded successfully.


# Load SAM2 Model

In [7]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Running on: {device}")

try:
    if not os.path.isfile(CHECKPOINT_PATH):
        raise FileNotFoundError(f"Checkpoint not found at: {CHECKPOINT_PATH}")
    if not os.path.isfile(CONFIG_FILE):
        raise FileNotFoundError(f"Config file not found at: {CONFIG_FILE}")

    # Load and build the SAM2 model
    model = build_sam2(
        config_file=CONFIG_FILE,
        ckpt_path=CHECKPOINT_PATH,
        device=device
    )
    model.eval()

    # Initialize mask generator with model and desired parameters
    mask_generator = SAM2AutomaticMaskGenerator(
        model=model,
        points_per_side=16,
        pred_iou_thresh=0.86,
        stability_score_thresh=0.92,
        min_mask_region_area=100,
        output_mode="binary_mask",
        multimask_output=True
    )

    print(f"Model {MODEL_FILE} initialized successfully on {device}.")

except Exception as err:
    print(f"Model loading failed: {err}")
    import traceback
    traceback.print_exc()
    raise


Running on: cpu
Model loading failed: Checkpoint not found at: https://github.com/KasiR07/MealLens-Inc/blob/main/Org_Image1.jpg/checkpoints/sam2.1_hiera_base_plus.pt


Traceback (most recent call last):
  File "<ipython-input-7-981014112>", line 6, in <cell line: 0>
    raise FileNotFoundError(f"Checkpoint not found at: {CHECKPOINT_PATH}")
FileNotFoundError: Checkpoint not found at: https://github.com/KasiR07/MealLens-Inc/blob/main/Org_Image1.jpg/checkpoints/sam2.1_hiera_base_plus.pt


FileNotFoundError: Checkpoint not found at: https://github.com/KasiR07/MealLens-Inc/blob/main/Org_Image1.jpg/checkpoints/sam2.1_hiera_base_plus.pt

# Load Image Paths

In [None]:
image_files = []
if os.path.isdir(IMG_DIR):
    image_files = [os.path.join(IMG_DIR, f) for f in os.listdir(IMG_DIR)
                   if f.lower().endswith(('.png', '.jpg', '.jpeg', '.gif', '.bmp', '.tiff'))]
    print(f"{len(image_files)} image(s) found in directory: {IMG_DIR}")
    if not image_files:
        raise FileNotFoundError("No supported image files found.")
else:
    raise FileNotFoundError(f"Directory not found: {IMG_DIR}")


# Perform Segmentation

In [None]:
results = {}
print("Beginning segmentation on available images...")

for img_path in image_files:
    name = os.path.basename(img_path)
    print(f"Processing: {name}")

    img_bgr = cv2.imread(img_path)
    if img_bgr is None:
        print(f"Failed to read: {img_path}")
        continue

    img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB)

    try:
        masks = mask_generator.generate(img_rgb)
        results[name] = masks
        print(f"Generated {len(masks)} mask(s) for {name}")
    except Exception as e:
        print(f"Error processing {name}: {e}")
        import traceback
        traceback.print_exc()

print(f"Segmentation completed for {len(results)} image(s).")

# Display First Image with Masks Overlay

In [None]:
if results:
    first_img = list(results.keys())[0]
    first_path = os.path.join(IMG_DIR, first_img)
    img_np = cv2.cvtColor(cv2.imread(first_path), cv2.COLOR_BGR2RGB)
    masks_info = results[first_img]

    if masks_info:
        plt.figure(figsize=(10, 10))
        plt.imshow(img_np)
        ax = plt.gca()
        ax.set_autoscale_on(False)

        h, w = img_np.shape[:2]
        overlay = np.zeros((h, w, 4), dtype=np.uint8)

        for mask_info in sorted(masks_info, key=lambda x: x['area'], reverse=True):
            mask = mask_info['segmentation']
            color = np.concatenate([np.random.randint(0, 256, 3), [int(255 * 0.6)]])
            overlay[mask] = color

        plt.imshow(overlay)
        plt.title(f"Segmentation Preview: {first_img}")
        plt.axis('off')
        plt.show()
    else:
        print(f"No masks available for {first_img}.")
else:
    print("No segmentation results found to display.")