In [None]:
import numpy as np
import cv2
import matplotlib.pyplot as plt
import sys
import os

# --- Configuration: Determine the project root ---

current_executed_dir = os.getcwd()
print(f"Current executed_dir (os.getcwd()): {current_executed_dir}")

# This script is expected to be run from 'nutri-snap/' or 'nutri-snap/notebooks/'
# If run from 'nutri-snap/notebooks/', project_root is the parent.
# If run from 'nutri-snap/', project_root is the current directory.

# Determine the project root (which is the 'nutri-snap' folder)
if os.path.basename(current_executed_dir) == 'notebooks' and \
   os.path.basename(os.path.abspath(os.path.join(current_executed_dir, '..'))) == 'nutri-snap':
    project_root = os.path.abspath(os.path.join(current_executed_dir, '..'))
elif os.path.basename(current_executed_dir) == 'nutri-snap':
    project_root = current_executed_dir
else:
    # Fallback or error if the directory structure is not recognized
    print(f"WARNING: Unexpected directory structure. Attempting to use '{current_executed_dir}' as the nutri-snap project root.")
    print("Please run this script from the 'nutri-snap' directory or the 'nutri-snap/notebooks' directory.")
    project_root = current_executed_dir # Default assumption, may require manual adjustment

print(f"Project root (nutri-snap folder): {project_root}")


if project_root not in sys.path:
    sys.path.append(project_root)

# --- User Parameters ---
SPLIT = 'train'  # Or 'test', etc.
DISH_ID = 'dish_1558640849'  # Replace with the desired dish ID

# Dynamically constructed paths
# base_processed_dir should be project_root (which is nutri-snap) + /data/processed
base_processed_dir = os.path.join(project_root, 'data', 'processed')

print(f"Using base_processed_dir: {base_processed_dir}")


sam_instance_file_path = os.path.join(base_processed_dir, SPLIT, 'sam_instance', DISH_ID, f'{DISH_ID}.npy')
rgb_image_path = os.path.join(base_processed_dir, SPLIT, 'rgb', DISH_ID, f'{DISH_ID}.png')

print(f"Attempting to load SAM masks from: {sam_instance_file_path}")
print(f"Attempting to load RGB image from: {rgb_image_path}")

# Colors for masks
try:
    import matplotlib.cm as cm
except ImportError:
    cm = None
    print("Warning: matplotlib.cm could not be imported. Mask colors might be less distinct.")
    DEFAULT_COLORS = [ # BGR format for OpenCV
        (255, 0, 0), (0, 255, 0), (0, 0, 255), (255, 255, 0), (0, 255, 255),
        (255, 0, 255), (192, 192, 192), (128, 0, 0), (0, 128, 0), (0, 0, 128)
    ]

background_color_bgr = (0, 0, 0) # Black for background if RGB image is not used/found

# --- Load RGB image (for dimensions and optional background) ---
img_rgb = cv2.imread(rgb_image_path)
if img_rgb is None:
    print(f"Warning: Could not load RGB image from {rgb_image_path}. Dimensions will be taken from masks. Black background used.")
    image_height, image_width = -1, -1
    display_image = None
else:
    img_rgb = cv2.cvtColor(img_rgb, cv2.COLOR_BGR2RGB) # For matplotlib display
    image_height, image_width = img_rgb.shape[:2]
    print(f"RGB image {DISH_ID} loaded. Dimensions: {image_width}x{image_height}")
    display_image = img_rgb.copy() # Use RGB image as background

