# SAM2 Evaluation Pipeline - Colab Notebook

This notebook provides an interactive way to run the SAM2 evaluation pipeline

**Workflow:**
1.  **Setup:** Clone repository, install dependencies, and install the required SAM2 library.
2.  **Configuration:** Set parameters for the pipeline (model, data paths, etc.).
3.  **Data Preparation:** Generate the `degradation_map.json` (assumes image data exists).
4.  **(Optional) Visualization:** Inspect sample images and masks.
5.  **Run Pipeline:** Execute the evaluation using the configured settings.
6.  **View Results:** Load and display the output CSV.

## 1. Setup Environment

In [None]:
# Check if running in Colab
import sys
IN_COLAB = 'google.colab' in sys.modules

# Base directory for the project
# If in Colab, clone the repo. Otherwise, assume we are running from the repo root.
import os

if IN_COLAB:
    print('Running in Colab, cloning repository...')
    # TODO: Replace with your repo URL (use token if private)
    # Example: !git clone https://<your_token>@github.com/YOUR_USERNAME/SAM2_analysis.git
    %git clone https://github.com/YOUR_USERNAME/SAM2_analysis.git # <-- REPLACE THIS
    %cd SAM2_analysis
    PROJECT_ROOT = '/content/SAM2_analysis'
else:
    print('Running locally, assuming current directory is project root.')
    # Find the project root assuming this notebook is in the root
    PROJECT_ROOT = os.path.abspath('.')
    # Verify by checking for a known file/directory
    if not os.path.exists(os.path.join(PROJECT_ROOT, 'main.py')):
        print(f'Warning: Could not confirm project root at {PROJECT_ROOT}')

print(f'Project Root: {PROJECT_ROOT}')
os.chdir(PROJECT_ROOT) # Ensure we are in the project root directory

In [None]:
# Install dependencies from requirements.txt
print('\nInstalling dependencies...')
%pip install -r requirements.txt

In [None]:
# Install the SAM2 library
# Assumes the sam2 code is located in 'external/sam2' within the project
print('\nInstalling SAM2 library...')
SAM2_DIR = os.path.join(PROJECT_ROOT, 'external/sam2')

if not os.path.exists(SAM2_DIR):
    print(f'Error: SAM2 directory not found at {SAM2_DIR}')
    print('Please ensure you have cloned the SAM2 repository into external/sam2')
    # Optional: Add command to clone it if missing
    # print('Attempting to clone SAM2...')
    # !git clone <SAM2_REPO_URL> external/sam2 # <-- Add SAM2 repo URL if desired
else:
    # Use pip install -e for editable install
    %pip install -e "{SAM2_DIR}"

## 2. Configuration

In [None]:
# --- Pipeline Configuration ---
# Mimic the structure of sam2_eval_config.json

config = {
    "pipeline_name": "sam2_eval",
    "description": "Evaluate SAM2 auto-mask generator on data map (Colab)",

    # --- Data Configuration ---
    # Path to the generated data map (relative to project root)
    "data_path": "data/degradation_map.json",

    # Base directory where image files referenced in data_path are located
    # Paths in data_path['versions'][*]['filepath'] are relative to this.
    "image_base_dir": "data", # Assumes images are in data/images, data/pic_degraded/*

    # --- Model Configuration ---
    # Hugging Face identifier for the SAM2 model
    # Examples: 'facebook/sam2-hiera-tiny', 'facebook/sam2-hiera-small',
    #           'facebook/sam2-hiera-base', 'facebook/sam2-hiera-large'
    "model_hf_id": "facebook/sam2-hiera-tiny", # Use a smaller model for faster testing

    # --- Mask Generator Configuration ---
    # Parameters passed to SAM2AutomaticMaskGenerator
    # See SAM2 library documentation for all options
    "generator_config": {
        "points_per_side": 16,       # Lower for faster processing
        "pred_iou_thresh": 0.80,     # Default: 0.88
        "stability_score_thresh": 0.90, # Default: 0.95
        "crop_n_layers": 0,          # Default: 0 (no cropping)
        "min_mask_region_area": 10   # Default: 0
    },

    # --- Evaluation Metric Configuration ---
    "iou_threshold": 0.5,        # For matching pred mask to GT mask
    "bf1_tolerance": 2,          # Tolerance in pixels for Boundary F1 score

    # --- Output Configuration ---
    # Directory to save the results CSV file (relative to project root)
    "output_dir": "output",
    "results_filename_prefix": "results_colab_"
}

# Make directories absolute for clarity later
config['data_path'] = os.path.join(PROJECT_ROOT, config['data_path'])
config['image_base_dir'] = os.path.join(PROJECT_ROOT, config['image_base_dir'])
config['output_dir'] = os.path.join(PROJECT_ROOT, config['output_dir'])

print("Configuration set:")
import json
print(json.dumps(config, indent=2))

## 3. Data Preparation

**IMPORTANT:** This section assumes the necessary image files (e.g., in `data/images/`, `data/pic_degraded/*`) and any COCO annotation files needed by `code_json.py` are already present in the `data/` directory of your Colab environment/mounted drive.

