<a href="https://colab.research.google.com/github/Ace-Chrono/Coral_Lesion_Measurer/blob/main/Lesion_Measurer.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Pip installs

In [None]:
!pip install ultralytics #This is where we get the YOLO packages

In [None]:
!git clone https://github.com/facebookresearch/sam2.git sam2_repo
%cd sam2_repo
!pip install -e . --no-build-isolation

In [None]:
!pip install ipympl

##Libraries

In [None]:
get_ipython().kernel.do_shutdown(restart=True)

{'status': 'ok', 'restart': True}

In [1]:
from google.colab import drive, files
import torch
from ultralytics import YOLO
from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor
from google.colab import output
output.enable_custom_widget_manager()
%matplotlib widget
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import cv2
from PIL import Image
import datetime
import json
import pandas as pd
import ipywidgets as widgets
from ipywidgets import Button, Output, VBox, HBox
from IPython.display import display, clear_output
import os
import gc
import re

## Setup

In [2]:
drive.mount('/content/gdrive/')
DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print(torch.__version__)
print(DEVICE) #Make sure to reload Anaconda if it prints out CPU even though it is the right PyTorch version
lesion_bbox_model = YOLO("/content/gdrive/MyDrive/Coral Lesion Research/Prototype Code/ML Models/YOLOV11_Lesion.pt")
ruler_bbox_model = YOLO("/content/gdrive/MyDrive/Coral Lesion Research/Prototype Code/ML Models/YOLOV11_Ruler.pt")
sam_location = "/content/gdrive/MyDrive/Coral Lesion Research/Prototype Code/ML Models/sam2.1_hiera_large.pt"
sam_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
lesion_sam_model = SAM2ImagePredictor(build_sam2(sam_cfg, sam_location))

Drive already mounted at /content/gdrive/; to attempt to forcibly remount, call drive.mount("/content/gdrive/", force_remount=True).
2.6.0+cu124
cuda:0


## YOLO and SAM Processing

In [None]:
def open_image(image_path):
    image = Image.open(image_path)
    image_np = np.array(image)
    height, width, channels = image_np.shape
    return image, height, width

def image_info(image_name):
    date = None
    repetition = None
    id = None

    parentheses_match = re.search(r'\s*\((\d+)\)$', image_name)
    if parentheses_match:
       repetition = int(parentheses_match.group(1))
       image_name_clean = re.sub(r'\s*\(\d+\)$', '', image_name)
    else:
        image_name_clean = image_name

    date_match = re.search(r'(\d{4}_\d{2}_\d{2})', image_name_clean)
    if date_match:
        date_str = date_match.group(1)
        date = date_str.replace('_', '-')

    id_match = re.search(r'(LC[^_\s]*)', image_name_clean)
    if id_match:
        id = id_match.group(1)

    return id, date, repetition

def get_conversion_ratio(image, image_name):
    try:
        results = ruler_bbox_model.predict(image, verbose=False)
        if len(results) > 0 and len(results[0].boxes) > 0:
            bboxes = results[0].boxes
            if len(bboxes.xyxy) > 0:
                x_min, y_min, x_max, y_max = bboxes.xyxy[0].tolist()
                width = x_max - x_min
                height = y_max - y_min
                conversion_ratio = max(width, height) / 30.5
                print(f"Ruler detected in {image_name}: {conversion_ratio:.2f} pixels/cm")
                return conversion_ratio
    except Exception as e:
        print(f"Error in ruler detection for {image_name}: {str(e)}")
    return None

def run_yolo_lesion(image):
    results = lesion_bbox_model.predict(image, verbose=False)
    for result in results:
        bboxes = result.boxes
        bboxes = bboxes.xyxy.tolist()
    return bboxes

def run_sam(image, bboxes):
    bboxes_np = []
    for bbox in bboxes:
        bbox_np = np.array(bbox)
        bboxes_np.append(bbox_np)
    input_boxes = np.array(bboxes_np).astype(np.float32)
    with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
        lesion_sam_model.set_image(np.array(image))
        masks, _, _ = lesion_sam_model.predict(
            point_coords=None,
            point_labels=None,
            box=input_boxes,               # shape: (N, 4)
            multimask_output=False,
        )
    return masks

## Mask Manipulation

