In [2]:
import os
import re
import numpy as np
import pandas as pd
from PIL import Image
from segment_anything import sam_model_registry, SamPredictor
from concurrent.futures import ProcessPoolExecutor, as_completed

In [3]:
# Define the SAM model outside of multiprocessing functions to avoid reloading it each time
model_type = "vit_h"
sam = sam_model_registry[model_type](checkpoint="./sam_vit_h_4b8939.pth")
predictor = SamPredictor(sam)

  state_dict = torch.load(f)


In [4]:
# Define paths for loading and saving images
input_dir = r"C:\Users\dave-\OneDrive - ZHAW\HS24\MoIm\MolecularIMaging\Images\test_images"
output_dir = r"C:\Users\dave-\OneDrive - ZHAW\HS24\MoIm\MolecularIMaging\Images\Output_test_images"
phase_mask_dir = os.path.join(output_dir, "phase_masks")
os.makedirs(output_dir, exist_ok=True)

In [5]:
def get_identifier(filename: str) -> str:
    """Extracts a unique identifier from a filename."""
    match = re.match(r"^([A-D]\d+_\d+).*_(\d{3})\.tif$", filename)
    return f"{match.group(1)}_{match.group(2)}" if match else None

def process_image_pair(identifier: str, phase_file: str, gfp_file: str, input_dir: str, output_dir: str, 
                       save_masks: bool = False, save_masked_images: bool = False) -> dict:
    """
    Processes a pair of 'Phase Contrast' and 'GFP' images, generates a binary mask using SAM, 
    calculates GFP intensity within the mask, and optionally saves the mask and masked images.
    
    Args:
        identifier (str): Unique identifier for the image pair.
        phase_file (str): Filename of the Phase Contrast image.
        gfp_file (str): Filename of the GFP image.
        input_dir (str): Directory containing the input images.
        output_dir (str): Directory to save the output files.
        save_masks (bool): If True, saves generated masks as PNG files.
        save_masked_images (bool): If True, saves masked GFP images.
    
    Returns:
        dict: Dictionary containing the identifier, mean GFP intensity, and total GFP intensity.
    """
    # Initialize SAM model in each process
    model_type = "vit_h"  # Model type can be "vit_h", "vit_l", or "vit_b"
    sam = sam_model_registry[model_type](checkpoint="./sam_vit_h_4b8939.pth")
    predictor = SamPredictor(sam)

    phase_mask_dir = os.path.join(output_dir, "phase_masks")
    if save_masks:
        os.makedirs(phase_mask_dir, exist_ok=True)
        
    # Load and process the Phase Contrast image
    phase_path = os.path.join(input_dir, phase_file)
    phase_image = Image.open(phase_path)
    phase_np = np.array(phase_image, dtype=np.uint16)
    phase_np = (phase_np / phase_np.max() * 255).astype(np.uint8)
    phase_rgb = np.stack([phase_np] * 3, axis=-1)

    # Set the image in SAM model for mask generation
    predictor.set_image(phase_rgb)
    input_point = np.array([[phase_rgb.shape[1] // 2, phase_rgb.shape[0] // 2]])
    input_label = np.array([1])

    # Generate mask
    masks, scores, _ = predictor.predict(
        point_coords=input_point,
        point_labels=input_label,
        multimask_output=False
    )
    mask = masks[0]

    # Optionally save the generated mask
    if save_masks:
        mask_output_path = os.path.join(phase_mask_dir, f"mask_{os.path.splitext(phase_file)[0]}.png")
        Image.fromarray((mask * 255).astype(np.uint8)).save(mask_output_path)

    # Load and process the GFP image
    gfp_path = os.path.join(input_dir, gfp_file)
    gfp_image = Image.open(gfp_path)
    gfp_np = np.array(gfp_image, dtype=np.uint16)

    # Calculate GFP intensities
    gfp_values_within_fish = gfp_np[mask > 0]
    mean_gfp_intensity = gfp_values_within_fish.mean() if gfp_values_within_fish.size > 0 else 0
    total_gfp_intensity = gfp_values_within_fish.sum()

    # Optionally save the masked GFP image
    if save_masked_images:
        gfp_masked = gfp_np * mask
        output_path = os.path.join(output_dir, f"masked_{gfp_file}")
        Image.fromarray(gfp_masked.astype(np.uint16)).save(output_path)

    return {
        "Identifier": identifier,
        "Mean_GFP_Intensity": mean_gfp_intensity,
        "Total_GFP_Intensity": total_gfp_intensity
    }

def collect_image_pairs(input_dir: str):
    """
    Collects and pairs 'Phase Contrast' and 'GFP' image files based on a unique identifier.
    
    Args:
        input_dir (str): Directory containing the input images.
    
    Returns:
        list of tuples: Each tuple contains an identifier and the filenames for Phase Contrast and GFP images.
    """
    phase_contrast_files = {}
    gfp_files = {}

    for file_name in os.listdir(input_dir):
        if file_name.endswith(".tif"):
            identifier = get_identifier(file_name)
            if identifier:
                if "Phase Contrast" in file_name:
                    phase_contrast_files[identifier] = file_name
                elif "GFP" in file_name:
                    gfp_files[identifier] = file_name

    image_pairs = [
        (identifier, phase_contrast_files[identifier], gfp_files[identifier])
        for identifier in phase_contrast_files if identifier in gfp_files
    ]
    return image_pairs

def process_images(input_dir: str, output_dir: str, save_masks: bool = False, save_masked_images: bool = False) -> pd.DataFrame:
    """
    Uses multiprocessing to process all image pairs in parallel, generating masks, calculating GFP intensities, 
    and optionally saving the outputs.

    Args:
        input_dir (str): Directory containing the input images.
        output_dir (str): Directory to save the output files.
        save_masks (bool): If True, saves the generated masks as PNG files.
        save_masked_images (bool): If True, saves the masked GFP images.

    Returns:
        pd.DataFrame: A DataFrame containing GFP intensity data (mean and total) for each image pair.
    """
    os.makedirs(output_dir, exist_ok=True)

    # Collect all Phase Contrast and GFP image pairs
    image_pairs = collect_image_pairs(input_dir)
    results = []

    # Use ProcessPoolExecutor for parallel processing
    with ProcessPoolExecutor() as executor:
        futures = [
            executor.submit(
                process_image_pair, identifier, phase_file, gfp_file, input_dir, output_dir, save_masks, save_masked_images
            )
            for identifier, phase_file, gfp_file in image_pairs
        ]

        # Collect results as tasks complete
        for future in as_completed(futures):
            try:
                result = future.result()
                results.append(result)
            except Exception as e:
                print(f"Error processing an image pair: {e}")

    # Convert results to a DataFrame and save as CSV
    results_df = pd.DataFrame(results)
    csv_output_path = os.path.join(output_dir, "gfp_intensity_results.csv")
    results_df.to_csv(csv_output_path, index=False)
    
    print(f"Results saved to {csv_output_path}")
    return results_df


In [6]:
# Run the parallelized image processing function
gfp_intensity_results_df = process_images(input_dir, output_dir, save_masks=True, save_masked_images=True)

Error processing an image pair: A process in the process pool was terminated abruptly while the future was running or pending.
Error processing an image pair: A process in the process pool was terminated abruptly while the future was running or pending.
Error processing an image pair: A process in the process pool was terminated abruptly while the future was running or pending.
Error processing an image pair: A process in the process pool was terminated abruptly while the future was running or pending.
Results saved to C:\Users\dave-\OneDrive - ZHAW\HS24\MoIm\MolecularIMaging\Images\Output_test_images\gfp_intensity_results.csv
