In [1]:
import os
import re
import numpy as np
import pandas as pd
from PIL import Image
from segment_anything import sam_model_registry, SamPredictor

In [2]:
# Load SAM Model
model_type = "vit_l"  # Model type can be "vit_h", "vit_l", or "vit_b"
#sam = sam_model_registry[model_type](checkpoint="./sam_vit_h_4b8939.pth")
sam = sam_model_registry[model_type](checkpoint="./sam_vit_l_0b3195.pth")
#sam = sam_model_registry[model_type](checkpoint="./sam_vit_b_01ec64.pth")
predictor = SamPredictor(sam)

  state_dict = torch.load(f)


In [3]:
def get_identifier(filename: str) -> str:
    """
    Extracts a unique identifier from a filename using a regular expression pattern.
    """
    match = re.match(r"^([A-D]\d+_\d+).*_(\d{3})\.tif$", filename)
    return f"{match.group(1)}_{match.group(2)}" if match else None

def split_mask(mask: np.array) -> tuple:
    """
    Splits a mask into two regions: head (first one-third) and tail (remaining two-thirds).
    """
    split_index = mask.shape[1] // 3  # Compute the index for one-third of the image width

    head_mask = np.zeros_like(mask)  # Mask for the head (first one-third)
    tail_mask = np.zeros_like(mask)  # Mask for the tail (remaining two-thirds)

    head_mask[:, :split_index] = mask[:, :split_index]  # First one-third (head)
    tail_mask[:, split_index:] = mask[:, split_index:]  # Remaining two-thirds (tail)

    return head_mask, tail_mask