In [4]:
def masks_to_polygons(masks):  # Creates polygons from a list of SAM masks
    all_polygons = []

    for mask in masks:
        mask = np.squeeze(mask)  # Ensures (H, W)

        if mask is None:
            raise ValueError(f"Mask {i} is None.")
        if mask.ndim != 2:
            raise ValueError(f"Mask {i} must be 2D after squeeze, got shape {mask.shape}")
        if mask.shape[0] == 0 or mask.shape[1] == 0:
            raise ValueError(f"Mask {i} has invalid shape {mask.shape}")
        if not np.any(mask):
            continue  # Skip empty masks

        # Convert mask to binary if it's not already
        if mask.max() > 1:
            _, binary_mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)
        else:
            binary_mask = (mask * 255).astype(np.uint8)

        # Find contours
        contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)

        polygons = []
        for contour in contours:
            # Simplify the contour to reduce the number of points
            epsilon = 0.001 * cv2.arcLength(contour, True)
            approx = cv2.approxPolyDP(contour, epsilon, True)

            # Extract points and flatten the list
            polygon = approx.reshape(-1, 2).tolist()
            flat_polygon = [point for sublist in polygon for point in sublist]
            polygons.append(flat_polygon)

        all_polygons.append(polygons)

    return all_polygons

def get_perimeter(all_polygons, conversion_ratio):
  perimeters = []
  for polygon in all_polygons:
    for points in polygon:
      contour = np.array(points).reshape(-1, 1, 2)
      perimeter = cv2.arcLength(contour, True)
      perimeter_cm = perimeter / conversion_ratio
      perimeter_um = perimeter_cm * 10_000
      perimeters.append(perimeter_um)

  return perimeters

def get_areas_and_centers(masks, bboxes, conversion_ratio):
    areas = []
    centers = []
    for mask in masks:
        area = np.count_nonzero(mask)
        area_um2 = area * ((1 / conversion_ratio)*10000)** 2
        areas.append(area_um2)
    for bbox in bboxes:
        x_min, y_min, x_max, y_max = bbox
        center_x = (x_min + x_max) / 2
        center_y = (y_min + y_max) / 2
        centers.append((center_x, center_y))  # (x, y) format
    return areas, centers

## Outputting

In [None]:
def output_image_cv(image, bboxes, masks, segmentations, areas, centers, image_output_path):
    # Convert PIL Image to NumPy array if necessary
    if not isinstance(image, np.ndarray):
        image = np.array(image)

    img = image.copy()

    # Convert RGB to BGR for OpenCV display
    if img.shape[-1] == 3:  # Check for color image
        img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)

    # Draw bounding boxes (green)
    for box in bboxes:
        x1, y1, x2, y2 = map(int, box)
        cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)

    # Draw SAM masks (semi-transparent blue overlays)
    for mask in masks:
        if mask.dtype != np.uint8:
            mask = (mask * 255).astype(np.uint8)

        if len(mask.shape) == 3:
            mask = mask.squeeze()

        # Create colored overlay
        color_mask = np.zeros_like(img, dtype=np.uint8)
        color_mask[:, :, 0] = 255  # Blue in BGR
        alpha = 0.2

        # Create 3-channel mask and blend it
        mask_3ch = np.stack([mask]*3, axis=-1)
        img = np.where(mask_3ch, (1 - alpha) * img + alpha * color_mask, img).astype(np.uint8)

    # Draw segmentation polygons (red outlines)
    for polygons in segmentations:
        formatted_polygons = [np.array(polygon, dtype=np.int32).reshape(-1, 2) for polygon in polygons]
        for polygon in formatted_polygons:
            cv2.polylines(img, [polygon], isClosed=True, color=(0, 0, 255), thickness=2)

    # Draw area annotations (white text with black background)
    for i, (area, center) in enumerate(zip(areas, centers)):
        x, y = map(int, center)
        text = f"{area:.2f} um^2"
        font = cv2.FONT_HERSHEY_SIMPLEX
        font_scale = 0.6
        thickness = 1
        text_size, _ = cv2.getTextSize(text, font, font_scale, thickness)
        text_w, text_h = text_size

        # Draw background rectangle
        cv2.rectangle(img, (x, y - text_h), (x + text_w, y), (0, 0, 0), -1)
        # Put text
        cv2.putText(img, text, (x, y - 2), font, font_scale, (255, 255, 255), thickness, cv2.LINE_AA)

    # Save image
    cv2.imwrite(image_output_path, img)

def output_csv(folder_name, image_name, lesion_count, conversion_ratio, areas, perimeters, csv_output_path):

    for i in range(len(areas)):
        areas[i] = float(areas[i])
        perimeters[i] = float(perimeters[i])

    new_row = {
        "Folder": folder_name,
        "Image Name": image_name,
        "# Lesions": lesion_count,
        "Pixels Per um": conversion_ratio,
        "um^2": areas,
        "Perimeters": perimeters
        }
    new_row_df = pd.DataFrame([new_row], columns=csv_columns)
    new_row_df.to_csv(csv_output_path, mode='a', header=not os.path.exists(csv_output_path), index=False)