If they are not present, you need to:
1. Run `data/data_scripts/code_degradation.py` locally first (if you haven't).
2. Upload the entire `data/` directory to Colab or sync via Google Drive.

In [None]:
import os # Ensure os is imported if running cells independently

# Ensure data and output directories exist
DATA_DIR = os.path.join(PROJECT_ROOT, 'data')
OUTPUT_DIR = config['output_dir'] # Use absolute path from config
os.makedirs(DATA_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)


print(f'Checking for image base directory: {config["image_base_dir"]}')
print(f'Expected output map path: {config["data_path"]}')

# Check for a key data file (e.g., the script expects COCO annotations)
expected_coco_file = os.path.join(DATA_DIR, 'annotations', 'instances_val2017_100.json') # Example path
if not os.path.exists(expected_coco_file):
    print(f"Warning: Expected COCO annotation file not found at {expected_coco_file}. "
          f"The 'code_json.py' script might fail if it relies on this.")
# Add checks for other necessary data/image directories if needed

# Run the script to generate the degradation_map.json
print('\nRunning script to generate degradation_map.json...')
# Note: Ensure code_json.py uses relative paths correctly or adjust it if needed
%python data/data_scripts/code_json.py

# Verify the map was created
data_map_path = config['data_path'] # Use absolute path from config
if os.path.exists(data_map_path):
    print(f'Successfully generated {data_map_path}')
else:
    print(f'Error: {data_map_path} was not generated. Check data availability and script output.')
    # Add more detailed error checking if the script provides specific logs

## 4. (Optional) Visualize Data Sample

In [None]:
import json
import random
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import pycocotools.mask as mask_util # Import for RLE decoding
import os # Ensure os is imported

def visualize_sample(data_map_path, image_base_dir):
    """Loads the data map, picks a random image, and displays its versions and GT mask."""
    if not os.path.exists(data_map_path):
        print(f'Cannot visualize: {data_map_path} not found.')
        return

    with open(data_map_path, 'r') as f:
        try:
            data_map = json.load(f)
        except json.JSONDecodeError as e:
            print(f"Error reading data map JSON: {e}")
            return


    if not data_map:
        print('Cannot visualize: Data map is empty.')
        return

    image_id = random.choice(list(data_map.keys()))
    print(f'Visualizing sample for image_id: {image_id}')
    item_data = data_map[image_id]

    # Decode GT mask
    gt_rle = item_data.get('ground_truth_rle')
    gt_mask = None
    if gt_rle:
        try:
            # Handle potential string vs dict RLE formats if needed
            if isinstance(gt_rle, str): # If RLE is just the counts string
                 # Need size info - assume it's stored elsewhere or reconstruct
                 print("Warning: GT RLE is string, size info needed for decoding.")
                 # Example: Need to fetch item_data['height'], item_data['width']
                 # gt_rle_dict = {'size': [item_data['height'], item_data['width']], 'counts': gt_rle}
                 # gt_mask = mask_util.decode(gt_rle_dict)
            elif isinstance(gt_rle, dict):
                 gt_mask = mask_util.decode(gt_rle)
            else:
                 print(f"Warning: Unexpected GT RLE format: {type(gt_rle)}")

        except Exception as e:
            print(f'  Could not decode GT RLE: {e}')


    # Count versions - needs robust handling of structure
    num_versions = 0
    versions_to_plot = []
    base_img_path = image_base_dir # Already absolute

    if 'versions' in item_data:
         for degradation_type, levels_or_data in item_data['versions'].items():
              if isinstance(levels_or_data, dict) and 'filepath' in levels_or_data: # e.g., 'original'
                   filepath = levels_or_data['filepath']
                   level = levels_or_data.get('level', 'N/A')
                   title = f'{degradation_type}\n(Level: {level})'
                   # Construct absolute path carefully based on structure
                   abs_path = os.path.join(base_img_path, filepath) # Assumes filepath is relative to base_img_dir
                   versions_to_plot.append({'title': title, 'path': abs_path})
                   num_versions += 1
              elif isinstance(levels_or_data, dict): # Nested levels like {'1': {...}, '2': {...}}
                   for level, version_data in levels_or_data.items():
                       if isinstance(version_data, dict) and 'filepath' in version_data:
                           filepath = version_data['filepath']
                           level_val = version_data.get('level', level) # Use nested level if available
                           title = f'{degradation_type}_{level}\n(Level: {level_val})'
                           # Construct absolute path
                           abs_path = os.path.join(base_img_path, filepath) # Assumes filepath relative to base
                           versions_to_plot.append({'title': title, 'path': abs_path})
                           num_versions += 1

    plot_cols = num_versions + (1 if gt_mask is not None else 0)
    if plot_cols == 0:
        print("No image versions or GT mask found to plot.")
        return

    fig, axes = plt.subplots(1, max(1, plot_cols), figsize=(5 * max(1, plot_cols), 5))
    if plot_cols == 1:
        axes = [axes] # Make it iterable

    plot_idx = 0

    # Display versions
    for version_info in versions_to_plot:
        img_path = version_info['path']
        title = version_info['title']
        try:
            img = Image.open(img_path).convert('RGB')
            axes[plot_idx].imshow(img)
            axes[plot_idx].set_title(title)
        except FileNotFoundError:
            print(f'  Image not found: {img_path}')
            axes[plot_idx].set_title(f'{title}\n(Not Found)')
        except Exception as e:
             print(f"Error loading image {img_path}: {e}")
             axes[plot_idx].set_title(f'{title}\n(Load Error)')
        finally:
            axes[plot_idx].axis('off')
            plot_idx += 1


    # Display GT mask
    if gt_mask is not None:
        if plot_idx < len(axes): # Ensure we don't go out of bounds
            axes[plot_idx].imshow(gt_mask, cmap='gray')
            axes[plot_idx].set_title('Ground Truth Mask')
            axes[plot_idx].axis('off')
        else:
             print("Warning: Not enough subplot axes allocated for GT mask.")

    # Hide unused axes
    for i in range(plot_idx + (1 if gt_mask is not None else 0), len(axes)):
        axes[i].axis('off')


    plt.tight_layout()
    plt.show()

# --- Run visualization ---
data_map_path = config['data_path']
image_base_dir = config['image_base_dir']
if os.path.exists(data_map_path):
    visualize_sample(data_map_path, image_base_dir)
else:
    print(f"Skipping visualization because data map not found: {data_map_path}")


## 5. Run Pipeline

In [None]:
import sys
import os # Ensure os is imported

# Ensure project root is in path for imports
if PROJECT_ROOT not in sys.path:
    sys.path.append(PROJECT_ROOT)

# Import the main pipeline function
try:
    from sam2_eval_pipeline import run_evaluation_pipeline
    print('Imported run_evaluation_pipeline successfully.')
except ImportError as e:
    print(f'Error importing pipeline function: {e}')
    print('Ensure installation steps completed correctly and you are in the project root.')
    run_evaluation_pipeline = None # Prevent further errors
except Exception as e:
     print(f"An unexpected error occurred during import: {e}")
     run_evaluation_pipeline = None


# Execute the pipeline
results_df = None
if run_evaluation_pipeline:
    print('\nStarting evaluation pipeline...')
    try:
        # Use the absolute paths from the config dictionary defined earlier
        results_df = run_evaluation_pipeline(
            data_path=config['data_path'],
            image_base_dir=config['image_base_dir'],
            model_hf_id=config['model_hf_id'],
            generator_config=config['generator_config'],
            iou_threshold=config['iou_threshold'],
            bf1_tolerance=config['bf1_tolerance'],
            output_dir=config['output_dir'],
            results_filename_prefix=config['results_filename_prefix']
        )
        if results_df is not None:
            print(f'Pipeline finished successfully. Results saved in {config["output_dir"]}')
        else:
            # The function might return None on failure/no data
             print('Pipeline function completed, but returned None (possibly no data processed or an error occurred). Check logs.')

    except Exception as e:
        print(f"An error occurred during pipeline execution: {e}")
        import traceback
        traceback.print_exc() # Print detailed traceback for debugging
else:
     print("Skipping pipeline execution due to import failure.")


## 6. View Results

In [None]:
import pandas as pd
import os # Ensure os is imported

output_path = config['output_dir'] # Use absolute path
prefix = config['results_filename_prefix']

if results_df is not None and not results_df.empty:
    print('Displaying results DataFrame returned from pipeline:')
    # Configure pandas display options if needed
    # pd.set_option('display.max_rows', None)
    # pd.set_option('display.max_columns', None)
    try:
        from google.colab.data_table import DataTable # Use Colab's interactive table
        display(DataTable(results_df))
    except ImportError:
        display(results_df) # Fallback for non-Colab

elif os.path.exists(output_path):
    # Try to find the latest results CSV in the output directory if DF is empty/None
    print('Results DataFrame not available directly. Trying to load latest CSV from output directory...')
    try:
        csv_files = [f for f in os.listdir(output_path) if f.startswith(prefix) and f.endswith('.csv')]
        if csv_files:
            # Find the most recently modified CSV
            latest_csv = max(csv_files, key=lambda f: os.path.getmtime(os.path.join(output_path, f)))
            latest_csv_path = os.path.join(output_path, latest_csv)
            print(f'Loading latest results file: {latest_csv_path}')
            results_df_loaded = pd.read_csv(latest_csv_path)
            print('Displaying loaded results DataFrame:')
            try:
                from google.colab.data_table import DataTable
                display(DataTable(results_df_loaded))
            except ImportError:
                 display(results_df_loaded)
        else:
            print(f'No results CSV files found in {output_path} matching prefix "{prefix}"')
    except FileNotFoundError:
         print(f'Output directory {output_path} seems to have disappeared.')
    except Exception as e:
         print(f'Error loading or listing results CSV: {e}')
else:
     print(f"Neither results DataFrame nor output directory ({output_path}) found.")