def process_images(input_dir: str, output_dir: str, 
                   save_masks: bool = True, save_masked_images: bool = True) -> pd.DataFrame:
    """
    Processes pairs of 'Phase Contrast' and 'GFP' images to generate a binary mask using SAM, 
    calculate GFP intensity within the mask, and save mask and masked images for head, tail, and whole fish.
    """
    # Create subdirectories for saving masks and masked images
    phase_mask_dir = os.path.join(output_dir, "phase_masks")
    head_mask_dir = os.path.join(output_dir, "head_masks")
    tail_mask_dir = os.path.join(output_dir, "tail_masks")
    os.makedirs(output_dir, exist_ok=True)
    if save_masks:
        os.makedirs(phase_mask_dir, exist_ok=True)
        os.makedirs(head_mask_dir, exist_ok=True)
        os.makedirs(tail_mask_dir, exist_ok=True)

    phase_contrast_files = {}
    gfp_files = {}

    for file_name in os.listdir(input_dir):
        if file_name.endswith(".tif"):
            identifier = get_identifier(file_name)
            if identifier:
                if "Phase Contrast" in file_name:
                    phase_contrast_files[identifier] = file_name
                elif "GFP" in file_name:
                    gfp_files[identifier] = file_name

    gfp_intensity_results = []

    for i, (identifier, phase_file) in enumerate(phase_contrast_files.items()):
        if identifier in gfp_files:
            try:
                phase_path = os.path.join(input_dir, phase_file)
                gfp_path = os.path.join(input_dir, gfp_files[identifier])
                
                # Load Phase Contrast image
                phase_image = Image.open(phase_path)
                phase_np = np.array(phase_image, dtype=np.uint16)
                phase_np = (phase_np / phase_np.max() * 255).astype(np.uint8)
                phase_rgb = np.stack([phase_np] * 3, axis=-1)

                predictor.set_image(phase_rgb)

                # Generate primary mask
                input_point = np.array([[phase_rgb.shape[1] // 2, phase_rgb.shape[0] // 2]])
                input_label = np.array([1])
                masks, scores, _ = predictor.predict(
                    point_coords=input_point,
                    point_labels=input_label,
                    multimask_output=False
                )
                primary_mask = masks[0]

                # Save whole fish mask
                if save_masks:
                    mask_output_path = os.path.join(phase_mask_dir, f"mask_{os.path.splitext(phase_file)[0]}.png")
                    Image.fromarray((primary_mask * 255).astype(np.uint8)).save(mask_output_path)

                # Split mask into head and tail
                head_mask, tail_mask = split_mask(primary_mask)

                # Save head and tail masks
                if save_masks:
                    head_mask_output_path = os.path.join(head_mask_dir, f"head_mask_{os.path.splitext(phase_file)[0]}.png")
                    tail_mask_output_path = os.path.join(tail_mask_dir, f"tail_mask_{os.path.splitext(phase_file)[0]}.png")
                    Image.fromarray((head_mask * 255).astype(np.uint8)).save(head_mask_output_path)
                    Image.fromarray((tail_mask * 255).astype(np.uint8)).save(tail_mask_output_path)

                # Load GFP image
                gfp_image = Image.open(gfp_path)
                gfp_np = np.array(gfp_image, dtype=np.uint16)

                # Calculate GFP intensities for whole fish, head, and tail
                gfp_values_within_fish = gfp_np[primary_mask > 0]
                mean_gfp_intensity = gfp_values_within_fish.mean() if gfp_values_within_fish.size > 0 else 0
                total_gfp_intensity = gfp_values_within_fish.sum()

                gfp_values_within_head = gfp_np[head_mask > 0]
                mean_gfp_head = gfp_values_within_head.mean() if gfp_values_within_head.size > 0 else 0
                total_gfp_head = gfp_values_within_head.sum()

                gfp_values_within_tail = gfp_np[tail_mask > 0]
                mean_gfp_tail = gfp_values_within_tail.mean() if gfp_values_within_tail.size > 0 else 0
                total_gfp_tail = gfp_values_within_tail.sum()

                # Save masked GFP images
                if save_masked_images:
                    gfp_masked = gfp_np * primary_mask
                    head_masked = gfp_np * head_mask
                    tail_masked = gfp_np * tail_mask

                    gfp_output_path = os.path.join(output_dir, f"masked_{identifier}.png")
                    head_output_path = os.path.join(output_dir, f"masked_head_{identifier}.png")
                    tail_output_path = os.path.join(output_dir, f"masked_tail_{identifier}.png")

                    Image.fromarray(gfp_masked.astype(np.uint16)).save(gfp_output_path)
                    Image.fromarray(head_masked.astype(np.uint16)).save(head_output_path)
                    Image.fromarray(tail_masked.astype(np.uint16)).save(tail_output_path)

                # Append GFP intensity results
                gfp_intensity_results.append({
                    "Identifier": identifier,
                    "Mean_GFP_Intensity": mean_gfp_intensity,
                    "Total_GFP_Intensity": total_gfp_intensity,
                    "Mean_GFP_Head": mean_gfp_head,
                    "Total_GFP_Head": total_gfp_head,
                    "Mean_GFP_Tail": mean_gfp_tail,
                    "Total_GFP_Tail": total_gfp_tail
                })

            except Exception as e:
                print(f"Error processing {identifier}: {e}")

        if (i + 1) % 5 == 0:
            print(f"Processed {i + 1}/{len(phase_contrast_files)} image pairs.")

    results_df = pd.DataFrame(gfp_intensity_results)
    csv_output_path = os.path.join(output_dir, "gfp_intensity_results_vit_l.csv")
    results_df.to_csv(csv_output_path, index=False)
    
    print(f"Results saved to {csv_output_path}") 
    return results_df

In [4]:
# Define paths
input_dir = r"C:\Daten\MoIm\Images"
output_dir = r"C:\Daten\MoIm\Output\Model_L"
os.makedirs(output_dir, exist_ok=True)

In [5]:
# Run the function
gfp_intensity_results_df = process_images(input_dir, output_dir, save_masks=True, save_masked_images=True)

Processed 5/1281 image pairs.
Processed 10/1281 image pairs.
Processed 15/1281 image pairs.
Processed 20/1281 image pairs.
Processed 25/1281 image pairs.
Processed 30/1281 image pairs.
Processed 35/1281 image pairs.
Processed 40/1281 image pairs.
Processed 45/1281 image pairs.
Processed 50/1281 image pairs.
Processed 55/1281 image pairs.
Processed 60/1281 image pairs.
Processed 65/1281 image pairs.
Processed 70/1281 image pairs.
Processed 75/1281 image pairs.
Processed 80/1281 image pairs.
Processed 85/1281 image pairs.
Processed 90/1281 image pairs.
Processed 95/1281 image pairs.
Processed 100/1281 image pairs.
Processed 105/1281 image pairs.
Processed 110/1281 image pairs.
Processed 115/1281 image pairs.
Processed 120/1281 image pairs.
Processed 125/1281 image pairs.
Processed 130/1281 image pairs.
Processed 135/1281 image pairs.
Processed 140/1281 image pairs.
Processed 145/1281 image pairs.
Processed 150/1281 image pairs.
Processed 155/1281 image pairs.
Processed 160/1281 image pai

### Runtimes for 8 image pairs
Base model: 1min 3.5s (imperfect masks)

Large model: 2min 35.9s (perfect masks)

Huge model: 4min 53s (perfect masks)