# --- Load and Draw SAM Instance Masks ---
try:
    sam_masks = np.load(sam_instance_file_path)
    print(f"SAM instance file '{sam_instance_file_path}' loaded successfully.")
    
    if sam_masks.ndim != 3 or sam_masks.shape[0] == 0:
        print(f"  Error: The .npy file does not contain a valid stack of masks. Expected shape (N, H, W), got {sam_masks.shape}.")
        raise ValueError("Invalid mask format")

    num_masks, h_mask, w_mask = sam_masks.shape
    print(f"Number of instance masks found: {num_masks}, Mask dimensions: {w_mask}x{h_mask}")

    if image_height == -1: # If RGB image was not loaded
        image_height, image_width = h_mask, w_mask
        print(f"Image dimensions set from masks: {image_width}x{image_height}")

    # Ensure mask dimensions match RGB image if loaded
    if display_image is not None and (h_mask != image_height or w_mask != image_width):
        print(f"  Warning: Mask dimensions ({w_mask}x{h_mask}) do not match RGB image ({image_width}x{image_height}).")
        # Could resize masks or image, or just display on black background.
        # For now, continue, but overlay might be incorrect.
        # Create a black image if dimensions don't match or if RGB image wasn't loaded
        display_image = np.full((image_height, image_width, 3), background_color_bgr, dtype=np.uint8)
        display_image = cv2.cvtColor(display_image, cv2.COLOR_BGR2RGB) # Convert to RGB for matplotlib
    elif display_image is None: # No RGB image, create a black background
        display_image = np.full((image_height, image_width, 3), background_color_bgr, dtype=np.uint8)
        display_image = cv2.cvtColor(display_image, cv2.COLOR_BGR2RGB)


    # Generate colors for each mask
    if cm:
        # Use a colormap to generate colors. 'tab20' has 20 distinct colors.
        # You can choose other colormaps like 'viridis', 'jet', etc.
        # If num_masks > 20, colors will start to repeat with tab20.
        colormap = cm.get_cmap('tab20', num_masks if num_masks > 0 else 1)
        colors_rgba = [colormap(i) for i in range(num_masks)]
        # Convert RGBA (0-1) to RGB (0-255) for display
        colors_for_display = [(int(r*255), int(g*255), int(b*255)) for r,g,b,a in colors_rgba]
    else: # Fallback if cm is not available
        colors_for_display = [tuple(c[::-1]) for c in DEFAULT_COLORS] # Convert BGR to RGB (if DEFAULT_COLORS are BGR)
                                                                    # Assuming DEFAULT_COLORS were BGR, make them RGB.
                                                                    # If they were already RGB, this would swap R and B.
                                                                    # Let's assume DEFAULT_COLORS are (R,G,B) for direct use.
        if DEFAULT_COLORS and isinstance(DEFAULT_COLORS[0][0], int): # Check if colors are in (0-255) range
             colors_for_display = DEFAULT_COLORS # Assuming they are already RGB tuples
        else: # If they were BGR as originally in my example, then convert
             colors_for_display = [tuple(c[::-1]) for c in DEFAULT_COLORS]


    if num_masks == 0:
        print("The file contains no instance masks to draw.")
    else:
        overlay = display_image.copy() # Create a copy for overlaying masks

        for i in range(num_masks):
            mask = sam_masks[i] # Boolean mask (H, W)
            
            # Ensure mask is boolean for indexing
            if mask.dtype != np.bool_:
                mask = mask.astype(np.bool_)

            current_color = colors_for_display[i % len(colors_for_display)] # Reuse colors if more masks than defined colors

            # Apply color to the mask
            # The 'overlay' image is in RGB
            overlay[mask] = current_color
        
        # Optionally, add transparency effect (alpha blending)
        # alpha = 0.6 # Transparency of masks
        # display_image = cv2.addWeighted(overlay, alpha, display_image, 1 - alpha, 0) # display_image would need to be BGR for cv2.addWeighted
        # For simple replacement (no transparency):
        display_image = overlay


    # --- Display with Matplotlib ---
    plt.figure(figsize=(12, 10)) # Adjust figure size if necessary
    plt.imshow(display_image) # Image is already in RGB
    plt.title(f"Visualization of {num_masks} SAM instance masks for {DISH_ID} (Split: {SPLIT})")
    plt.axis('off') # Hide axes
    plt.show()

except FileNotFoundError:
    print(f"ERROR: SAM instance file '{sam_instance_file_path}' or RGB image '{rgb_image_path}' not found. Please check the displayed paths, DISH_ID, and SPLIT.")
except ImportError:
    print("ERROR: Please install necessary libraries: pip install numpy opencv-python matplotlib")
except Exception as e:
    print(f"An error occurred during visualization: {e}")
    import traceback
    traceback.print_exc()