def get_metadata(index, image_name, height, width, bboxes, segmentations):
    image_info = {
        "id": index,
        "license": 1,
        "file_name": image_name,
        "height": height,
        "width": width,
        "date_captured": datetime.datetime.now().isoformat()
    }
    annotations = []
    for annotation_id, (bbox, segmentation) in enumerate(zip(bboxes, segmentations)):
        x_min, y_min, x_max, y_max = bbox
        width_box = x_max - x_min
        height_box = y_max - y_min
        area = width_box * height_box

        annotation_info = {
            "id": index * 1000 + annotation_id,  # ensures uniqueness
            "image_id": index,
            "category_id": 1,
            "bbox": [x_min, y_min, width_box, height_box],
            "area": area,
            "segmentation": segmentation,
            "iscrowd": 0
        }
        annotations.append(annotation_info)
    return image_info, annotations

def output_coco_json(image_info_list, annotations_list, output_path):
    coco_dict = {
        "info": {
            "description": "Coral Dataset",
            "version": "1.0",
            "year": 2025,
            "contributor": "Richard Zhao",
            "date_created": datetime.datetime.now().isoformat()
        },
        "licenses": [
            {
                "id": 1,
                "name": "Attribution-NonCommercial-ShareAlike License",
                "url": "http://creativecommons.org/licenses/by-nc-sa/2.0/"
            }
        ],
        "images": image_info_list,
        "annotations": annotations_list,
        "categories": [
            {
                "id": 1,
                "name": "coral lesion",
                "supercategory": "marine_life"
            }
        ]
    }

    # Save to JSON file
    with open(output_path, "w") as f:
        json.dump(coco_dict, f, indent=4)


def append_row_to_excel(file_path, new_row_dict):
    if os.path.exists(file_path):
        df_existing = pd.read_excel(file_path)
        # Check if date already exists
        if new_row_dict['Date'] in df_existing['Date'].values:
            print(f"Skipping duplicate date {new_row_dict['Date']} in {file_path}")
            return
        df_existing = pd.concat([df_existing, pd.DataFrame([new_row_dict])], ignore_index=True)
        df_existing = df_existing.sort_values(by='Date')
        df_existing.to_excel(file_path, index=False)
    else:
        pd.DataFrame([new_row_dict]).to_excel(file_path, index=False)

##Manual Conversion Ratio GUI

In [6]:
class ClickCollector:
    def __init__(self, image, image_name="Image", on_done=None):
        self.image = image
        self.image_name = image_name
        self.on_done = on_done
        self.coords = []
        self.line = None
        self.dots = []
        self.click_mode = False

        self.out = Output()

        # Setup plot
        with self.out:
            self.fig, self.ax = plt.subplots(figsize=(8,6))
            self.ax.imshow(self.image)
            self.ax.set_title(f"{self.image_name}")
            self.cid = self.fig.canvas.mpl_connect('button_press_event', self.onclick)
            plt.show()

        # Buttons
        self.click_button = Button(description="Enable Click Mode", button_style='primary')
        self.click_button.on_click(self.toggle_click_mode)

        self.clear_button = Button(description="Clear", button_style='danger')
        self.clear_button.on_click(self.clear)

        self.finish_button = Button(description="Finish", button_style='success')
        self.finish_button.on_click(self.finish)

        # Distance input & submit button, hidden initially
        self.dist_input = widgets.FloatText(description="Real distance (cm):")
        self.dist_submit = widgets.Button(description="Submit distance")
        self.dist_submit.on_click(self.submit_distance)
        self.dist_input.layout.display = 'none'
        self.dist_submit.layout.display = 'none'

        display(VBox([
            self.out,
            HBox([self.click_button, self.clear_button, self.finish_button]),
            VBox([self.dist_input, self.dist_submit])
        ]))

    def toggle_click_mode(self, b):
        self.click_mode = not self.click_mode
        if self.click_mode:
            self.click_button.description = "Click Mode: ON (Click Image)"
            self.click_button.button_style = 'warning'
            print("🖱️ Click mode enabled: Click two points.")
        else:
            self.click_button.description = "Click Mode: OFF"
            self.click_button.button_style = 'primary'
            print("✋ Click mode disabled: Use zoom/pan tools.")

    def onclick(self, event):
        if not self.click_mode or event.inaxes != self.ax:
            return
        x, y = event.xdata, event.ydata
        print(f"📍 Clicked at ({x:.1f}, {y:.1f})")
        self.coords.append((x, y))
        self.draw_dot(x, y)
        if len(self.coords) == 2:
            self.draw_line()
            self.toggle_click_mode(None)

    def draw_dot(self, x, y):
        dot = self.ax.plot(x, y, 'ro', markersize=6)[0]
        self.dots.append(dot)
        self.fig.canvas.draw()

    def draw_line(self):
        x_vals = [self.coords[0][0], self.coords[1][0]]
        y_vals = [self.coords[0][1], self.coords[1][1]]
        if self.line:
            self.line.remove()
        self.line, = self.ax.plot(x_vals, y_vals, 'r-', linewidth=2)
        self.fig.canvas.draw()

    def clear(self, b):
        # Clear all dots
        for dot in self.dots:
            dot.remove()
        self.dots = []

        # Clear line
        if self.line:
            self.line.remove()
            self.line = None

        self.coords = []
        self.fig.canvas.draw()
        print("🧹 Cleared all points and lines.")

    def finish(self, b):
        if len(self.coords) != 2:
            with self.out:
                print("Please click exactly 2 points before finishing.")
            return
        with self.out:
            print(f"✅ Line set from {self.coords[0]} to {self.coords[1]}")
            print("Please enter the real-world distance (cm) below:")
        # Show distance input widgets
        self.dist_input.layout.display = None
        self.dist_submit.layout.display = None

    def submit_distance(self, b):
        dist = self.dist_input.value
        if dist <= 0:
            with self.out:
                print("Distance must be positive.")
            return
        pixel_dist = np.linalg.norm(np.array(self.coords[0]) - np.array(self.coords[1]))
        ratio = pixel_dist / dist
        with self.out:
            print(f"➡️ Pixel distance: {pixel_dist:.2f}")
            print(f"➡️ Real distance: {dist:.2f} cm")
            print(f"➡️ Conversion ratio: {ratio:.2f} pixels/cm")
        if self.on_done:
            self.on_done(ratio)

## Define Path Information

In [7]:
image_input_root = "/content/gdrive/MyDrive/Coral Lesion Research/Prototype Code/Input/"

image_input_path = [
    os.path.join(root, d)
    for root, dirs, _ in os.walk(image_input_root)
    for d in dirs
]
image_input_path.append(image_input_root)

image_output_path = "/content/gdrive/MyDrive/Coral Lesion Research/Prototype Code/Output/"
csv_output_path = image_output_path + "/coral_areas_output.csv"
csv_columns = ["Folder", "Image Name", "# Lesions", "Pixels Per um", "um^2", "Perimeters"]

areas_folder = os.path.join(image_output_path, "areas")
perimeters_folder = os.path.join(image_output_path, "perimeters")
os.makedirs(areas_folder, exist_ok=True)
os.makedirs(perimeters_folder, exist_ok=True)

## Run Measurer

In [None]:
pending_conversion = {}

for index, folder in enumerate(image_input_path):
    image_metadata_list = []
    annotation_metadata_list = []
    folder_name = os.path.basename(folder.rstrip('/'))
    json_output_path = os.path.join(image_output_path, "_annotations.coco.json")
    image_index = 0

    for file in os.listdir(folder):
        if file.endswith('.JPG'):
            old_image_path = os.path.join(folder, file)
            image_name, _ = os.path.splitext(file)
            image_name_png = image_name + ".png"
            new_image_path = os.path.join(image_output_path, image_name_png)
            image, height, width = open_image(old_image_path)

            conversion_ratio = get_conversion_ratio(image, image_name)
            if conversion_ratio is None:
                pending_conversion[image_name] = (image, height, width, old_image_path)
                continue

            bboxes_lesion = run_yolo_lesion(image)
            masks_lesion = run_sam(image, bboxes_lesion)
            segmentations = masks_to_polygons(masks_lesion)
            perimeters = get_perimeter(segmentations, conversion_ratio)
            areas, centers = get_areas_and_centers(masks_lesion, bboxes_lesion, conversion_ratio)
            id, date, repetition = image_info(image_name)

            area_file = os.path.join(areas_folder, f"{id}_areas.xlsx")
            perim_file = os.path.join(perimeters_folder, f"{id}_perimeters.xlsx")
            area_row = {'Date': date}
            for i, a in enumerate(areas):
                area_row[f'Area {i+1}'] = a
            perim_row = {'Date': date}
            for i, p in enumerate(perimeters):
                perim_row[f'Perimeter {i+1}'] = p
            append_row_to_excel(area_file, area_row)
            append_row_to_excel(perim_file, perim_row)

            output_image_cv(image, bboxes_lesion, masks_lesion, segmentations, areas, centers, new_image_path)
            output_csv(folder_name, image_name_png, len(masks_lesion), conversion_ratio, areas, perimeters, csv_output_path)

            image_metadata, annotation_metadata = get_metadata(image_index, image_name_png, height, width, bboxes_lesion, segmentations)
            image_metadata_list.append(image_metadata)
            annotation_metadata_list.append(annotation_metadata)

            image_index += 1
            del image, bboxes_lesion, masks_lesion, segmentations, areas, centers
            gc.collect()
            torch.cuda.empty_cache()

    if image_metadata_list and annotation_metadata_list:
        output_coco_json(image_metadata_list, annotation_metadata_list, json_output_path)

resolved_ratios = {}

pending_items = list(pending_conversion.items())
current_idx = 0

main_output = widgets.Output()
display(main_output)

manual_output_path = image_output_path + "/manual_conversion"
os.makedirs(manual_output_path, exist_ok=True)
image_metadata_list = []
annotation_metadata_list = []
folder_name = os.path.basename(folder.rstrip('/'))
json_output_path = os.path.join(manual_output_path, "_annotations.coco.json")
# image_index = 0

def process_next():
    global current_idx
    if current_idx >= len(pending_items):
        with main_output:
            clear_output()
            print("All ratios collected. Now running analysis...\n")
        run_analysis()
        return

    image_name, (image, height, width, old_image_path) = pending_items[current_idx]
    current_idx += 1

    def handle_ratio(ratio):
        resolved_ratios[image_name] = ratio
        process_next()

    with main_output:
        clear_output(wait=True)
        ClickCollector(image, image_name, on_done=handle_ratio)

def run_analysis():
    image_index = 0
    for image_name, conversion_ratio in resolved_ratios.items():
        if conversion_ratio is None:
            print(f"Skipping image {image_name} due to missing conversion ratio.")
            continue

        image, height, width, old_image_path = pending_conversion[image_name]
        image_name_png = image_name + ".png"
        new_image_path = os.path.join(manual_output_path, image_name_png)

        bboxes_lesion = run_yolo_lesion(image)
        masks_lesion = run_sam(image, bboxes_lesion)
        segmentations = masks_to_polygons(masks_lesion)
        perimeters = get_perimeter(segmentations, conversion_ratio)
        areas, centers = get_areas_and_centers(masks_lesion, bboxes_lesion, conversion_ratio)
        id, date, repetition = image_info(image_name)

        # Excel
        area_file = os.path.join(areas_folder, f"{id}_areas.xlsx")
        perim_file = os.path.join(perimeters_folder, f"{id}_perimeters.xlsx")
        area_row = {'Date': date}
        for i, a in enumerate(areas):
            area_row[f'Area {i+1}'] = a
        perim_row = {'Date': date}
        for i, p in enumerate(perimeters):
            perim_row[f'Perimeter {i+1}'] = p
        append_row_to_excel(area_file, area_row)
        append_row_to_excel(perim_file, perim_row)

        # Output image + CSV
        output_image_cv(image, bboxes_lesion, masks_lesion, segmentations, areas, centers, new_image_path)
        output_csv("manual_conversion", image_name_png, len(masks_lesion), conversion_ratio, areas, perimeters, csv_output_path)

        image_metadata, annotation_metadata = get_metadata(image_index, image_name_png, height, width, bboxes_lesion, segmentations)
        image_metadata_list.append(image_metadata)
        annotation_metadata_list.append(annotation_metadata)

        image_index += 1
        del image, bboxes_lesion, masks_lesion, segmentations, areas, centers
        gc.collect()
        torch.cuda.empty_cache()

    if image_metadata_list and annotation_metadata_list:
        output_coco_json(image_metadata_list, annotation_metadata_list, json_output_path)

if pending_conversion:
    process_next()

Ruler detected in RRC_S1_ECA_LC-041_a_2021_05_28 (1): 66.60 pixels/mm
Skipping duplicate date 2021-05-28 in /content/gdrive/MyDrive/Coral Lesion Research/Prototype Code/Output/areas/LC-041_areas.xlsx
Skipping duplicate date 2021-05-28 in /content/gdrive/MyDrive/Coral Lesion Research/Prototype Code/Output/perimeters/LC-041_perimeters.xlsx
peak memory: 3682.16 MiB, increment: 653.26 MiB


Output()